Skip to content

afnio.optim.tgd

afnio.optim.tgd.TGD

Bases: Optimizer

Textual Gradient Descent (TGD) optimizer.

TGD is an optimization algorithm for language-model–based systems where gradients are represented and propagated as natural language feedback rather than numerical tensors. Instead of computing numerical derivatives, TGD relies on a language model to generate textual critiques (gradients) that are used to iteratively refine prompt-based parameters.

This implementation follows the ideas introduced in the TextGrad paper, which proposes treating language-model feedback as a differentiable signal for optimizing textual variables and prompt programs.

TGD operates over Variable objects and consumes textual gradients produced by the automatic differentiation process. These gradients are used to update the optimized variables, with optional momentum applied to recent gradient history to stabilize and accelerate optimization.

Parameters are organized into parameter groups, similar to optimizers in PyTorch. This allows different optimization settings—such as optimization meta-prompts (messages), constraints, and momentum—to be applied consistently across groups.

References:

Source code in afnio/optim/tgd.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class TGD(Optimizer):
    """
    Textual Gradient Descent (TGD) optimizer.

    TGD is an optimization algorithm for language-model–based systems where
    gradients are represented and propagated as *natural language feedback*
    rather than numerical tensors. Instead of computing numerical derivatives,
    TGD relies on a language model to generate textual critiques (gradients)
    that are used to iteratively refine prompt-based parameters.

    This implementation follows the ideas introduced in the *TextGrad* paper,
    which proposes treating language-model feedback as a differentiable signal
    for optimizing textual variables and prompt programs.

    TGD operates over [`Variable`][afnio.Variable] objects and consumes textual
    gradients produced by the automatic differentiation process. These gradients
    are used to update the optimized variables, with optional momentum applied
    to recent gradient history to stabilize and accelerate optimization.

    Parameters are organized into parameter groups, similar to optimizers in
    PyTorch. This allows different optimization settings—such as optimization
    meta-prompts (`messages`), `constraints`, and `momentum`—to be applied
    consistently across groups.

    **References:**

    - *TextGrad: Automatic Differentiation via Large Language Models*
        [https://arxiv.org/abs/2406.07496](https://arxiv.org/abs/2406.07496)
    """

    def __init__(
        self,
        params: ParamsT,
        model_client: Optional[ChatCompletionModel],
        messages: MultiTurnMessages = TGD_MESSAGES,
        inputs: Optional[Dict[str, Union[str, Variable]]] = None,
        constraints: Optional[List[Union[str, Variable]]] = None,
        momentum: int = 0,
        **completion_args,
    ):
        """Initialize the Textual Gradient Descent (TGD) optimizer.

        Args:
            params (iterable): Iterable of parameters to optimize or dicts defining
                parameter groups.
            model_client: LM model client used for optimization.
            messages: Messages for multi-turn interactions. It typically defines
                the optimizer system prompt and user instruction. In-context
                examples (shots) can be added as well.
            inputs: Dynamic values to fill placeholders within message templates
            constraints: A list of natural language constraints for optimization.
            momentum (int, optional): Momentum window size. Tracks the last `momentum`
                gradients, which helps accelerate updates in the right direction and
                dampen oscillations. Defaults to 0.
            completion_args (Dict[str, Any], optional): Additional arguments to pass to
                the model client when generating text completions. Defaults to an
                empty dictionary.
        """
        # Workaround to trigger TGD_MESSAGES registration with the server
        # and store related variable_ids on the client side
        if messages is TGD_MESSAGES:
            messages = [
                {
                    "role": "system",
                    "content": [
                        Variable(
                            data="Placeholder for Textual Gradient Descent optimizer system prompt",  # noqa: E501
                            role="Textual Gradient Descent optimizer system prompt",
                        )
                    ],
                },
                {
                    "role": "user",
                    "content": [
                        Variable(
                            data="Placeholder for Textual Gradient Descent optimizer user prompt",  # noqa: E501
                            role="Textual Gradient Descent optimizer user prompt",
                        )
                    ],
                },
            ]

        defaults = dict(
            model_client=model_client,
            messages=messages,
            inputs=inputs or {},
            constraints=constraints or [],
            momentum=momentum,
            completion_args=completion_args,
        )
        super().__init__(params, defaults)

    def step(
        self, closure: Optional[Callable] = None
    ) -> Optional[Tuple[Variable, Variable]]:
        """Performs a single optimization step.

        Args:
            closure: A closure that reevaluates the model and returns the loss.

        Returns:
            The loss if `closure` is provided, otherwise None. The loss should \
            return a numerical or textual score and a textual explanation, both \
            wrapped as [`Variable`][afnio.Variable] objects
        """
        loss = closure() if closure else (None, None)
        super().step()
        return loss

    def _extract_variable_ids_from_state(self, state):
        """
        Extract only the variable_ids of deepcopied parameters (i.e., those generated
        on the server) from the optimizer state.

        Args:
            state (list): The serialized optimizer state as returned by the server.

        Returns:
            Set[str]: Set of variable_ids for deepcopied parameters.
        """
        var_ids = set()
        for entry in state:
            momentum_buffer = entry.get("value", {}).get("momentum_buffer", [])
            for buf_entry in momentum_buffer:
                if (
                    isinstance(buf_entry, list)
                    and len(buf_entry) > 0
                    and isinstance(buf_entry[0], dict)
                    and "variable_id" in buf_entry[0]
                ):
                    var_ids.add(buf_entry[0]["variable_id"])
        return var_ids

