afnio.autodiff.lm_ops
afnio.autodiff.lm_ops.ChatCompletion
Bases: Function
Implements a chat completion operation using the specified language model within
the afnio framework, supporting automatic differentiation.
This class inherits from Function and
requires both the forward and
backward methods to be defined.
Features
- Mini-Batching: Processes multiple input dictionaries simultaneously to improve throughput.
- Asynchronous Execution: Both the forward and backward passes are optimized to run asynchronous calls for each mini-batch, reducing latency.
- Gradient Computation: Supports automatic differentiation for all
Variables inmessagesandinputsarguments, maintaining the order of gradients.
The ChatCompletion function generates a Variable responses by
passing a composite prompt, built from a list of messages and optional inputs,
to the forward_model_client. Each message is a dictionary with a 'role' (e.g.,
'system', 'user') and a list of Variable objects as
'content'. inputs is a dictionary containing strings, list of strings or
Variables providing dynamic values to fill placeholders within
message templates. If inputs contain lists of strings or
Variables which data field is a list,
the response's data field will be a list, corresponding to
the batched results. Otherwise, the data field will be a
scalar string. Additional behavior, such as temperature or token limits, can be
customized through completion_args.
Examples:
Example with scalar inputs:
>>> system = Variable(
... "You are a helpful assistant.",
... role="system instruction",
... requires_grad=True
... )
>>> user = Variable("Translate 'Hello' to {language}.", role="user query")
>>> messages = [
... {"role": "system", "content": [system]},
... {"role": "user", "content": [user]},
... ]
>>> inputs = {"language": Variable("Italian", role="language")}
>>> response = ChatCompletion.apply(
... model_client,
... messages,
... inputs=inputs,
... temperature=0.7
... )
>>> print(response.data)
'Ciao'
'Hola'
>>> feedback = Variable("Use only capital letters.", role="feedback")
>>> response.backward(feedback)
>>> system.grad[0].data
'The system instruction should enforce the use of capital letters only.'
Example with batched inputs:
>>> system = Variable(
... "You are a helpful assistant.",
... role="system instruction",
... requires_grad=True
... )
>>> user = Variable("Translate 'Hello' to {language}.", role="user query")
>>> messages = [
... {"role": "system", "content": [system]},
... {"role": "user", "content": [user]},
... ]
>>> inputs = {
... "language": [
... Variable("Italian", role="language"),
... Variable("Spanish", role="language")
... ]
... }
>>> response = ChatCompletion.apply(
... model_client,
... messages,
... inputs=inputs,
... temperature=0.7
... )
>>> print(response.data)
['Ciao', 'Hola']
Source code in afnio/autodiff/lm_ops.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | |
forward(ctx, forward_model_client, messages, inputs=None, **completion_args)
staticmethod
Forward pass for the chat completion function.
Warning
This method is invoked by
apply()
and should not be called directly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
Context object used to save information for |
required | |
forward_model_client
|
ChatCompletionModel | None
|
The LM model client used for generating chat completions. |
required |
messages
|
MultiTurnMessages
|
A list of messages that compose the prompt/context for the LM.
Each message is a dictionary with a |
required |
inputs
|
dict[str, str | list[str] | Variable] | None
|
A dictionary mapping placeholder names to their corresponding
values, which can be strings or |
None
|
**completion_args
|
Additional keyword arguments to pass to the LM model
client's |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
response |
Variable
|
A |
Raises:
| Type | Description |
|---|---|
TypeError
|
If the types of |
Source code in afnio/autodiff/lm_ops.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | |
backward(ctx, grad_output)
staticmethod
Backward pass for the chat completion function.
Warning
This method is invoked by the autodiff engine and should not be called directly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
Context object containing saved information from the
|
required | |
grad_output
|
Variable
|
The gradient of the |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
None
|
Placeholder for the |
grad_messages |
tuple[Variable | None, ...]
|
A tuple of gradients for the
|
grad_inputs |
tuple[Variable | None, ...]
|
A tuple of gradients for the
|
None |
tuple[None, ...]
|
Placeholder for any additional completion arguments
passed to |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the LM response to generate the gradients cannot be parsed as valid JSON after the maximum number of retries. |
ValueError
|
If the number of gradients returned by the LM does not match the expected number. |
Source code in afnio/autodiff/lm_ops.py
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | |