Skip to content

afnio.tellurio.websocket_client

afnio.tellurio.websocket_client.TellurioWebSocketClient

A WebSocket client for interacting with the Tellurio backend.

This client establishes a WebSocket connection to the backend, sends requests, listens for responses, and handles reconnections. It supports JSON-RPC-style communication and is designed to work with asynchronous workflows.

Source code in afnio/tellurio/websocket_client.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
class TellurioWebSocketClient:
    """A WebSocket client for interacting with the Tellurio backend.

    This client establishes a WebSocket connection to the backend, sends requests,
    listens for responses, and handles reconnections. It supports JSON-RPC-style
    communication and is designed to work with asynchronous workflows.
    """

    def __init__(
        self,
        base_url: str = None,
        port: int = None,
        default_timeout: int = 30,
    ):
        """Initializes the `TellurioWebSocketClient`.

        Args:
            base_url: The base URL of the Tellurio backend
                (e.g., `"https://platform.tellurio.ai"`).
            port: The port number for the backend (default: 443).
            default_timeout: The default timeout (in seconds)
              for WebSocket requests.
        """
        self.base_url = base_url or os.getenv(
            "TELLURIO_BACKEND_WS_BASE_URL", "wss://platform.tellurio.ai"
        )
        self.port = port or os.getenv("TELLURIO_BACKEND_WS_PORT", 443)
        self.api_key = None
        self.default_timeout = default_timeout
        self.ws_url = self._build_ws_url(self.base_url, self.port)
        self.connection: websockets.ClientConnection = None
        self.listener_task = None
        self.pending = {}  # req_id → Future
        self._heartbeat_times = {}  # req_id -> last heartbeat time (monotonic)

    def _build_ws_url(self, base_url: str, port: int) -> str:
        """Constructs the WebSocket URL from the base URL and port.

        Args:
            base_url: The base URL of the Tellurio backend
              (e.g., `"wss://platform.tellurio.ai"`).
            port: The port number for the backend.

        Returns:
            The WebSocket URL (e.g., `"wss://platform.tellurio.ai/ws/v0/rpc/"`).
        """
        return f"{base_url}:{port}/ws/v0/rpc/"

    async def connect(
        self, api_key: str = None, retries: int = 3, delay: int = 5
    ) -> Dict[str, str]:
        """Connects to the WebSocket server with retry logic.

        Attempts to establish a WebSocket connection to the backend. If the connection
        fails, it retries up to the specified number of attempts with a delay between
        each attempt.

        Args:
            api_key: The API key for authenticating with the backend.
            retries: The number of reconnection attempts.
            delay: The delay (in seconds) between reconnection attempts.

        Returns:
            The session ID received from the server upon successful connection.

        Raises:
            RuntimeError: If the connection fails after all retry attempts.
        """
        self.api_key = api_key

        headers = {"Authorization": f"Api-Key {self.api_key}"}
        for attempt in range(retries):
            try:
                logger.debug(
                    f"Connecting to WebSocket at {self.ws_url} "
                    f"(attempt {attempt + 1}/{retries})"
                )
                self.connection = await websockets.connect(
                    self.ws_url, additional_headers=headers
                )

                # Start the listener task
                self.listener_task = asyncio.create_task(self._listener())
                logger.debug("WebSocket connection established.")

                # Example: Retrieve session ID from the server
                response = await self.connection.recv()
                response_data = json.loads(response)
                session_id = response_data.get("result", {}).get("session_id")
                return {"session_id": session_id}
            except Exception as e:
                logger.error(f"Failed to connect to WebSocket: {e}")
                if attempt < retries - 1:
                    await asyncio.sleep(delay)
                else:
                    raise RuntimeError(
                        "Failed to connect to WebSocket after multiple attempts."
                    )

    async def _listener(self):
        """Continuously listens for and processes incoming WebSocket messages.

        This method runs as a background task and handles all incoming messages from the
        WebSocket server according to the JSON-RPC 2.0 protocol. It supports:

        - Resolving responses to client-initiated requests by matching them with their
          corresponding request IDs and completing the associated futures.
        - Handling server-initiated JSON-RPC requests and notifications by dispatching
          them to the appropriate handler methods (e.g., `rpc_update_variable`).
        - Sending JSON-RPC responses or error messages back to the server as needed.
        - Logging and reporting protocol errors, unexpected messages, or exceptions.
        - Attempting to reconnect if the WebSocket connection is closed unexpectedly.

        This method ensures robust, asynchronous, and standards-compliant communication
        between the client and the Tellurio backend.

        Raises:
            ConnectionClosed: If the WebSocket connection is closed.
            Exception: For any unexpected errors during message processing.
        """
        try:
            async for message in self.connection:
                logger.debug(f"Received message: {message}")
                try:
                    data: Dict[str, Any] = json.loads(message)
                    req_id = data.get("id")

                    # Validate JSON-RPC version
                    jsonrpc_version = data.get("jsonrpc")
                    if jsonrpc_version != "2.0":
                        logger.warning(f"Invalid JSON-RPC version: {jsonrpc_version}")
                        await self._send_error(
                            req_id,
                            INVALID_REQUEST,
                            "Invalid JSON-RPC version. Expected '2.0'.",
                        )
                        continue

                    # Handle JSON-RPC responses to client-initiated requests
                    if req_id and "method" not in data:
                        future = self.pending.pop(req_id, None)
                        if future:
                            # Handle both success and error responses
                            if "error" in data:
                                future.set_result(data)  # Pass full error response
                            elif "result" in data:
                                future.set_result(data)  # Pass full success response
                            else:
                                logger.warning(f"Unexpected response format: {data}")
                                future.set_exception(
                                    ValueError(f"Unexpected response format: {data}")
                                )
                        else:
                            logger.warning(f"Unexpected data or missing ID: {data}")
                        continue

                    # Handle JSON-RPC requests and notifications (must have "method")
                    if "method" not in data:
                        logger.warning("Invalid request. Missing 'method' field.")
                        await self._send_error(
                            req_id,
                            INVALID_REQUEST,
                            "Missing required field: method",
                        )
                        continue

                    # Handle the RPC method
                    method = data.get("method")
                    handler = getattr(self, f"rpc_{method}", None)
                    if not handler:
                        logger.warning(f"RPC method not found: {method}")
                        await self._send_error(
                            req_id,
                            METHOD_NOT_FOUND,
                            f"Method '{method}' not found.",
                            {"method": method},
                        )
                        continue

                    # Handle notifications (no id): do not send a response
                    if req_id is None:
                        logger.debug(
                            f"Received notification for method '{method}' "
                            f"with params: {data.get('params', {})}"
                        )
                        await handler(data.get("params", {}))
                        continue

                    # Handle request (with id): execute RPC method and send response
                    params = data.get("params", {})
                    logger.debug(
                        f"RPC method call: method={method!r} "
                        f"params={params!r} id={req_id!r}"
                    )

                    result = await handler(params)

                    # Send the response
                    response = {
                        "jsonrpc": "2.0",
                        "id": req_id,
                        "result": result,
                    }
                    logger.debug(
                        f"RPC method executed successfully: method={method!r} "
                        f"result={result!r} id={req_id!r}"
                    )
                    await self.connection.send(json.dumps(response))

                except json.JSONDecodeError as e:
                    logger.error(f"Failed to decode JSON data: {e}")
                    await self._send_error(
                        req_id, PARSE_ERROR, "Parse error", {"error": str(e)}
                    )
                except KeyError as e:
                    logger.error(f"Missing key in request: {e}")
                    await self._send_error(
                        req_id,
                        INVALID_PARAMS,
                        f"Missing key: {e}",
                        {"missing_key": str(e)},
                    )
                except Exception as e:
                    logger.error(f"Unexpected error: {e}")
                    await self._send_error(
                        req_id, INTERNAL_ERROR, "Internal error", {"exception": str(e)}
                    )

        except ConnectionClosed as e:
            logger.warning(f"WebSocket connection closed: {e}")
            await asyncio.sleep(1)  # Add a delay before reconnecting
            await self.connect()  # Attempt to reconnect
        except Exception as e:
            logger.error(f"Unexpected error in listener: {e}")

    async def rpc_heartbeat(self, params: dict):
        """Handle the `'heartbeat'` JSON-RPC notification from the server.

        This method is called when the server sends a heartbeat notification for a
        long-running operation. It updates the last heartbeat timestamp for the
        corresponding request ID, allowing the client to reset its timeout and avoid
        prematurely timing out while the server is still processing the request.

        Args:
            params: A dictionary with keys:

                - `id`: The request ID (str) associated with the long-running operation.
        """
        req_id = params.get("id")
        if req_id:
            self._heartbeat_times[req_id] = time.monotonic()
            logger.debug(f"Received heartbeat for request {req_id}: {params!r}")

    async def rpc_create_variable(self, params: dict) -> Dict[str, str]:
        """
        Handle the `'create_variable'` JSON-RPC method from the server.

        This method creates and registers a new Variable instance in the local registry
        using the provided parameters. It is typically called when the server creates a
        deepcopy of a Variable or Parameter and needs to notify the client.

        Args:
            params: A dictionary with keys:

                - `variable_id`: The unique identifier of the Variable.
                - `obj_type`: The type of the variable object
                    (e.g., "__variable__" or "__parameter__").
                - `data`: The initial data for the variable.
                - `role`: The role or description of the variable.
                - `requires_grad`: Whether the variable requires gradients.
                - `_retain_grad`: Whether to retain gradients for non-leaf variables.
                - `_grad`: The initial gradient(s) for the variable.
                - `_output_nr`: The output number for the variable in the computation
                    graph.
                - `_grad_fn`: The gradient function associated with the variable.
                - `is_leaf`: Whether the variable is a leaf node in the computation
                    graph.

        Returns:
            A dictionary with a success message if the variable is created.

        Raises:
            KeyError: If required keys are missing from params.
            RuntimeError: If the variable creation fails for any reason.
        """
        try:
            var = create_local_variable(
                params["variable_id"],
                params["obj_type"],
                params["data"],
                params["role"],
                params["requires_grad"],
                params["_retain_grad"],
                params["_grad"],
                params["_output_nr"],
                params["_grad_fn"],
                params["is_leaf"],
            )
            logger.debug(f"Variable created: variable_id={var.variable_id!r}")
            return {"message": "Ok"}
        except KeyError as e:
            logger.error(f"Missing key in params: {e}")
            raise KeyError(f"Missing key: {e}")
        except RuntimeError as e:
            logger.error(f"Failed to create variable: {e}")
            raise RuntimeError(f"Failed to create variable: {e}")

    async def rpc_update_variable(self, params: dict) -> Dict[str, str]:
        """Handle the `'update_variable'` JSON-RPC method from the server.

        This method updates a specific field of a registered Variable instance
        in response to a server notification. It uses the provided parameters
        to identify the variable and the field to update.

        Args:
            params: A dictionary with keys:

                - `variable_id`: The unique identifier of the Variable.
                - `field`: The field name to update.
                - `value`: The new value to set for the field.

        Returns:
            A dictionary with a success message if the update is successful.

        Raises:
            KeyError: If required keys are missing from params.
            RuntimeError: If the update fails for any reason.
        """
        try:
            with suppress_variable_notifications():
                update_local_variable_field(
                    params["variable_id"], params["field"], params["value"]
                )
                logger.debug(
                    f"Variable updated: variable_id={params['variable_id']!r} "
                    f"field={params['field']!r} value={params['value']!r}"
                )
                return {"message": "Ok"}
        except KeyError as e:
            logger.error(f"Missing key in params: {e}")
            raise KeyError(f"Missing key: {e}")
        except RuntimeError as e:
            logger.error(
                f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
            )
            raise RuntimeError(
                f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
            )

    async def rpc_append_grad(self, params: dict) -> Dict[str, str]:
        """Handle the `'append_grad'` JSON-RPC method from the server.

        This method appends a new gradient variable to the local grad list of the
        specified Variable instance. It is typically called when the server notifies
        the client that a new gradient has been added to a variable during
        the backward pass.

        Args:
            params: A dictionary containing:

                - `variable_id`: The unique identifier of the Variable to update.
                - `gradient`: The serialized gradient Variable to append.

        Returns:
            A dictionary with a success message if the gradient is appended.

        Raises:
            KeyError: If required keys are missing from params.
            RuntimeError: If appending the gradient fails for any reason.
        """
        try:
            with suppress_variable_notifications():
                append_grad_local(
                    params["variable_id"], params["gradient_id"], params["gradient"]
                )
                logger.debug(
                    f"Gradient appended: variable_id={params['variable_id']!r} "
                    f"gradient={params['gradient']!r}"
                )
                return {"message": "Ok"}
        except KeyError as e:
            logger.error(f"Missing key in params: {e}")
            raise KeyError(f"Missing key: {e}")
        except RuntimeError as e:
            logger.error(
                f"Failed to append gradient for variable with ID "
                f"{params.get('variable_id')!r}: {e}"
            )
            raise RuntimeError(
                f"Failed to append gradient for variable with ID "
                f"{params.get('variable_id')!r}: {e}"
            )

    async def rpc_create_node(self, params: dict) -> Dict[str, str]:
        """Handle the `'create_node'` JSON-RPC method from the server.

        This method creates and registers a new Node instance in the local registry
        using the provided parameters. It is typically called when the server notifies
        the client that a new node has been created in the computation graph.

        Args:
            params: A dictionary with keys:

                - `node_id`: The unique identifier of the Node.
                - `name`: The class name or type of the Node.

        Returns:
            A dictionary with a success message if the node is created.

        Raises:
            KeyError: If required keys are missing from params.
        """
        try:
            create_node(params)
            logger.debug(
                f"Node created: node_id={params['node_id']!r} name={params['name']!r}"
            )
            return {"message": "Ok"}
        except KeyError as e:
            logger.error(f"Missing key in params: {e}")
            raise KeyError(f"Missing key: {e}")

    async def rpc_create_edge(self, params: dict) -> Dict[str, str]:
        """Handle the `'create_edge'` JSON-RPC method from the server.

        This method creates a [`GradientEdge`][afnio.GradientEdge] between two nodes
        in the local registry, appending the edge to the `from_node`'s
        [`next_functions`][afnio.Node.next_functions].
        It is typically called when the server notifies the client that a new edge
        has been created in the computation graph.

        Note:
            The terms `from_node` and `to_node` should be interpreted in the context
            of the backward pass (backpropagation): the edge is added to the
            `from_node`'s [`next_functions`][afnio.Node.next_functions] and points to
            the `to_node`, following the direction of gradient flow
            during backpropagation.

        Args:
            params: A dictionary with keys:

                - `from_node_id`: The unique identifier of the source node.
                - `to_node_id`: The unique identifier of the destination node.
                - `output_nr`: The output number associated with the edge.

        Returns:
            A dictionary with a success message if the edge is created.

        Raises:
            KeyError: If required keys are missing from params.
        """
        try:
            create_and_append_edge(params)
            logger.debug(
                f"Edge created: from_node_id={params['from_node_id']!r} "
                f"to_node_id={params['to_node_id']!r} output_nr={params['output_nr']!r}"
            )
            return {"message": "Ok"}
        except KeyError as e:
            logger.error(f"Missing key in params: {e}")
            raise KeyError(f"Missing key: {e}")

    async def rpc_update_model(self, params: dict) -> Dict[str, str]:
        """Handle the `'update_model'` JSON-RPC method from the server.

        This method updates a specific field of a registered LM model instance
        in response to a server notification. It uses the provided parameters
        to identify the LM model and the field to update.

        Args:
            params: A dictionary with keys:

                - `model_id`: The unique identifier of the LM model.
                - `field`: The field name to update.
                - `value`: The new value to set for the field.

        Returns:
            A dictionary with a success message if the update is successful.

        Raises:
            KeyError: If required keys are missing from params.
            RuntimeError: If the update fails for any reason.
        """
        try:
            update_local_model_field(
                params["model_id"], params["field"], params["value"]
            )
            logger.debug(
                f"Model updated: model_id={params['model_id']!r} "
                f"field={params['field']!r} value={params['value']!r}"
            )
            return {"message": "Ok"}
        except KeyError as e:
            logger.error(f"Missing key in params: {e}")
            raise KeyError(f"Missing key: {e}")
        except RuntimeError as e:
            logger.error(
                f"Failed to update model with ID {params.get('model_id')!r}: {e}"
            )
            raise RuntimeError(
                f"Failed to update model with ID {params.get('model_id')!r}: {e}"
            )

    async def rpc_run_callable(self, params: dict) -> Dict[str, Any]:
        """Handle the `'run_callable'` JSON-RPC method from the server.

        This method is invoked when the server sends a JSON-RPC request with the
        method `"run_callable"`. It extracts callable details from the provided
        parameters, executes the callable from the registry, and returns a response
        containing the result. The response is expected to be JSON-serializable.

        Args:
            params: A dictionary containing:

                - `callable_id`: A unique identifier for the callable.
                - `args`: Positional arguments (as a list or tuple) for the callable.
                - `kwargs`: Keyword arguments for the callable.

        Returns:
            A dictionary with the following structure:

                `{
                    "message": "Ok",
                    "data": <result of executing the callable>
                }`

        Raises:
            KeyError: If required keys are missing in params.
            TypeError: If the result of the callable is not JSON-serializable.
            ValueError: If the callable execution fails due to invalid parameters.
            RuntimeError: For any other exception encountered during callable execution.
        """
        try:
            result = run_callable(params)

            # Check if result is JSON serializable
            try:
                json.dumps(result)
            except (TypeError, ValueError) as e:
                logger.error(
                    f"Result of callable with ID {params.get('callable_id')!r} "
                    f"is not JSON-serializable: {result!r} ({e})"
                )
                raise TypeError(
                    f"Result of callable with ID {params.get('callable_id')!r} "
                    f"is not JSON-serializable: {result!r} ({e})"
                )

            logger.debug(
                f"Callable executed successfully: "
                f"callable_id={params['callable_id']!r}, "
                f"args={params.get('args', {})!r}, "
                f"kwargs={params.get('kwargs', {})!r}"
            )
            return {"message": "Ok", "data": result}
        except KeyError as e:
            logger.error(f"Missing key in params: {e}")
            raise KeyError(f"Missing key: {e}")
        except ValueError as e:
            logger.error(
                f"Failed to run callable with ID {params.get('callable_id')!r}: {e}"
            )
            raise ValueError(
                f"Failed to run callable with ID {params.get('callable_id')!r}: {e}"
            )
        except Exception as e:
            logger.error(
                f"Exception during execution of callable "
                f"with ID {params.get('callable_id')!r}: {e}"
            )
            raise RuntimeError(
                f"Exception during execution of callable "
                f"with ID {params.get('callable_id')!r}: {e}"
            )

    async def rpc_clear_backward(self, params: dict) -> Dict[str, str]:
        """Handle the `'clear_backward'` JSON-RPC method from the server.

        This method clears the `_pending_grad` flag for the specified variables.
        It is called after the server finalizes the backward pass for the entire
        computation graph, indicating that the gradients for its variables have been
        computed and already shared with the client. Once it receives
        `'clear_backward'`, the client can safely access the values of these gradients
        without worrying about them being modified.

        Args:
            params: A dictionary containing:

                - `variable_ids`: A list of variable IDs for which to clear
                    the `_pending_grad` flag.

        Returns:
            A dictionary with a success message if the pending gradients are cleared.

        Raises:
            KeyError: If required keys are missing from params.
            RuntimeError: If clearing the pending grad fails for any variable.
        """
        try:
            variable_ids = params["variable_ids"]
            clear_pending_grad(variable_ids)

            logger.debug(f"Cleared pending gradients for variables: {variable_ids!r}")
            return {"message": "Ok"}
        except KeyError as e:
            logger.error(f"Missing key in params: {e}")
            raise KeyError(f"Missing key: {e}")
        except RuntimeError as e:
            logger.error(
                f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
            )
            raise RuntimeError(
                f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
            )
        except Exception as e:
            logger.error(f"Exception during execution of backward clearing: {e}")
            raise RuntimeError(f"Exception during execution of backward clearing: {e}")

    async def rpc_clear_step(self, params: dict) -> Dict[str, str]:
        """Handle the `'clear_step'` JSON-RPC method from the server.

        This method clears the `_pending_data` flag for the specified variables.
        It is called after the server completes an optimizer step and updates
        the data for the relevant variables. Once `'clear_step'` is received,
        the client can safely access the updated values of these variables,
        knowing that the data is no longer pending or being modified.

        Args:
            params: A dictionary containing:

                - `variable_ids`: A list of variable IDs (str) for which to clear
                  the `_pending_data` flag.

        Returns:
            A dictionary with a success message if the pending data is cleared.

        Raises:
            KeyError: If required keys are missing from params.
            RuntimeError: If clearing the pending data fails for any variable.
        """
        try:
            variable_ids = params["variable_ids"]
            clear_pending_data(variable_ids)

            logger.debug(f"Cleared pending data for variables: {variable_ids!r}")
            return {"message": "Ok"}
        except KeyError as e:
            logger.error(f"Missing key in params: {e}")
            raise KeyError(f"Missing key: {e}")
        except RuntimeError as e:
            logger.error(
                f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
            )
            raise RuntimeError(
                f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
            )
        except Exception as e:
            logger.error(f"Exception during execution of backward clearing: {e}")
            raise RuntimeError(f"Exception during execution of backward clearing: {e}")

    async def call(self, method: str, params: dict, timeout=None) -> Any:
        """Sends a request over the WebSocket connection and waits for a response.

        Constructs a JSON-RPC request, sends it to the WebSocket server, and waits
        for the corresponding response. If no response is received within the timeout
        period, a `TimeoutError` is raised.

        Args:
            method: The name of the method to call on the backend.
            params: The parameters to pass to the method.
            timeout: The timeout (in seconds) for the response.
                If not provided, the default timeout is used.

        Returns:
            The result of the method call.

        Raises:
            RuntimeError: If the WebSocket connection is not established.
            asyncio.TimeoutError: If the response is not received within
              the timeout period.
        """
        timeout = timeout or self.default_timeout  # Use default timeout if not provided

        if not self.connection:
            raise RuntimeError("WebSocket is not connected")

        active_run = get_active_run()
        params["run_uuid"] = active_run.uuid

        req_id = str(uuid.uuid4()) if timeout else None
        request = {
            "jsonrpc": "2.0",
            "method": method,
            "params": params,
        }
        if req_id:
            request["id"] = req_id

        # Send request and wait for matching response
        await self.connection.send(json.dumps(request))
        logger.debug(f"Sent RPC request: {request}")

        # If it's a notification (no `id`), return immediately
        if not req_id:
            return None

        # Wait for response
        future = asyncio.get_running_loop().create_future()
        self.pending[req_id] = future

        if method in LONG_RUNNING_METHODS:
            # Heartbeat-aware wait loop
            self._heartbeat_times[req_id] = time.monotonic()
            last_heartbeat = time.monotonic()
            try:
                while True:
                    try:
                        # Using `shield` to prevent cancellation of the future to allow
                        # heartbeat updates to keep it alive
                        return await asyncio.wait_for(
                            asyncio.shield(future), timeout=timeout
                        )
                    except asyncio.TimeoutError:
                        now = time.monotonic()
                        last_heartbeat = self._heartbeat_times.get(
                            req_id, last_heartbeat
                        )
                        if now - last_heartbeat > timeout:
                            logger.error(f"Request timed out (no heartbeat): {request}")
                            raise
            finally:
                self.pending.pop(req_id, None)
                self._heartbeat_times.pop(req_id, None)
        else:
            # Standard wait
            try:
                return await asyncio.wait_for(future, timeout=timeout)
            except asyncio.TimeoutError:
                logger.error(f"Request timed out: {request}")
                raise
            finally:
                self.pending.pop(req_id, None)

    async def close(self):
        """Closes the WebSocket connection and cleans up resources.

        Cancels the listener task, clears pending requests, and closes the WebSocket
        connection.
        """
        # Add a delay to allow receiving and replying to remaining server requests
        await asyncio.sleep(1)

        if self.listener_task:
            logger.debug("Canceling listener task...")
            self.listener_task.cancel()
            try:
                await self.listener_task  # Wait for the listener task to finish
            except asyncio.CancelledError:
                logger.debug("Listener task canceled.")
                pass  # Ignore cancellation errors
        self.listener_task = None  # Clean up the listener task

        if self.connection:
            logger.debug("Closing WebSocket connection...")
            try:
                await self.connection.close()
            finally:
                self.connection = None

        logger.debug("Clearing pending requests...")
        self._cancel_pending_requests()  # Clear pending requests

        logger.debug("WebSocket connection closed.")

    async def _send_error(
        self,
        req_id: Optional[str],
        code: int,
        message: str,
        data: Optional[Dict[str, Any]] = None,
    ):
        """Send an error response to the server.

        Args:
            req_id: The ID of the request that caused the error.
            code: The JSON-RPC error code.
            message: A description of the error.
            data: Additional data about the error.
        """
        logger.warning(
            f"Sending error response. ID: {req_id}, Code: {code}, "
            f"Message: {message}, Data: {data}"
        )
        error_response = {
            "jsonrpc": "2.0",
            "id": req_id,
            "error": {
                "code": code,
                "message": message,
            },
        }
        if data:
            error_response["error"]["data"] = data
        await self.connection.send(json.dumps(error_response))

    async def __aenter__(self) -> "TellurioWebSocketClient":
        """Asynchronous context manager entry.

        Establishes the WebSocket connection when entering the context.
        If the connection is already established, it ensures the connection is active.

        Returns:
            The WebSocket client instance.
        """
        if not self.connection or self.connection.closed:
            await self.connect()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Asynchronous context manager exit.

        Closes the WebSocket connection and cleans up resources
        when exiting the context.
        """
        await self.close()

    def _cancel_pending_requests(self):
        """Cancels all pending requests and clears the pending dictionary."""
        for req_id, future in self.pending.items():
            if not future.done():
                future.cancel()
        self.pending.clear()
        logger.debug("All pending requests have been canceled.")

__init__(base_url=None, port=None, default_timeout=30)

Initializes the TellurioWebSocketClient.

Parameters:

Name Type Description Default
base_url str

The base URL of the Tellurio backend (e.g., "https://platform.tellurio.ai").

None
port int

The port number for the backend (default: 443).

None
default_timeout int

The default timeout (in seconds) for WebSocket requests.

30
Source code in afnio/tellurio/websocket_client.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    base_url: str = None,
    port: int = None,
    default_timeout: int = 30,
):
    """Initializes the `TellurioWebSocketClient`.

    Args:
        base_url: The base URL of the Tellurio backend
            (e.g., `"https://platform.tellurio.ai"`).
        port: The port number for the backend (default: 443).
        default_timeout: The default timeout (in seconds)
          for WebSocket requests.
    """
    self.base_url = base_url or os.getenv(
        "TELLURIO_BACKEND_WS_BASE_URL", "wss://platform.tellurio.ai"
    )
    self.port = port or os.getenv("TELLURIO_BACKEND_WS_PORT", 443)
    self.api_key = None
    self.default_timeout = default_timeout
    self.ws_url = self._build_ws_url(self.base_url, self.port)
    self.connection: websockets.ClientConnection = None
    self.listener_task = None
    self.pending = {}  # req_id → Future
    self._heartbeat_times = {}  # req_id -> last heartbeat time (monotonic)

connect(api_key=None, retries=3, delay=5) async

Connects to the WebSocket server with retry logic.

Attempts to establish a WebSocket connection to the backend. If the connection fails, it retries up to the specified number of attempts with a delay between each attempt.

Parameters:

Name Type Description Default
api_key str

The API key for authenticating with the backend.

None
retries int

The number of reconnection attempts.

3
delay int

The delay (in seconds) between reconnection attempts.

5

Returns:

Type Description
dict[str, str]

The session ID received from the server upon successful connection.

Raises:

Type Description
RuntimeError

If the connection fails after all retry attempts.

Source code in afnio/tellurio/websocket_client.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
async def connect(
    self, api_key: str = None, retries: int = 3, delay: int = 5
) -> Dict[str, str]:
    """Connects to the WebSocket server with retry logic.

    Attempts to establish a WebSocket connection to the backend. If the connection
    fails, it retries up to the specified number of attempts with a delay between
    each attempt.

    Args:
        api_key: The API key for authenticating with the backend.
        retries: The number of reconnection attempts.
        delay: The delay (in seconds) between reconnection attempts.

    Returns:
        The session ID received from the server upon successful connection.

    Raises:
        RuntimeError: If the connection fails after all retry attempts.
    """
    self.api_key = api_key

    headers = {"Authorization": f"Api-Key {self.api_key}"}
    for attempt in range(retries):
        try:
            logger.debug(
                f"Connecting to WebSocket at {self.ws_url} "
                f"(attempt {attempt + 1}/{retries})"
            )
            self.connection = await websockets.connect(
                self.ws_url, additional_headers=headers
            )

            # Start the listener task
            self.listener_task = asyncio.create_task(self._listener())
            logger.debug("WebSocket connection established.")

            # Example: Retrieve session ID from the server
            response = await self.connection.recv()
            response_data = json.loads(response)
            session_id = response_data.get("result", {}).get("session_id")
            return {"session_id": session_id}
        except Exception as e:
            logger.error(f"Failed to connect to WebSocket: {e}")
            if attempt < retries - 1:
                await asyncio.sleep(delay)
            else:
                raise RuntimeError(
                    "Failed to connect to WebSocket after multiple attempts."
                )

rpc_heartbeat(params) async

Handle the 'heartbeat' JSON-RPC notification from the server.

This method is called when the server sends a heartbeat notification for a long-running operation. It updates the last heartbeat timestamp for the corresponding request ID, allowing the client to reset its timeout and avoid prematurely timing out while the server is still processing the request.

Parameters:

Name Type Description Default
params dict

A dictionary with keys:

  • id: The request ID (str) associated with the long-running operation.
required
Source code in afnio/tellurio/websocket_client.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
async def rpc_heartbeat(self, params: dict):
    """Handle the `'heartbeat'` JSON-RPC notification from the server.

    This method is called when the server sends a heartbeat notification for a
    long-running operation. It updates the last heartbeat timestamp for the
    corresponding request ID, allowing the client to reset its timeout and avoid
    prematurely timing out while the server is still processing the request.

    Args:
        params: A dictionary with keys:

            - `id`: The request ID (str) associated with the long-running operation.
    """
    req_id = params.get("id")
    if req_id:
        self._heartbeat_times[req_id] = time.monotonic()
        logger.debug(f"Received heartbeat for request {req_id}: {params!r}")

rpc_create_variable(params) async

Handle the 'create_variable' JSON-RPC method from the server.

This method creates and registers a new Variable instance in the local registry using the provided parameters. It is typically called when the server creates a deepcopy of a Variable or Parameter and needs to notify the client.

Parameters:

Name Type Description Default
params dict

A dictionary with keys:

  • variable_id: The unique identifier of the Variable.
  • obj_type: The type of the variable object (e.g., "variable" or "parameter").
  • data: The initial data for the variable.
  • role: The role or description of the variable.
  • requires_grad: Whether the variable requires gradients.
  • _retain_grad: Whether to retain gradients for non-leaf variables.
  • _grad: The initial gradient(s) for the variable.
  • _output_nr: The output number for the variable in the computation graph.
  • _grad_fn: The gradient function associated with the variable.
  • is_leaf: Whether the variable is a leaf node in the computation graph.
required

Returns:

Type Description
dict[str, str]

A dictionary with a success message if the variable is created.

Raises:

Type Description
KeyError

If required keys are missing from params.

RuntimeError

If the variable creation fails for any reason.

Source code in afnio/tellurio/websocket_client.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
async def rpc_create_variable(self, params: dict) -> Dict[str, str]:
    """
    Handle the `'create_variable'` JSON-RPC method from the server.

    This method creates and registers a new Variable instance in the local registry
    using the provided parameters. It is typically called when the server creates a
    deepcopy of a Variable or Parameter and needs to notify the client.

    Args:
        params: A dictionary with keys:

            - `variable_id`: The unique identifier of the Variable.
            - `obj_type`: The type of the variable object
                (e.g., "__variable__" or "__parameter__").
            - `data`: The initial data for the variable.
            - `role`: The role or description of the variable.
            - `requires_grad`: Whether the variable requires gradients.
            - `_retain_grad`: Whether to retain gradients for non-leaf variables.
            - `_grad`: The initial gradient(s) for the variable.
            - `_output_nr`: The output number for the variable in the computation
                graph.
            - `_grad_fn`: The gradient function associated with the variable.
            - `is_leaf`: Whether the variable is a leaf node in the computation
                graph.

    Returns:
        A dictionary with a success message if the variable is created.

    Raises:
        KeyError: If required keys are missing from params.
        RuntimeError: If the variable creation fails for any reason.
    """
    try:
        var = create_local_variable(
            params["variable_id"],
            params["obj_type"],
            params["data"],
            params["role"],
            params["requires_grad"],
            params["_retain_grad"],
            params["_grad"],
            params["_output_nr"],
            params["_grad_fn"],
            params["is_leaf"],
        )
        logger.debug(f"Variable created: variable_id={var.variable_id!r}")
        return {"message": "Ok"}
    except KeyError as e:
        logger.error(f"Missing key in params: {e}")
        raise KeyError(f"Missing key: {e}")
    except RuntimeError as e:
        logger.error(f"Failed to create variable: {e}")
        raise RuntimeError(f"Failed to create variable: {e}")

rpc_update_variable(params) async

Handle the 'update_variable' JSON-RPC method from the server.

This method updates a specific field of a registered Variable instance in response to a server notification. It uses the provided parameters to identify the variable and the field to update.

Parameters:

Name Type Description Default
params dict

A dictionary with keys:

  • variable_id: The unique identifier of the Variable.
  • field: The field name to update.
  • value: The new value to set for the field.
required

Returns:

Type Description
dict[str, str]

A dictionary with a success message if the update is successful.

Raises:

Type Description
KeyError

If required keys are missing from params.

RuntimeError

If the update fails for any reason.

Source code in afnio/tellurio/websocket_client.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
async def rpc_update_variable(self, params: dict) -> Dict[str, str]:
    """Handle the `'update_variable'` JSON-RPC method from the server.

    This method updates a specific field of a registered Variable instance
    in response to a server notification. It uses the provided parameters
    to identify the variable and the field to update.

    Args:
        params: A dictionary with keys:

            - `variable_id`: The unique identifier of the Variable.
            - `field`: The field name to update.
            - `value`: The new value to set for the field.

    Returns:
        A dictionary with a success message if the update is successful.

    Raises:
        KeyError: If required keys are missing from params.
        RuntimeError: If the update fails for any reason.
    """
    try:
        with suppress_variable_notifications():
            update_local_variable_field(
                params["variable_id"], params["field"], params["value"]
            )
            logger.debug(
                f"Variable updated: variable_id={params['variable_id']!r} "
                f"field={params['field']!r} value={params['value']!r}"
            )
            return {"message": "Ok"}
    except KeyError as e:
        logger.error(f"Missing key in params: {e}")
        raise KeyError(f"Missing key: {e}")
    except RuntimeError as e:
        logger.error(
            f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
        )
        raise RuntimeError(
            f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
        )

rpc_append_grad(params) async

Handle the 'append_grad' JSON-RPC method from the server.

This method appends a new gradient variable to the local grad list of the specified Variable instance. It is typically called when the server notifies the client that a new gradient has been added to a variable during the backward pass.

Parameters:

Name Type Description Default
params dict

A dictionary containing:

  • variable_id: The unique identifier of the Variable to update.
  • gradient: The serialized gradient Variable to append.
required

Returns:

Type Description
dict[str, str]

A dictionary with a success message if the gradient is appended.

Raises:

Type Description
KeyError

If required keys are missing from params.

RuntimeError

If appending the gradient fails for any reason.

Source code in afnio/tellurio/websocket_client.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
async def rpc_append_grad(self, params: dict) -> Dict[str, str]:
    """Handle the `'append_grad'` JSON-RPC method from the server.

    This method appends a new gradient variable to the local grad list of the
    specified Variable instance. It is typically called when the server notifies
    the client that a new gradient has been added to a variable during
    the backward pass.

    Args:
        params: A dictionary containing:

            - `variable_id`: The unique identifier of the Variable to update.
            - `gradient`: The serialized gradient Variable to append.

    Returns:
        A dictionary with a success message if the gradient is appended.

    Raises:
        KeyError: If required keys are missing from params.
        RuntimeError: If appending the gradient fails for any reason.
    """
    try:
        with suppress_variable_notifications():
            append_grad_local(
                params["variable_id"], params["gradient_id"], params["gradient"]
            )
            logger.debug(
                f"Gradient appended: variable_id={params['variable_id']!r} "
                f"gradient={params['gradient']!r}"
            )
            return {"message": "Ok"}
    except KeyError as e:
        logger.error(f"Missing key in params: {e}")
        raise KeyError(f"Missing key: {e}")
    except RuntimeError as e:
        logger.error(
            f"Failed to append gradient for variable with ID "
            f"{params.get('variable_id')!r}: {e}"
        )
        raise RuntimeError(
            f"Failed to append gradient for variable with ID "
            f"{params.get('variable_id')!r}: {e}"
        )

rpc_create_node(params) async

Handle the 'create_node' JSON-RPC method from the server.

This method creates and registers a new Node instance in the local registry using the provided parameters. It is typically called when the server notifies the client that a new node has been created in the computation graph.

Parameters:

Name Type Description Default
params dict

A dictionary with keys:

  • node_id: The unique identifier of the Node.
  • name: The class name or type of the Node.
required

Returns:

Type Description
dict[str, str]

A dictionary with a success message if the node is created.

Raises:

Type Description
KeyError

If required keys are missing from params.

Source code in afnio/tellurio/websocket_client.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
async def rpc_create_node(self, params: dict) -> Dict[str, str]:
    """Handle the `'create_node'` JSON-RPC method from the server.

    This method creates and registers a new Node instance in the local registry
    using the provided parameters. It is typically called when the server notifies
    the client that a new node has been created in the computation graph.

    Args:
        params: A dictionary with keys:

            - `node_id`: The unique identifier of the Node.
            - `name`: The class name or type of the Node.

    Returns:
        A dictionary with a success message if the node is created.

    Raises:
        KeyError: If required keys are missing from params.
    """
    try:
        create_node(params)
        logger.debug(
            f"Node created: node_id={params['node_id']!r} name={params['name']!r}"
        )
        return {"message": "Ok"}
    except KeyError as e:
        logger.error(f"Missing key in params: {e}")
        raise KeyError(f"Missing key: {e}")

rpc_create_edge(params) async

Handle the 'create_edge' JSON-RPC method from the server.

This method creates a GradientEdge between two nodes in the local registry, appending the edge to the from_node's next_functions. It is typically called when the server notifies the client that a new edge has been created in the computation graph.

Note

The terms from_node and to_node should be interpreted in the context of the backward pass (backpropagation): the edge is added to the from_node's next_functions and points to the to_node, following the direction of gradient flow during backpropagation.

Parameters:

Name Type Description Default
params dict

A dictionary with keys:

  • from_node_id: The unique identifier of the source node.
  • to_node_id: The unique identifier of the destination node.
  • output_nr: The output number associated with the edge.
required

Returns:

Type Description
dict[str, str]

A dictionary with a success message if the edge is created.

Raises:

Type Description
KeyError

If required keys are missing from params.

Source code in afnio/tellurio/websocket_client.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
async def rpc_create_edge(self, params: dict) -> Dict[str, str]:
    """Handle the `'create_edge'` JSON-RPC method from the server.

    This method creates a [`GradientEdge`][afnio.GradientEdge] between two nodes
    in the local registry, appending the edge to the `from_node`'s
    [`next_functions`][afnio.Node.next_functions].
    It is typically called when the server notifies the client that a new edge
    has been created in the computation graph.

    Note:
        The terms `from_node` and `to_node` should be interpreted in the context
        of the backward pass (backpropagation): the edge is added to the
        `from_node`'s [`next_functions`][afnio.Node.next_functions] and points to
        the `to_node`, following the direction of gradient flow
        during backpropagation.

    Args:
        params: A dictionary with keys:

            - `from_node_id`: The unique identifier of the source node.
            - `to_node_id`: The unique identifier of the destination node.
            - `output_nr`: The output number associated with the edge.

    Returns:
        A dictionary with a success message if the edge is created.

    Raises:
        KeyError: If required keys are missing from params.
    """
    try:
        create_and_append_edge(params)
        logger.debug(
            f"Edge created: from_node_id={params['from_node_id']!r} "
            f"to_node_id={params['to_node_id']!r} output_nr={params['output_nr']!r}"
        )
        return {"message": "Ok"}
    except KeyError as e:
        logger.error(f"Missing key in params: {e}")
        raise KeyError(f"Missing key: {e}")

rpc_update_model(params) async

Handle the 'update_model' JSON-RPC method from the server.

This method updates a specific field of a registered LM model instance in response to a server notification. It uses the provided parameters to identify the LM model and the field to update.

Parameters:

Name Type Description Default
params dict

A dictionary with keys:

  • model_id: The unique identifier of the LM model.
  • field: The field name to update.
  • value: The new value to set for the field.
required

Returns:

Type Description
dict[str, str]

A dictionary with a success message if the update is successful.

Raises:

Type Description
KeyError

If required keys are missing from params.

RuntimeError

If the update fails for any reason.

Source code in afnio/tellurio/websocket_client.py
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
async def rpc_update_model(self, params: dict) -> Dict[str, str]:
    """Handle the `'update_model'` JSON-RPC method from the server.

    This method updates a specific field of a registered LM model instance
    in response to a server notification. It uses the provided parameters
    to identify the LM model and the field to update.

    Args:
        params: A dictionary with keys:

            - `model_id`: The unique identifier of the LM model.
            - `field`: The field name to update.
            - `value`: The new value to set for the field.

    Returns:
        A dictionary with a success message if the update is successful.

    Raises:
        KeyError: If required keys are missing from params.
        RuntimeError: If the update fails for any reason.
    """
    try:
        update_local_model_field(
            params["model_id"], params["field"], params["value"]
        )
        logger.debug(
            f"Model updated: model_id={params['model_id']!r} "
            f"field={params['field']!r} value={params['value']!r}"
        )
        return {"message": "Ok"}
    except KeyError as e:
        logger.error(f"Missing key in params: {e}")
        raise KeyError(f"Missing key: {e}")
    except RuntimeError as e:
        logger.error(
            f"Failed to update model with ID {params.get('model_id')!r}: {e}"
        )
        raise RuntimeError(
            f"Failed to update model with ID {params.get('model_id')!r}: {e}"
        )

rpc_run_callable(params) async

Handle the 'run_callable' JSON-RPC method from the server.

This method is invoked when the server sends a JSON-RPC request with the method "run_callable". It extracts callable details from the provided parameters, executes the callable from the registry, and returns a response containing the result. The response is expected to be JSON-serializable.

Parameters:

Name Type Description Default
params dict

A dictionary containing:

  • callable_id: A unique identifier for the callable.
  • args: Positional arguments (as a list or tuple) for the callable.
  • kwargs: Keyword arguments for the callable.
required

Returns:

Type Description
dict[str, Any]

A dictionary with the following structure:

{ "message": "Ok", "data": <result of executing the callable> }

Raises:

Type Description
KeyError

If required keys are missing in params.

TypeError

If the result of the callable is not JSON-serializable.

ValueError

If the callable execution fails due to invalid parameters.

RuntimeError

For any other exception encountered during callable execution.

Source code in afnio/tellurio/websocket_client.py
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
async def rpc_run_callable(self, params: dict) -> Dict[str, Any]:
    """Handle the `'run_callable'` JSON-RPC method from the server.

    This method is invoked when the server sends a JSON-RPC request with the
    method `"run_callable"`. It extracts callable details from the provided
    parameters, executes the callable from the registry, and returns a response
    containing the result. The response is expected to be JSON-serializable.

    Args:
        params: A dictionary containing:

            - `callable_id`: A unique identifier for the callable.
            - `args`: Positional arguments (as a list or tuple) for the callable.
            - `kwargs`: Keyword arguments for the callable.

    Returns:
        A dictionary with the following structure:

            `{
                "message": "Ok",
                "data": <result of executing the callable>
            }`

    Raises:
        KeyError: If required keys are missing in params.
        TypeError: If the result of the callable is not JSON-serializable.
        ValueError: If the callable execution fails due to invalid parameters.
        RuntimeError: For any other exception encountered during callable execution.
    """
    try:
        result = run_callable(params)

        # Check if result is JSON serializable
        try:
            json.dumps(result)
        except (TypeError, ValueError) as e:
            logger.error(
                f"Result of callable with ID {params.get('callable_id')!r} "
                f"is not JSON-serializable: {result!r} ({e})"
            )
            raise TypeError(
                f"Result of callable with ID {params.get('callable_id')!r} "
                f"is not JSON-serializable: {result!r} ({e})"
            )

        logger.debug(
            f"Callable executed successfully: "
            f"callable_id={params['callable_id']!r}, "
            f"args={params.get('args', {})!r}, "
            f"kwargs={params.get('kwargs', {})!r}"
        )
        return {"message": "Ok", "data": result}
    except KeyError as e:
        logger.error(f"Missing key in params: {e}")
        raise KeyError(f"Missing key: {e}")
    except ValueError as e:
        logger.error(
            f"Failed to run callable with ID {params.get('callable_id')!r}: {e}"
        )
        raise ValueError(
            f"Failed to run callable with ID {params.get('callable_id')!r}: {e}"
        )
    except Exception as e:
        logger.error(
            f"Exception during execution of callable "
            f"with ID {params.get('callable_id')!r}: {e}"
        )
        raise RuntimeError(
            f"Exception during execution of callable "
            f"with ID {params.get('callable_id')!r}: {e}"
        )

rpc_clear_backward(params) async

Handle the 'clear_backward' JSON-RPC method from the server.

This method clears the _pending_grad flag for the specified variables. It is called after the server finalizes the backward pass for the entire computation graph, indicating that the gradients for its variables have been computed and already shared with the client. Once it receives 'clear_backward', the client can safely access the values of these gradients without worrying about them being modified.

Parameters:

Name Type Description Default
params dict

A dictionary containing:

  • variable_ids: A list of variable IDs for which to clear the _pending_grad flag.
required

Returns:

Type Description
dict[str, str]

A dictionary with a success message if the pending gradients are cleared.

Raises:

Type Description
KeyError

If required keys are missing from params.

RuntimeError

If clearing the pending grad fails for any variable.

Source code in afnio/tellurio/websocket_client.py
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
async def rpc_clear_backward(self, params: dict) -> Dict[str, str]:
    """Handle the `'clear_backward'` JSON-RPC method from the server.

    This method clears the `_pending_grad` flag for the specified variables.
    It is called after the server finalizes the backward pass for the entire
    computation graph, indicating that the gradients for its variables have been
    computed and already shared with the client. Once it receives
    `'clear_backward'`, the client can safely access the values of these gradients
    without worrying about them being modified.

    Args:
        params: A dictionary containing:

            - `variable_ids`: A list of variable IDs for which to clear
                the `_pending_grad` flag.

    Returns:
        A dictionary with a success message if the pending gradients are cleared.

    Raises:
        KeyError: If required keys are missing from params.
        RuntimeError: If clearing the pending grad fails for any variable.
    """
    try:
        variable_ids = params["variable_ids"]
        clear_pending_grad(variable_ids)

        logger.debug(f"Cleared pending gradients for variables: {variable_ids!r}")
        return {"message": "Ok"}
    except KeyError as e:
        logger.error(f"Missing key in params: {e}")
        raise KeyError(f"Missing key: {e}")
    except RuntimeError as e:
        logger.error(
            f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
        )
        raise RuntimeError(
            f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
        )
    except Exception as e:
        logger.error(f"Exception during execution of backward clearing: {e}")
        raise RuntimeError(f"Exception during execution of backward clearing: {e}")

rpc_clear_step(params) async

Handle the 'clear_step' JSON-RPC method from the server.

This method clears the _pending_data flag for the specified variables. It is called after the server completes an optimizer step and updates the data for the relevant variables. Once 'clear_step' is received, the client can safely access the updated values of these variables, knowing that the data is no longer pending or being modified.

Parameters:

Name Type Description Default
params dict

A dictionary containing:

  • variable_ids: A list of variable IDs (str) for which to clear the _pending_data flag.
required

Returns:

Type Description
dict[str, str]

A dictionary with a success message if the pending data is cleared.

Raises:

Type Description
KeyError

If required keys are missing from params.

RuntimeError

If clearing the pending data fails for any variable.

Source code in afnio/tellurio/websocket_client.py
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
async def rpc_clear_step(self, params: dict) -> Dict[str, str]:
    """Handle the `'clear_step'` JSON-RPC method from the server.

    This method clears the `_pending_data` flag for the specified variables.
    It is called after the server completes an optimizer step and updates
    the data for the relevant variables. Once `'clear_step'` is received,
    the client can safely access the updated values of these variables,
    knowing that the data is no longer pending or being modified.

    Args:
        params: A dictionary containing:

            - `variable_ids`: A list of variable IDs (str) for which to clear
              the `_pending_data` flag.

    Returns:
        A dictionary with a success message if the pending data is cleared.

    Raises:
        KeyError: If required keys are missing from params.
        RuntimeError: If clearing the pending data fails for any variable.
    """
    try:
        variable_ids = params["variable_ids"]
        clear_pending_data(variable_ids)

        logger.debug(f"Cleared pending data for variables: {variable_ids!r}")
        return {"message": "Ok"}
    except KeyError as e:
        logger.error(f"Missing key in params: {e}")
        raise KeyError(f"Missing key: {e}")
    except RuntimeError as e:
        logger.error(
            f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
        )
        raise RuntimeError(
            f"Failed to update variable with ID {params.get('variable_id')!r}: {e}"
        )
    except Exception as e:
        logger.error(f"Exception during execution of backward clearing: {e}")
        raise RuntimeError(f"Exception during execution of backward clearing: {e}")

call(method, params, timeout=None) async

Sends a request over the WebSocket connection and waits for a response.

Constructs a JSON-RPC request, sends it to the WebSocket server, and waits for the corresponding response. If no response is received within the timeout period, a TimeoutError is raised.

Parameters:

Name Type Description Default
method str

The name of the method to call on the backend.

required
params dict

The parameters to pass to the method.

required
timeout

The timeout (in seconds) for the response. If not provided, the default timeout is used.

None

Returns:

Type Description
Any

The result of the method call.

Raises:

Type Description
RuntimeError

If the WebSocket connection is not established.

TimeoutError

If the response is not received within the timeout period.

Source code in afnio/tellurio/websocket_client.py
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
async def call(self, method: str, params: dict, timeout=None) -> Any:
    """Sends a request over the WebSocket connection and waits for a response.

    Constructs a JSON-RPC request, sends it to the WebSocket server, and waits
    for the corresponding response. If no response is received within the timeout
    period, a `TimeoutError` is raised.

    Args:
        method: The name of the method to call on the backend.
        params: The parameters to pass to the method.
        timeout: The timeout (in seconds) for the response.
            If not provided, the default timeout is used.

    Returns:
        The result of the method call.

    Raises:
        RuntimeError: If the WebSocket connection is not established.
        asyncio.TimeoutError: If the response is not received within
          the timeout period.
    """
    timeout = timeout or self.default_timeout  # Use default timeout if not provided

    if not self.connection:
        raise RuntimeError("WebSocket is not connected")

    active_run = get_active_run()
    params["run_uuid"] = active_run.uuid

    req_id = str(uuid.uuid4()) if timeout else None
    request = {
        "jsonrpc": "2.0",
        "method": method,
        "params": params,
    }
    if req_id:
        request["id"] = req_id

    # Send request and wait for matching response
    await self.connection.send(json.dumps(request))
    logger.debug(f"Sent RPC request: {request}")

    # If it's a notification (no `id`), return immediately
    if not req_id:
        return None

    # Wait for response
    future = asyncio.get_running_loop().create_future()
    self.pending[req_id] = future

    if method in LONG_RUNNING_METHODS:
        # Heartbeat-aware wait loop
        self._heartbeat_times[req_id] = time.monotonic()
        last_heartbeat = time.monotonic()
        try:
            while True:
                try:
                    # Using `shield` to prevent cancellation of the future to allow
                    # heartbeat updates to keep it alive
                    return await asyncio.wait_for(
                        asyncio.shield(future), timeout=timeout
                    )
                except asyncio.TimeoutError:
                    now = time.monotonic()
                    last_heartbeat = self._heartbeat_times.get(
                        req_id, last_heartbeat
                    )
                    if now - last_heartbeat > timeout:
                        logger.error(f"Request timed out (no heartbeat): {request}")
                        raise
        finally:
            self.pending.pop(req_id, None)
            self._heartbeat_times.pop(req_id, None)
    else:
        # Standard wait
        try:
            return await asyncio.wait_for(future, timeout=timeout)
        except asyncio.TimeoutError:
            logger.error(f"Request timed out: {request}")
            raise
        finally:
            self.pending.pop(req_id, None)

close() async

Closes the WebSocket connection and cleans up resources.

Cancels the listener task, clears pending requests, and closes the WebSocket connection.

Source code in afnio/tellurio/websocket_client.py
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
async def close(self):
    """Closes the WebSocket connection and cleans up resources.

    Cancels the listener task, clears pending requests, and closes the WebSocket
    connection.
    """
    # Add a delay to allow receiving and replying to remaining server requests
    await asyncio.sleep(1)

    if self.listener_task:
        logger.debug("Canceling listener task...")
        self.listener_task.cancel()
        try:
            await self.listener_task  # Wait for the listener task to finish
        except asyncio.CancelledError:
            logger.debug("Listener task canceled.")
            pass  # Ignore cancellation errors
    self.listener_task = None  # Clean up the listener task

    if self.connection:
        logger.debug("Closing WebSocket connection...")
        try:
            await self.connection.close()
        finally:
            self.connection = None

    logger.debug("Clearing pending requests...")
    self._cancel_pending_requests()  # Clear pending requests

    logger.debug("WebSocket connection closed.")

__aenter__() async

Asynchronous context manager entry.

Establishes the WebSocket connection when entering the context. If the connection is already established, it ensures the connection is active.

Returns:

Type Description
TellurioWebSocketClient

The WebSocket client instance.

Source code in afnio/tellurio/websocket_client.py
851
852
853
854
855
856
857
858
859
860
861
862
async def __aenter__(self) -> "TellurioWebSocketClient":
    """Asynchronous context manager entry.

    Establishes the WebSocket connection when entering the context.
    If the connection is already established, it ensures the connection is active.

    Returns:
        The WebSocket client instance.
    """
    if not self.connection or self.connection.closed:
        await self.connect()
    return self

__aexit__(exc_type, exc_val, exc_tb) async

Asynchronous context manager exit.

Closes the WebSocket connection and cleans up resources when exiting the context.

Source code in afnio/tellurio/websocket_client.py
864
865
866
867
868
869
870
async def __aexit__(self, exc_type, exc_val, exc_tb):
    """Asynchronous context manager exit.

    Closes the WebSocket connection and cleans up resources
    when exiting the context.
    """
    await self.close()