__init__(params, model_client, messages=TGD_MESSAGES, inputs=None, constraints=None, momentum=0, **completion_args)

Initialize the Textual Gradient Descent (TGD) optimizer.

Parameters:

Name Type Description Default
params iterable

Iterable of parameters to optimize or dicts defining parameter groups.

required
model_client ChatCompletionModel | None

LM model client used for optimization.

required
messages MultiTurnMessages

Messages for multi-turn interactions. It typically defines the optimizer system prompt and user instruction. In-context examples (shots) can be added as well.

TGD_MESSAGES
inputs dict[str, str | Variable] | None

Dynamic values to fill placeholders within message templates

None
constraints list[str | Variable] | None

A list of natural language constraints for optimization.

None
momentum int

Momentum window size. Tracks the last momentum gradients, which helps accelerate updates in the right direction and dampen oscillations. Defaults to 0.

0
completion_args dict[str, Any]

Additional arguments to pass to the model client when generating text completions. Defaults to an empty dictionary.

{}
Source code in afnio/optim/tgd.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __init__(
    self,
    params: ParamsT,
    model_client: Optional[ChatCompletionModel],
    messages: MultiTurnMessages = TGD_MESSAGES,
    inputs: Optional[Dict[str, Union[str, Variable]]] = None,
    constraints: Optional[List[Union[str, Variable]]] = None,
    momentum: int = 0,
    **completion_args,
):
    """Initialize the Textual Gradient Descent (TGD) optimizer.

    Args:
        params (iterable): Iterable of parameters to optimize or dicts defining
            parameter groups.
        model_client: LM model client used for optimization.
        messages: Messages for multi-turn interactions. It typically defines
            the optimizer system prompt and user instruction. In-context
            examples (shots) can be added as well.
        inputs: Dynamic values to fill placeholders within message templates
        constraints: A list of natural language constraints for optimization.
        momentum (int, optional): Momentum window size. Tracks the last `momentum`
            gradients, which helps accelerate updates in the right direction and
            dampen oscillations. Defaults to 0.
        completion_args (Dict[str, Any], optional): Additional arguments to pass to
            the model client when generating text completions. Defaults to an
            empty dictionary.
    """
    # Workaround to trigger TGD_MESSAGES registration with the server
    # and store related variable_ids on the client side
    if messages is TGD_MESSAGES:
        messages = [
            {
                "role": "system",
                "content": [
                    Variable(
                        data="Placeholder for Textual Gradient Descent optimizer system prompt",  # noqa: E501
                        role="Textual Gradient Descent optimizer system prompt",
                    )
                ],
            },
            {
                "role": "user",
                "content": [
                    Variable(
                        data="Placeholder for Textual Gradient Descent optimizer user prompt",  # noqa: E501
                        role="Textual Gradient Descent optimizer user prompt",
                    )
                ],
            },
        ]

    defaults = dict(
        model_client=model_client,
        messages=messages,
        inputs=inputs or {},
        constraints=constraints or [],
        momentum=momentum,
        completion_args=completion_args,
    )
    super().__init__(params, defaults)

