Skip to content

afnio.autodiff.lm_ops

afnio.autodiff.lm_ops.ChatCompletion

Bases: Function

Implements a chat completion operation using the specified language model within the afnio framework, supporting automatic differentiation.

This class inherits from Function and requires both the forward and backward methods to be defined.

Features
  • Mini-Batching: Processes multiple input dictionaries simultaneously to improve throughput.
  • Asynchronous Execution: Both the forward and backward passes are optimized to run asynchronous calls for each mini-batch, reducing latency.
  • Gradient Computation: Supports automatic differentiation for all Variables in messages and inputs arguments, maintaining the order of gradients.

The ChatCompletion function generates a Variable responses by passing a composite prompt, built from a list of messages and optional inputs, to the forward_model_client. Each message is a dictionary with a 'role' (e.g., 'system', 'user') and a list of Variable objects as 'content'. inputs is a dictionary containing strings, list of strings or Variables providing dynamic values to fill placeholders within message templates. If inputs contain lists of strings or Variables which data field is a list, the response's data field will be a list, corresponding to the batched results. Otherwise, the data field will be a scalar string. Additional behavior, such as temperature or token limits, can be customized through completion_args.

Examples:

Example with scalar inputs:

>>> system = Variable(
...     "You are a helpful assistant.",
...     role="system instruction",
...     requires_grad=True
... )
>>> user = Variable("Translate 'Hello' to {language}.", role="user query")
>>> messages = [
...     {"role": "system", "content": [system]},
...     {"role": "user", "content": [user]},
... ]
>>> inputs = {"language": Variable("Italian", role="language")}
>>> response = ChatCompletion.apply(
...     model_client,
...     messages,
...     inputs=inputs,
...     temperature=0.7
... )
>>> print(response.data)
'Ciao'
'Hola'
>>> feedback = Variable("Use only capital letters.", role="feedback")
>>> response.backward(feedback)
>>> system.grad[0].data
'The system instruction should enforce the use of capital letters only.'

Example with batched inputs:

>>> system = Variable(
...     "You are a helpful assistant.",
...     role="system instruction",
...     requires_grad=True
... )
>>> user = Variable("Translate 'Hello' to {language}.", role="user query")
>>> messages = [
...     {"role": "system", "content": [system]},
...     {"role": "user", "content": [user]},
... ]
>>> inputs = {
...     "language": [
...         Variable("Italian", role="language"),
...         Variable("Spanish", role="language")
...     ]
... }
>>> response = ChatCompletion.apply(
...     model_client,
...     messages,
...     inputs=inputs,
...     temperature=0.7
... )
>>> print(response.data)
['Ciao', 'Hola']
Source code in afnio/autodiff/lm_ops.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
class ChatCompletion(Function):
    """
    Implements a chat completion operation using the specified language model within
    the `afnio` framework, supporting automatic differentiation.

    This class inherits from [`Function`][afnio.autodiff.function.Function] and
    requires both the [`forward`][afnio.autodiff.function.Function.forward] and
    [`backward`][afnio.autodiff.function.Function.backward] methods to be defined.

    Features:
        - **Mini-Batching**: Processes multiple input dictionaries simultaneously
            to improve throughput.
        - **Asynchronous Execution**: Both the forward and backward passes are optimized
            to run asynchronous calls for each mini-batch, reducing latency.
        - **Gradient Computation**: Supports automatic differentiation for all
            [`Variable`][afnio.Variable]s in `messages` and `inputs` arguments,
            maintaining the order of gradients.

    The `ChatCompletion` function generates a [`Variable`][afnio.Variable] responses by
    passing a composite prompt, built from a list of `messages` and optional `inputs`,
    to the `forward_model_client`. Each message is a dictionary with a `'role'` (e.g.,
    `'system'`, `'user'`) and a list of [`Variable`][afnio.Variable] objects as
    `'content'`. `inputs` is a dictionary containing strings, list of strings or
    [`Variable`][afnio.Variable]s providing dynamic values to fill placeholders within
    message templates. If `inputs` contain lists of strings or
    [`Variable`][afnio.Variable]s which [`data`][afnio.Variable.data] field is a list,
    the response's [`data`][afnio.Variable.data] field will be a list, corresponding to
    the batched results. Otherwise, the [`data`][afnio.Variable.data] field will be a
    scalar string. Additional behavior, such as temperature or token limits, can be
    customized through `completion_args`.

    Examples:
        Example with scalar inputs:
        >>> system = Variable(
        ...     "You are a helpful assistant.",
        ...     role="system instruction",
        ...     requires_grad=True
        ... )
        >>> user = Variable("Translate 'Hello' to {language}.", role="user query")
        >>> messages = [
        ...     {"role": "system", "content": [system]},
        ...     {"role": "user", "content": [user]},
        ... ]
        >>> inputs = {"language": Variable("Italian", role="language")}
        >>> response = ChatCompletion.apply(
        ...     model_client,
        ...     messages,
        ...     inputs=inputs,
        ...     temperature=0.7
        ... )
        >>> print(response.data)
        'Ciao'
        'Hola'
        >>> feedback = Variable("Use only capital letters.", role="feedback")
        >>> response.backward(feedback)
        >>> system.grad[0].data
        'The system instruction should enforce the use of capital letters only.'

        Example with batched inputs:
        >>> system = Variable(
        ...     "You are a helpful assistant.",
        ...     role="system instruction",
        ...     requires_grad=True
        ... )
        >>> user = Variable("Translate 'Hello' to {language}.", role="user query")
        >>> messages = [
        ...     {"role": "system", "content": [system]},
        ...     {"role": "user", "content": [user]},
        ... ]
        >>> inputs = {
        ...     "language": [
        ...         Variable("Italian", role="language"),
        ...         Variable("Spanish", role="language")
        ...     ]
        ... }
        >>> response = ChatCompletion.apply(
        ...     model_client,
        ...     messages,
        ...     inputs=inputs,
        ...     temperature=0.7
        ... )
        >>> print(response.data)
        ['Ciao', 'Hola']
    """

    @staticmethod
    def forward(
        ctx,
        forward_model_client: Optional[ChatCompletionModel],
        messages: MultiTurnMessages,
        inputs: Optional[Dict[str, Union[str, List[str], Variable]]] = None,
        **completion_args,
    ) -> Variable:
        """
        Forward pass for the chat completion function.

        Warning:
            This method is invoked by
            [`apply()`][afnio.autodiff.function.Function.apply]
            and should not be called directly.

        Args:
            ctx: Context object used to save information for [`backward`][..backward]
                computation.
            forward_model_client: The LM model client used for generating
                chat completions.
            messages: A list of messages that compose the prompt/context for the LM.
                Each message is a dictionary with a `"role"` (e.g., `"system"`,
                `"user"`, `"assistant"`) and a `"content"` field, which is a list of
                `Variable` objects. The `Variable` objects in the `"content"` can
                contain placeholders (e.g., `{prediction}`, `{target}`) that will be
                populated with the corresponding values from the `inputs` dictionary.
            inputs: A dictionary mapping placeholder names to their corresponding
                values, which can be strings or `Variable` instances. These values
                will be used to populate the placeholders in the `messages` content
                before sending the prompt to the LM. For example, if a message
                `"content"` field contains the placeholder `{color}`, the `inputs`
                dictionary should have a key `"color"` with the value to substitute
                in the prompt. Optional if there are no placeholders in the messages or
                if all placeholders are directly related to `prediction` and `target`.
            **completion_args: Additional keyword arguments to pass to the LM model
                client's `chat` method, such as temperature, max tokens, or seed values,
                to customize the LLM's behavior during the evaluation.

        Returns:
            response: A `Variable` containing the LM's response. \
                The [`data`][afnio.Variable.data] field of the returned `Variable` \
                will be a string if all inputs are scalar, or a list of strings if \
                any input is a list. The `role` field will indicate that this is a \
                response to the input messages, and the `requires_grad` field will \
                be set to `True` if any of the input `Variable` objects in `messages` \
                require gradients, otherwise `False`.

        Raises:
            TypeError: If the types of `forward_model_client`, `messages`,
                or `inputs` are not as expected.
        """
        raise NotImplementedError(
            "ChatCompletion.forward is implemented on the server. "
            "Client-side execution is not supported."
        )

    @staticmethod
    def backward(ctx, grad_output: Variable) -> Tuple[Optional[Variable], ...]:
        """
        Backward pass for the chat completion function.

        Warning:
            This method is invoked by the autodiff engine
            and should not be called directly.

        Args:
            ctx: Context object containing saved information from the
                [`forward`][..forward] pass.
            grad_output: The gradient of the `response` `Variable`
                w.r.t. the output of the `forward()` method.

        Returns:
            None (None): Placeholder for the `forward_model_client` argument of
                `forward()`, which does not require a gradient.
            grad_messages (Tuple[Optional[Variable], ...]): A tuple of gradients for the
                `messages` argument of `forward()`, where each gradient corresponds to
                the respective message variable.
            grad_inputs (Tuple[Optional[Variable], ...]): A tuple of gradients for the
                `inputs` argument of `forward()`, where each gradient corresponds to the
                respective input variable.
            None (Tuple[None, ...]): Placeholder for any additional completion arguments
                passed to `forward()`, which do not require gradients.

        Raises:
            RuntimeError: If the LM response to generate the gradients cannot be parsed
                as valid JSON after the maximum number of retries.
            ValueError: If the number of gradients returned by the LM does not match the
                expected number.
        """
        raise NotImplementedError(
            "ChatCompletion.backward is implemented on the server. "
            "Client-side execution is not supported."
        )

