Trainer#

Warning

Before running any code, ensure you are logged in to the Afnio backend (afnio login). See Logging in to Afnio Backend for details.

Tip

For full control over every training and evaluation step, use a manual optimization loop as shown in Optimization Loop. For most workflows, however, Trainer is the fastest and easiest way to train and evaluate agents in Afnio.

Afnio’s Trainer module provides a high-level interface for training, validating, and testing agents. It automates many aspects of the optimization loop, including experiment tracking, metric logging, checkpointing, and cost monitoring. If you want to get started quickly and focus on agent design rather than boilerplate training code, Trainer is the recommended approach.


Why Use Trainer?#

  • Automatic Experiment Tracking: Metrics (loss, accuracy, etc.) are logged to Tellurio Studio, giving you interactive plots and dashboards for each run.

  • LM Cost Tracking: The Trainer automatically tracks and logs the cost of all language model (LM) calls, so you can monitor and optimize your budget.

  • Progress Bars and Summaries: Built-in progress bars and agent summaries make it easy to follow training progress and inspect agent state.

  • Checkpointing: Trainer automatically saves agent checkpoints during training, allowing you to resume experiments, analyze results, or deploy the best-performing agent to production.

  • Less Boilerplate: You don’t need to write custom loops for training, validation, or testing—just implement a few methods in your agent.


Preparing Your Agent and Data#

Before using the Trainer module, you should define your agent, dataset, and data loaders. See Datasets and DataLoaders and Build the Agent or Workflow for details.

This example uses the same agent and dataset as in the Optimization Loop page, but demonstrates training with the Trainer module for a streamlined workflow.

import json
import os
import re

import afnio
import afnio.cognitive as cog
import afnio.cognitive.functional as F
import afnio.tellurio as te
from afnio.models.openai import AsyncOpenAI
from afnio.utils.data import DataLoader, WeightedRandomSampler
from afnio.utils.datasets import FacilitySupport

os.environ["OPENAI_API_KEY"] = "sk-..."  # Replace with your actual key


def compute_sample_weights(data):
    with te.suppress_variable_notifications():
        labels = [y.data for _, (_, y, _) in data]
        counts = {label: labels.count(label) for label in set(labels)}
        total = len(data)
    return [total / counts[label] for label in labels]


training_data = FacilitySupport(split="train", root="data")
validation_data = FacilitySupport(split="val", root="data")
test_data = FacilitySupport(split="test", root="data")

weights = compute_sample_weights(training_data)
sampler = WeightedRandomSampler(
    weights, num_samples=len(training_data), replacement=True
)

BATCH_SIZE = 33
train_dataloader = DataLoader(training_data, sampler=sampler, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(validation_data, batch_size=BATCH_SIZE, seed=42)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, seed=42)

SENTIMENT_RESPONSE_FORMAT = {
    "type": "json_schema",
    "json_schema": {
        "strict": True,
        "name": "sentiment_response_schema",
        "schema": {
            "type": "object",
            "properties": {
                "sentiment": {
                    "type": "string",
                    "enum": ["positive", "neutral", "negative"],
                },
            },
            "additionalProperties": False,
            "required": ["sentiment"],
        },
    },
}

afnio.set_backward_model_client(
    "openai/gpt-5",
    completion_args={
        "temperature": 1.0,
        "max_completion_tokens": 32000,
        "reasoning_effort": "low",
    },
)
fw_model_client = AsyncOpenAI()
optim_model_client = AsyncOpenAI()


class FacilitySupportAnalyzer(cog.Module):

    def __init__(self):
        super().__init__()
        self.sentiment_task = cog.Parameter(
            data="Read the provided message and determine the sentiment.",
            role="system prompt for sentiment classification",
            requires_grad=True,
        )
        self.sentiment_user = afnio.Variable(
            data="**Message:**\n\n{message}\n\n",
            role="input template to sentiment classifier",
        )
        self.sentiment_classifier = cog.ChatCompletion()

    def forward(self, fwd_model, inputs, **completion_args):
        sentiment_messages = [
            {"role": "system", "content": [self.sentiment_task]},
            {"role": "user", "content": [self.sentiment_user]},
        ]
        return self.sentiment_classifier(
            fwd_model,
            sentiment_messages,
            inputs=inputs,
            response_format=SENTIMENT_RESPONSE_FORMAT,
            **completion_args,
        )

    def training_step(self, batch, batch_idx):
        X, y = batch
        _, gold_sentiment, _ = y
        pred_sentiment = self(
            fw_model_client,
            inputs={"message": X},
            model="gpt-4.1-nano",
            temperature=0.0,
        )
        pred_sentiment.data = [
            json.loads(re.sub(r"^```json\n|\n```$", "", item))["sentiment"].lower()
            for item in pred_sentiment.data
        ]
        loss = F.exact_match_evaluator(pred_sentiment, gold_sentiment)
        return {"loss": loss, "accuracy": loss[0].data / len(gold_sentiment.data)}

    def validation_step(self, batch, batch_idx):
        return self.training_step(batch, batch_idx)

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        constraints = [
            afnio.Variable(
                data="The improved variable must never include or reference the characters `{` or `}`. Do not output them, mention them, or describe them in any way.",
                role="optimizer constraint",
            )
        ]
        optimizer = afnio.optim.TGD(
            self.parameters(),
            model_client=optim_model_client,
            constraints=constraints,
            momentum=3,
            model="gpt-5",
            temperature=1.0,
            max_completion_tokens=32000,
            reasoning_effort="low",
        )
        return optimizer


agent = FacilitySupportAnalyzer()

Output:

INFO     : API key provided and stored securely in local keyring.
INFO     : Currently logged in as 'username' to 'http://localhost'. Use `afnio login --relogin` to force relogin.
INFO     : Project with slug 'my-project' already exists in namespace 'username'.
Downloading https://raw.githubusercontent.com/meta-llama/llama-prompt-ops/refs/heads/main/use-cases/facility-support-analyzer/dataset.json to data/FacilitySupport/raw/dataset.json
Downloading ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 383.7/383.7 kB 1.1 MB/s 0:00:00

Using downloaded and verified file: data/FacilitySupport/raw/dataset.json

Using downloaded and verified file: data/FacilitySupport/raw/dataset.json


========== [Consent Required] ==========
Your LM model API key will be sent to the Tellurio server for remote model execution and backpropagation.

Please review the following:
  • Tellurio will never use your key except to execute your requests.
  • The key is only used during your session.
  • The key is never stored and is removed from memory when your session ends.

Do you consent to share your API key with the server?
Type 'yes' to allow just this time, or 'always' to allow and remember your choice for all future sessions.
Consent [yes/always/no]: always

To work with Trainer, your agent (which should extend cog.Module) must implement the following methods:

Method

Purpose

Input

Output/Return Value

training_step(batch, batch_idx)

Defines logic for each training batch

A batch of training data and its index batch_idx

  • Dict with keys like "loss" and "accuracy", where:
    • "loss" is a tuple (score, explanation) of afnio.Variable objects
    • "accuracy" is a numeric value
  • Or, a tuple (score, explanation) of afnio.Variable objects representing the loss

validation_step(batch, batch_idx)

Defines logic for each validation batch

A batch of validation data and its index batch_idx

Same format as training_step

test_step(batch, batch_idx)

Defines logic for each test batch

A batch of test data and its index batch_idx

Same format as training_step

configure_optimizers

Returns optimizer(s) for training

None

Optimizer instance(s), e.g., afnio.optim.TGD

Tip

  • The batch input is typically a tuple (X, y) or a dictionary, depending on your DataLoader.

  • The "loss" output must be a tuple of two afnio.Variable objects: the numeric score and the explanation (used for gradients).


Train and Evaluate with Trainer#

This section demonstrates a typical end-to-end workflow using Trainer.

Configure Trainer#

Start by creating a remote Run with te.init(...) — this registers a Run on Tellurio Studio so metrics, LM usage and cost are recorded in the cloud. Next, instantiate a Trainer to control the number of epochs, progress reporting, automatic checkpointing and logging behavior; the Trainer writes checkpoint files to your local checkpoints/ folder. When calling training or evaluation methods, provide the LM clients you use for forward, backward and optimization calls in the llm_clients list so the Trainer can attribute model usage and cost correctly and granularly.

Before training you can optionally call trainer.test(...) on the untrained agent to establish a baseline.

from afnio.trainer import Trainer

# Create Tellurio Run — replace "username" with your Tellurio Studio username slug
run = te.init("username", "my-project")

trainer = Trainer(max_epochs=5, enable_agent_summary=False)

# LM cost tracking requires passing a list of LM clients used during training
llm_clients = [
    fw_model_client,
    afnio.get_backward_model_client(),
    optim_model_client,
]

# Test baseline performance
trainer.test(
    agent=agent,
    test_dataloader=test_dataloader,
    llm_clients=llm_clients,
)

Output:

INFO     : Project with slug 'my-project' already exists in namespace 'dmpiergiacomo'.
INFO     : Run 'epic_kebab_599' created successfully at: https://platform.tellurio.ai/dmpiergiacomo/projects/my-project/runs/epic-kebab-599/
Testing
[Test] 68/68 ━━━━━━━━━━━━━━━━━━━━ 0:00:07 tot_cost: $0.0024  - test_loss: 17.3333 - test_accuracy: 0.6818

Note

In Afnio, the backward (optimization) graph is constructed on the remote Afnio backend (hosted on Tellurio Studio) and is available only when an active Run exists. Therefore, perform backpropagation and gradient generation inside a Run created with te.init(...). You can create a Run as a context manager (with te.init("your-username", "project"):), or by calling run = te.init("your-username", "project") and later run.finish(). Only call .backward() after the Run is active. See Runs and Experiments for details.

Train & Validate (trainer.fit)#

Run trainer.fit(...) to execute the full training loop: it runs per-epoch training batches, performs validation after each epoch, logs metrics and LM costs, and saves checkpoints.

# Train and validate
trainer.fit(
    agent=agent,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    llm_clients=llm_clients,
)

Output:

Epoch 1/5
  [Training] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:01:41 1.2m/step tot_cost: $0.0104 train_loss: 24.5000 - train_accuracy: 0.7424 - val_loss: 22.0000 - val_accuracy: 0.6667
[Validation] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:01:45

Epoch 2/5
  [Training] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:01:18 0.7m/step tot_cost: $0.0223 train_loss: 31.5000 - train_accuracy: 0.9545 - val_loss: 25.5000 - val_accuracy: 0.7727
[Validation] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:01:25

Epoch 3/5
  [Training] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:03:01 2.3m/step tot_cost: $0.0353 train_loss: 27.0000 - train_accuracy: 0.8182 - val_loss: 21.0000 - val_accuracy: 0.6364
[Validation] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:03:05

Epoch 4/5
  [Training] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:03:12 2.4m/step tot_cost: $0.0479 train_loss: 23.5000 - train_accuracy: 0.7121 - val_loss: 20.0000 - val_accuracy: 0.6061
[Validation] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:03:15

Epoch 5/5
  [Training] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:03:00 2.2m/step tot_cost: $0.0628 train_loss: 23.0000 - train_accuracy: 0.6970 - val_loss: 22.5000 - val_accuracy: 0.6818
[Validation] 66/66 ━━━━━━━━━━━━━━━━━━━━ 0:03:04

Select & Load Best Checkpoint#

After training, checkpoints are saved to checkpoints/. Choose the best file (typically the one with highest val_accuracy or val_loss) and restore it into a fresh agent instance with load_state_dict(..., model_clients=...).

import urllib.request

# Only run these lines if you want to download our reference checkpoint
checkpoint_path = "checkpoints/checkpoint_epoch2_20250912-190039.hf"
if not os.path.exists(checkpoint_path):
    os.makedirs("checkpoints", exist_ok=True)
    url = "https://github.com/Tellurio-AI/tutorials/raw/main/facility_support/checkpoints/checkpoint_epoch2_20250912-190039.hf"
    urllib.request.urlretrieve(url, checkpoint_path)

# Replace "checkpoint_path" with the path of your best local checkpoint,
# or use our reference checkpoint (downloaded with the previous cell)
checkpoint = afnio.load(checkpoint_path)
best_agent = FacilitySupportAnalyzer()
best_agent.load_state_dict(
    checkpoint["agent_state_dict"],
    model_clients={
        "sentiment_classifier.forward_model_client": fw_model_client,
    },
)

Evaluate Restored Agent (trainer.test)#

Run trainer.test(...) on the optimized agent (restored from the checkpoint) to report final test metrics, and call run.finish() to mark the Run as completed in Tellurio Studio.

# Test trained agent
trainer.test(agent=best_agent, test_dataloader=test_dataloader, llm_clients=llm_clients)

run.finish()

Output:

Testing
[Test] 68/68 ━━━━━━━━━━━━━━━━━━━━ 0:00:04 tot_cost: $0.0697  - test_loss: 19.3333 - test_accuracy: 0.8990
INFO     : Run 'epic_kebab_599' marked as COMPLETED.

Key Differences from Manual Optimization Loop#

  • Less Code: You don’t need to write explicit loops for training, validation, or testing.

  • Automatic Logging: All metrics and LM costs are logged to Tellurio Studio, giving you interactive plots and experiment tracking out of the box.

  • Checkpointing: Trainer saves checkpoints automatically, so you can resume or analyze experiments later.

  • Progress Bar: Trainer provides a rich progress bar and agent summary for each run.

  • Cost Tracking: LM usage and cost are tracked and logged automatically.

If you want full control over every step, you can still use a manual optimization loop as shown in Optimization Loop. For most workflows, however, Trainer is the fastest and easiest way to train and evaluate agents in Afnio.


Further Reading#