step(closure=None)

Performs a single optimization step.

Parameters:

Name Type Description Default
closure Callable | None

A closure that reevaluates the model and returns the loss.

None

Returns:

Type Description
tuple[Variable, Variable] | None

The loss if closure is provided, otherwise None. The loss should return a numerical or textual score and a textual explanation, both wrapped as Variable objects

Source code in afnio/optim/tgd.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def step(
    self, closure: Optional[Callable] = None
) -> Optional[Tuple[Variable, Variable]]:
    """Performs a single optimization step.

    Args:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        The loss if `closure` is provided, otherwise None. The loss should \
        return a numerical or textual score and a textual explanation, both \
        wrapped as [`Variable`][afnio.Variable] objects
    """
    loss = closure() if closure else (None, None)
    super().step()
    return loss

afnio.optim.tgd.tgd(params, grads, momentum_buffer_list, model_client, messages, inputs, constraints, momentum, **completion_args)

Functional API that performs TGD (Textual Gradient Descent) algorithm computation.

See TGD for details.

Source code in afnio/optim/tgd.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def tgd(
    params: List[Variable],
    grads: List[List[Variable]],
    momentum_buffer_list: List[Optional[List[Variable]]],
    model_client: Optional[ChatCompletionModel],
    messages: MultiTurnMessages,
    inputs: Optional[Dict[str, Union[str, Variable]]],
    constraints: Optional[List[Union[str, Variable]]],
    momentum: int,
    **completion_args,
):
    """Functional API that performs TGD (Textual Gradient Descent) algorithm
    computation.

    See [`TGD`][afnio.optim.tgd.TGD] for details.
    """
    # Set `_pending_data` for all parameters that will be optimized
    for p in params:
        p._pending_data = True
        logger.debug(f"Marked variable {p.variable_id!r} as pending for data update.")

    try:
        _, ws_client = get_default_clients()

        payload = {
            "params": _serialize_arg(params),
            "grads": _serialize_arg(grads),
            "momentum_buffer_list": _serialize_arg(momentum_buffer_list),
            "model_client": _serialize_arg(model_client),
            "messages": _serialize_arg(messages),
            "inputs": _serialize_arg(inputs),
            "constraints": _serialize_arg(constraints),
            "momentum": momentum,
            "completion_args": _serialize_arg(completion_args),
        }

        response = run_in_background_loop(ws_client.call("run_optimizer_tgd", payload))
        if "error" in response:
            raise RuntimeError(
                response["error"]["data"].get("exception", response["error"])
            )

        logger.debug(f"TGD optimization request sent: {payload!r}")

        result = response.get("result", {})
        result_message = result.get("message")
        result_momentum_buffer_list = result.get("momentum_buffer_list", [])

        # Extract all variable_ids from the result_momentum_buffer_list
        # and wait for them to be registered in VARIABLE_REGISTRY
        all_var_ids = _extract_variable_ids(result_momentum_buffer_list)
        for var_id in all_var_ids:
            _wait_for_variable(var_id)

        des_momentum_buffer_list = _deserialize_output(result_momentum_buffer_list)

        # Convert [param, grads] lists to (param, grads) tuples
        for i, buffer in enumerate(des_momentum_buffer_list):
            des_momentum_buffer_list[i] = [
                tuple(pair) if isinstance(pair, list) and len(pair) == 2 else pair
                for pair in buffer
            ]

        if result_message != "Functional TGD optimization step executed successfully.":
            raise RuntimeError(
                f"Server did not return any data for functional TGD optimization: "
                f"payload={payload!r}, response={response!r}"
            )

        # Update the momentum_buffer_list with the deserialized buffers
        momentum_buffer_list.clear()
        momentum_buffer_list.extend(des_momentum_buffer_list)

        logger.debug("Functional TGD optimization executed successfully")
    except Exception as e:
        logger.error(f"Failed to run functional TGD optimization on the server: {e!r}")

        # Clear all pending data flags to avoid deadlocks
        for p in params:
            p._pending_data = False
            logger.debug(
                f"Marked variable {p.variable_id!r} as not pending for data update "
                f"after error."
            )

        raise