forward(ctx, forward_model_client, messages, inputs=None, **completion_args) staticmethod

Forward pass for the chat completion function.

Warning

This method is invoked by apply() and should not be called directly.

Parameters:

Name Type Description Default
ctx

Context object used to save information for backward computation.

required
forward_model_client ChatCompletionModel | None

The LM model client used for generating chat completions.

required
messages MultiTurnMessages

A list of messages that compose the prompt/context for the LM. Each message is a dictionary with a "role" (e.g., "system", "user", "assistant") and a "content" field, which is a list of Variable objects. The Variable objects in the "content" can contain placeholders (e.g., {prediction}, {target}) that will be populated with the corresponding values from the inputs dictionary.

required
inputs dict[str, str | list[str] | Variable] | None

A dictionary mapping placeholder names to their corresponding values, which can be strings or Variable instances. These values will be used to populate the placeholders in the messages content before sending the prompt to the LM. For example, if a message "content" field contains the placeholder {color}, the inputs dictionary should have a key "color" with the value to substitute in the prompt. Optional if there are no placeholders in the messages or if all placeholders are directly related to prediction and target.

None
**completion_args

Additional keyword arguments to pass to the LM model client's chat method, such as temperature, max tokens, or seed values, to customize the LLM's behavior during the evaluation.

{}

Returns:

Name Type Description
response Variable

A Variable containing the LM's response. The data field of the returned Variable will be a string if all inputs are scalar, or a list of strings if any input is a list. The role field will indicate that this is a response to the input messages, and the requires_grad field will be set to True if any of the input Variable objects in messages require gradients, otherwise False.

Raises:

Type Description
TypeError

If the types of forward_model_client, messages, or inputs are not as expected.

Source code in afnio/autodiff/lm_ops.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
@staticmethod
def forward(
    ctx,
    forward_model_client: Optional[ChatCompletionModel],
    messages: MultiTurnMessages,
    inputs: Optional[Dict[str, Union[str, List[str], Variable]]] = None,
    **completion_args,
) -> Variable:
    """
    Forward pass for the chat completion function.

    Warning:
        This method is invoked by
        [`apply()`][afnio.autodiff.function.Function.apply]
        and should not be called directly.

    Args:
        ctx: Context object used to save information for [`backward`][..backward]
            computation.
        forward_model_client: The LM model client used for generating
            chat completions.
        messages: A list of messages that compose the prompt/context for the LM.
            Each message is a dictionary with a `"role"` (e.g., `"system"`,
            `"user"`, `"assistant"`) and a `"content"` field, which is a list of
            `Variable` objects. The `Variable` objects in the `"content"` can
            contain placeholders (e.g., `{prediction}`, `{target}`) that will be
            populated with the corresponding values from the `inputs` dictionary.
        inputs: A dictionary mapping placeholder names to their corresponding
            values, which can be strings or `Variable` instances. These values
            will be used to populate the placeholders in the `messages` content
            before sending the prompt to the LM. For example, if a message
            `"content"` field contains the placeholder `{color}`, the `inputs`
            dictionary should have a key `"color"` with the value to substitute
            in the prompt. Optional if there are no placeholders in the messages or
            if all placeholders are directly related to `prediction` and `target`.
        **completion_args: Additional keyword arguments to pass to the LM model
            client's `chat` method, such as temperature, max tokens, or seed values,
            to customize the LLM's behavior during the evaluation.

    Returns:
        response: A `Variable` containing the LM's response. \
            The [`data`][afnio.Variable.data] field of the returned `Variable` \
            will be a string if all inputs are scalar, or a list of strings if \
            any input is a list. The `role` field will indicate that this is a \
            response to the input messages, and the `requires_grad` field will \
            be set to `True` if any of the input `Variable` objects in `messages` \
            require gradients, otherwise `False`.

    Raises:
        TypeError: If the types of `forward_model_client`, `messages`,
            or `inputs` are not as expected.
    """
    raise NotImplementedError(
        "ChatCompletion.forward is implemented on the server. "
        "Client-side execution is not supported."
    )

backward(ctx, grad_output) staticmethod

Backward pass for the chat completion function.

Warning

This method is invoked by the autodiff engine and should not be called directly.

Parameters:

Name Type Description Default
ctx

Context object containing saved information from the forward pass.

required
grad_output Variable

The gradient of the response Variable w.r.t. the output of the forward() method.

required

Returns:

Name Type Description
None None

Placeholder for the forward_model_client argument of forward(), which does not require a gradient.

grad_messages tuple[Variable | None, ...]

A tuple of gradients for the messages argument of forward(), where each gradient corresponds to the respective message variable.

grad_inputs tuple[Variable | None, ...]

A tuple of gradients for the inputs argument of forward(), where each gradient corresponds to the respective input variable.

None tuple[None, ...]

Placeholder for any additional completion arguments passed to forward(), which do not require gradients.

Raises:

Type Description
RuntimeError

If the LM response to generate the gradients cannot be parsed as valid JSON after the maximum number of retries.

ValueError

If the number of gradients returned by the LM does not match the expected number.

Source code in afnio/autodiff/lm_ops.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@staticmethod
def backward(ctx, grad_output: Variable) -> Tuple[Optional[Variable], ...]:
    """
    Backward pass for the chat completion function.

    Warning:
        This method is invoked by the autodiff engine
        and should not be called directly.

    Args:
        ctx: Context object containing saved information from the
            [`forward`][..forward] pass.
        grad_output: The gradient of the `response` `Variable`
            w.r.t. the output of the `forward()` method.

    Returns:
        None (None): Placeholder for the `forward_model_client` argument of
            `forward()`, which does not require a gradient.
        grad_messages (Tuple[Optional[Variable], ...]): A tuple of gradients for the
            `messages` argument of `forward()`, where each gradient corresponds to
            the respective message variable.
        grad_inputs (Tuple[Optional[Variable], ...]): A tuple of gradients for the
            `inputs` argument of `forward()`, where each gradient corresponds to the
            respective input variable.
        None (Tuple[None, ...]): Placeholder for any additional completion arguments
            passed to `forward()`, which do not require gradients.

    Raises:
        RuntimeError: If the LM response to generate the gradients cannot be parsed
            as valid JSON after the maximum number of retries.
        ValueError: If the number of gradients returned by the LM does not match the
            expected number.
    """
    raise NotImplementedError(
        "ChatCompletion.backward is implemented on the server. "
        "Client-side execution is not supported."
    )