Skip to content

afnio.cognitive.modules.module

afnio.cognitive.modules.module.STEP_OUTPUT = Optional[Union[Tuple[Variable, Variable], Mapping[str, Any]]] module-attribute

The expected return type of the methods used by the Trainer for training and evaluation.

The training_step(), validation_step(), and test_step() methods of a Module are expected to return either:

  • A tuple of two Variables: the evaluation score (a Variable containing the loss value) and the explanation (a Variable containing a string explanation of the evaluation result).
  • A dictionary that includes any keys, but must include the key 'loss' containing a tuple of two Variables as described above. The dictionary can include additional keys for logging or other purposes, but the presence of the 'loss' key with the appropriate value is required for the Trainer to perform optimization steps correctly.
  • None, which indicates that the current batch should be skipped and no optimization should be performed for this batch.

afnio.cognitive.modules.module.Module

Base class for all LM pipeline modules.

Your pipeline should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes.

Examples:

>>> import afnio as hf
>>> import afnio.cognitive as cog
>>> import torch.cognitive.functional as F
>>> from afnio.models.openai import OpenAI
>>> from afnio import set_backward_model_client
>>>
>>> fwd_model_client = OpenAI()
>>> fwd_model_args = {"model": "gpt-4o", "temperature": 0.7}
>>> set_backward_model_client("openai/gpt-4o")
>>>
>>> class MedQA(cog.Module):
...     def __init__(self):
...         super().__init__()
...         self.system_prompt = cog.Parameter(
...             data="You are a doctor. Only answer medical questions on these areas:",
...             role="system prompt",
...             requires_grad=True,
...         )
...         self.topics = cog.Parameter(
...             data="Dermatology and Cardiology",
...             role="medical topics",
...             requires_grad=False,
...         )
...         self.epilogue = afnio.Variable(
...             data="\nThank you for your query.",
...             role="response preamble",
...         )
...         self.chat = cog.ChatCompletion()
>>>
>>>     def forward(self, fwd_model, user_query, inputs, **completion_args):
...         messages = [
...             {"role": "system", "content": [self.system_prompt, self.topics]},
...             {"role": "user", "content": [user_query]},
...         ]
...         response = self.chat(fwd_model, messages, inputs, **completion_args)
...         return F.Add(response, self.epilogue)

Submodules assigned in this way will be registered with the Module. For example, in the MedQA example above, self.chat (the ChatCompletion() instance) is the submodule that gets registered with the Module MedQA.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Attributes:

Name Type Description
training bool

Boolean represents whether this module is in training or evaluation mode. Defaults to True.

automatic_optimization bool

Boolean that determines whether optimization steps handled automatically by the Trainer.fit() or manually by the user. Defaults to True.

Source code in afnio/cognitive/modules/module.py
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
class Module:
    r"""
    Base class for all LM pipeline modules.

    Your pipeline should also subclass this class.

    Modules can also contain other Modules, allowing to nest them in
    a tree structure. You can assign the submodules as regular attributes.

    Examples:
        >>> import afnio as hf
        >>> import afnio.cognitive as cog
        >>> import torch.cognitive.functional as F
        >>> from afnio.models.openai import OpenAI
        >>> from afnio import set_backward_model_client
        >>>
        >>> fwd_model_client = OpenAI()
        >>> fwd_model_args = {"model": "gpt-4o", "temperature": 0.7}
        >>> set_backward_model_client("openai/gpt-4o")
        >>>
        >>> class MedQA(cog.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.system_prompt = cog.Parameter(
        ...             data="You are a doctor. Only answer medical questions on these areas:",
        ...             role="system prompt",
        ...             requires_grad=True,
        ...         )
        ...         self.topics = cog.Parameter(
        ...             data="Dermatology and Cardiology",
        ...             role="medical topics",
        ...             requires_grad=False,
        ...         )
        ...         self.epilogue = afnio.Variable(
        ...             data="\nThank you for your query.",
        ...             role="response preamble",
        ...         )
        ...         self.chat = cog.ChatCompletion()
        >>>
        >>>     def forward(self, fwd_model, user_query, inputs, **completion_args):
        ...         messages = [
        ...             {"role": "system", "content": [self.system_prompt, self.topics]},
        ...             {"role": "user", "content": [user_query]},
        ...         ]
        ...         response = self.chat(fwd_model, messages, inputs, **completion_args)
        ...         return F.Add(response, self.epilogue)

    Submodules assigned in this way will be registered with the [`Module`][.].
    For example, in the `MedQA` example above, `self.chat` (the `ChatCompletion()`
    instance) is the submodule that gets registered with the Module `MedQA`.

    Note:
        As per the example above, an `__init__()` call to the parent class
        must be made before assignment on the child.

    Attributes:
        training: Boolean represents whether this module is in training or evaluation
            mode. Defaults to `True`.
        automatic_optimization: Boolean that determines whether optimization steps
            handled automatically by the
            [`Trainer.fit()`][afnio.trainer.trainer.Trainer.fit]
            or manually by the user. Defaults to `True`.
    """  # noqa: E501

    _version: int = 1
    """This allows better backward support for [`load_state_dict`][.load_state_dict].
    In [`state_dict`][.state_dict], the version number will be saved as in the attribute
    `_metadata` of the returned state dict, and thus pickled. `_metadata` is a
    dictionary with keys that follow the naming convention of state dict. See
    [`_load_from_state_dict`][. _load_from_state_dict] on how to use this information
    in loading.

    If new parameters/buffers are added/removed from a module, this number shall
    be bumped, and the module's [`_load_from_state_dict`][._load_from_state_dict] method
    can compare the version number and do appropriate changes if the state dict is from
    before the change.
    """

    training: bool
    """ Boolean represents whether this module is in training or evaluation mode.
    Defaults to `True`.
    """
    automatic_optimization: bool
    """Boolean that determines whether optimization steps handled automatically
    by the [`Trainer.fit()`][afnio.trainer.trainer.Trainer.fit] or manually by the user.
    Defaults to `True`.

    If `True`, the [`Trainer.fit()`][afnio.trainer.trainer.Trainer.fit] method will
    automatically call `optimizer.clear_grad()`, `explanation.backward()`, and
    `optimizer.step()`. If `False`, the user must perform backpropagation and optimizer
    steps manually in the [`training_step()`][.training_step] method.
    """
    _optimizers: Optional[Union[Optimizer, List[Optimizer]]]
    _parameters: Dict[str, Optional[Parameter]]
    _buffers: Dict[str, Optional[Variable]]
    _non_persistent_buffers_set: Set[str]
    _chats: Dict[str, Optional[MultiTurnMessages]]
    _modules: Dict[str, Optional["Module"]]
    _models: Dict[str, Optional[BaseModel]]
    _completion_configs: Dict[str, Optional[Dict[str, Any]]]
    _functions: Dict[str, Optional[Callable]]

    def __init__(self, *args, **kwargs) -> None:
        """Initialize internal Module state.

        Calls `super().__setattr__('a', a)` instead of the typical `self.a = a`
        to avoid `Module.__setattr__` overhead. Module's `__setattr__` has special
        handling for parameters, submodules, buffers, multi-turn chats, language model
        clients and completion configurations but simply calls into
        `super().__setattr__` for all other attributes.
        """
        super().__setattr__("training", True)
        super().__setattr__("automatic_optimization", True)
        super().__setattr__("_optimizers", None)
        super().__setattr__("_parameters", OrderedDict())
        super().__setattr__("_buffers", OrderedDict())
        super().__setattr__("_non_persistent_buffers_set", set())
        super().__setattr__("_chats", OrderedDict())
        super().__setattr__("_modules", OrderedDict())
        super().__setattr__("_models", OrderedDict())
        super().__setattr__("_completion_configs", OrderedDict())
        super().__setattr__("_functions", OrderedDict())

    def forward(self, *args, **kwargs) -> Any:
        """Define the computation performed at every call.

        Should be overridden by all subclasses.

        Note:
            One should invoke the [`Module`][..] instance (`Module.__call__` method)
            instead of directly calling `Module.forward()`. This way hooks are
            registered and run.
        """
        raise NotImplementedError(
            f"Module [{type(self).__name__}] is missing the required "
            '"forward" function.'
        )

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def _get_name(self):
        return self.__class__.__name__

    @abstractmethod
    def extra_repr(self) -> str:
        """Set the extra representation of the module.

        To print customized extra information, you should re-implement this method in
        your own modules. Both single-line and multi-line strings are acceptable.

        Returns:
            A string containing the extra representation of the module, which will be \
            included in the module's `__repr__` output.
        """
        return ""

    def __repr__(self):
        def _addindent(s_, numSpaces):
            s = s_.split("\n")
            # don't do anything for single-line stuff
            if len(s) == 1:
                return s_
            first = s.pop(0)
            s = [(numSpaces * " ") + line for line in s]
            s = "\n".join(s)
            s = first + "\n" + s
            return s

        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split("\n")
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append("(" + key + "): " + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + "("
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += "\n  " + "\n  ".join(lines) + "\n"

        main_str += ")"
        return main_str

    def __dir__(self):
        module_attrs = dir(self.__class__)
        attrs = list(self.__dict__.keys())
        parameters = list(self._parameters.keys())
        modules = list(self._modules.keys())
        buffers = list(self._buffers.keys())
        chats = list(self._chats.keys())
        models = list(self._models.keys())
        completion_config = list(self._completion_configs.keys())
        functions = list(self._functions.keys())
        keys = (
            module_attrs
            + attrs
            + parameters
            + modules
            + buffers
            + chats
            + models
            + completion_config
            + functions
        )

        # Eliminate attrs that are not legal Python variable names
        keys = [key for key in keys if not key[0].isdigit()]

        return sorted(keys)

    def register_buffer(
        self, name: str, variable: Optional[Variable], persistent: bool = True
    ) -> None:
        """Add a buffer to the module.

        This is typically used to register a buffer that should not to be
        considered an agent [`Parameter`][afnio.cognitive.parameter.Parameter].
        For example, [`LMJudgeEvaluator`][afnio.cognitive.modules.lm_judge_evaluator.LMJudgeEvaluator]'s
        `reduction_fn_purpose` is not a parameter, but is part of the module's state.
        Buffers, by default, are persistent and will be saved alongside parameters.
        This behavior can be changed by setting `persistent` to `False`. The only
        difference between a persistent buffer and a non-persistent buffer is that the
        latter will not be a part of this module's [`state_dict`][..state_dict].

        Buffers can be accessed as attributes using given names.

        Args:
            name: Name of the buffer. The buffer can be accessed from this module
                using the given name.
            variable: Buffer to be registered. If `None`, then operations that run
                on buffers are ignored. If `None`, the buffer is **not** included in
                the module's [`state_dict`][..state_dict].
            persistent: Whether the buffer is part of this module's
                [`state_dict`][..state_dict].

        Example::
            >>> self.register_buffer('reduction_fn_purpose', afnio.Variable(data="summation", role="reduction function purpose"))
        """  # noqa: E501
        if "_buffers" not in self.__dict__:
            raise AttributeError("Cannot assign buffer before Module.__init__() call.")
        elif not isinstance(name, str):
            raise TypeError(
                f"Buffer name should be a string. Got {type(name).__name__}."
            )
        elif "." in name:
            raise KeyError('Buffer name cannot contain ".".')
        elif name == "":
            raise KeyError('Buffer name cannot be empty string "".')
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError(f"Attribute '{name}' already exists.")
        elif variable is not None and not isinstance(variable, Variable):
            raise TypeError(
                f"Cannot assign '{type(variable).__name__}' object to buffer '{name}' "
                "(afnio.Variable or None required)."
            )
        else:
            self._buffers[name] = variable
            if persistent:
                self._non_persistent_buffers_set.discard(name)
            else:
                self._non_persistent_buffers_set.add(name)

    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
        """Add a parameter to the module.

        The parameter can be accessed as an attribute using given name.

        Args:
            name: Name of the parameter. The parameter can be accessed from this module
                using the given name.
            param: Parameter to be added to the module. If `None`, then operations that
                run on parameters are ignored. If `None`, the parameter is **not**
                included in the module's [`state_dict`][..state_dict].
        """
        if "_parameters" not in self.__dict__:
            raise AttributeError(
                "Cannot assign parameter before Module.__init__() call."
            )

        elif not isinstance(name, str):
            raise TypeError(
                f"Parameter name should be a string. Got {type(name).__name__}."
            )
        elif "." in name:
            raise KeyError('Parameter name cannot contain ".".')
        elif name == "":
            raise KeyError('Parameter name cannot be empty string "".')
        elif hasattr(self, name) and name not in self._parameters:
            raise KeyError(f"Attribute '{name}' already exists.")

        if param is None:
            self._parameters[name] = None
        elif not isinstance(param, Parameter):
            raise TypeError(
                f"Cannot assign '{type(param).__name__}' object to parameter '{name}' "
                "(afnio.cognitive.Parameter or None required)."
            )
        elif param.grad_fn:
            raise ValueError(
                f"Cannot assign non-leaf Variable to parameter '{name}'. Model "
                f"parameters must be created explicitly. To express '{name}' "
                "as a function of another Variable, compute the value in "
                "the forward() method."
            )
        else:
            self._parameters[name] = param

    def register_chat(self, name: str, messages: Optional[MultiTurnMessages]) -> None:
        """Add multi-turn chat messages to the module.

        The chat can be accessed as an attribute using given name.

        Args:
            name: Name of the chat. The chat can be accessed from this module
                using the given name.
            messages: Chat to be added to the module. If `None`, then operations that
                run on chats are ignored. If `None`, the chat is **not** included in
                the module's [`state_dict`][..state_dict].
        """
        if "_chats" not in self.__dict__:
            raise AttributeError("Cannot assign chat before Module.__init__() call.")

        elif not isinstance(name, str):
            raise TypeError(
                f"Chat name should be a string. " f"Got {type(name).__name__}."
            )
        elif "." in name:
            raise KeyError('Chat name cannot contain ".".')
        elif name == "":
            raise KeyError('Chat name cannot be empty string "".')
        elif hasattr(self, name) and name not in self._chats:
            raise KeyError(f"Attribute '{name}' already exists.")

        if messages is None:
            self._chats[name] = None
        elif not is_multi_turn_messages(messages):
            raise TypeError(
                f"Cannot assign '{type(messages).__name__}' object to chat '{name}' "
                "(afnio.MultiTurnMessages or None required)."
            )
        else:
            self._chats[name] = messages

    def register_model(self, name: str, model: Optional[BaseModel]) -> None:
        """Add language model the module.

        The language model can be accessed as an attribute using given name.

        Args:
            name: Name of the model. The model can be accessed from this module
                using the given name.
            model: Model to be added to the module. If `None`, then operations that run
                on models are ignored. If `None`, the model is **not** included in
                the module's [`state_dict`][..state_dict].
        """
        if "_models" not in self.__dict__:
            raise AttributeError("Cannot assign model before Module.__init__() call.")

        elif not isinstance(name, str):
            raise TypeError(
                f"Model name should be a string. " f"Got {type(name).__name__}."
            )
        elif "." in name:
            raise KeyError('Model name cannot contain ".".')
        elif name == "":
            raise KeyError('Model name cannot be empty string "".')
        elif hasattr(self, name) and name not in self._models:
            raise KeyError(f"Attribute '{name}' already exists.")

        if model is None:
            self._models[name] = None
        elif not isinstance(model, BaseModel):
            raise TypeError(
                f"Cannot assign '{type(model).__name__}' object to model '{name}' "
                "(afnio.models.BaseModel or None required)."
            )
        else:
            self._models[name] = model

    def register_completion_config(
        self, name: str, args: Optional[Dict[str, Any]]
    ) -> None:
        """Register completion-specific arguments for text generation.

        This method allows dynamic storage of completion-related parameters
        such as `temperature`, `max_tokens`, `top_p`, etc.

        Args:
            name: Name of the completion argument set.
            args: Dictionary of completion arguments. If `None`, the argument is **not**
                included in the module's [`state_dict`][..state_dict].
        """
        if not isinstance(name, str):
            raise TypeError(
                f"Completion config name should be a string. Got {type(name).__name__}."
            )
        elif "." in name:
            raise KeyError('Completion config name cannot contain ".".')
        elif name == "":
            raise KeyError('Completion config name cannot be an empty string "".')
        elif hasattr(self, name) and name not in self._completion_configs:
            raise KeyError(f"Attribute '{name}' already exists.")

        if args is None:
            self._completion_configs[name] = None
        elif not isinstance(args, dict):
            raise TypeError(
                f"Cannot assign '{type(args).__name__}' object to "
                f"completion config '{name}' (dict or None required)."
            )
        else:
            self._completion_configs[name] = args

    def register_function(self, name: str, func: Optional[FunctionType]) -> None:
        """Add a function to the module.

        The function can be accessed as an attribute using given name.

        Args:
            name: Name of the function. The function can be accessed from this module
                using the given name.
            func: A standard Python function (i.e., a def-defined function, not a lambda
                or callable object) that can be pickled and registered for later
                execution. If `None`, the function is unregistered. If `None`, the
                function is **not** included in the
                module's [`state_dict`][..state_dict].
        """
        if "_functions" not in self.__dict__:
            raise AttributeError(
                "Cannot assign function before Module.__init__() call."
            )
        elif not isinstance(name, str):
            raise TypeError(
                f"Function name should be a string. Got {type(name).__name__}."
            )
        elif "." in name:
            raise KeyError('Function name cannot contain ".".')
        elif name == "":
            raise KeyError('Function name cannot be empty string "".')
        elif hasattr(self, name) and name not in self._functions:
            raise KeyError(f"Attribute '{name}' already exists.")

        if func is None:
            self._functions[name] = None
        else:
            _validate_function(func)  # Validate the function before registering
            self._functions[name] = func

    def register_module(self, name: str, module: Optional["Module"]) -> None:
        """Add a child module to the current module.

        This method explicitly adds a child module to the current module's hierarchy.
        The child module can then be accessed as an attribute using the given name
        and will be registered in the `_modules` dictionary.

        **When to use**:
            - Use `register_module()` when dynamically adding submodules at runtime,
                especially when the submodule name is determined programmatically.
            - This can be useful for creating flexible and modular architectures.

        **When it's unnecessary**:
            - Directly assigning the module to an attribute (e.g.,
                `self.module_name = SubModule()`) automatically registers it, so using
                `register_module()` is unnecessary in such cases.

        Args:
            name: Name of the child module. The child module can be accessed from
                this module using the given name.
            module: Child module to be added to the module.

        Raises:
            TypeError: If `module` is not a subclass of `Module` or
                if `name` is not a string.
            KeyError: If `name` is already an attribute of the module but not
                in `_modules`, or if `name` contains invalid characters
                such as `'.'` or is empty.

        Examples:
            >>> class DynamicPipeline(cog.Module):
            >>>     def __init__(self):
            >>>         super().__init__()
            >>>         # Dynamically add submodules
            >>>         for i in range(3):
            >>>             self.register_module(f"layer_{i}", cog.Module())
            >>>
            >>> pipeline = DynamicPipeline()
            >>> print(pipeline._modules.keys())
            odict_keys(['layer_0', 'layer_1', 'layer_2'])

        Note:
            If assigning submodules using standard attribute assignment
            (e.g., `self.submodule = SubModule()`), calling `register_module()`
            explicitly is not required. Direct assignment automatically registers
            the module.
        """
        if not isinstance(module, Module) and module is not None:
            raise TypeError(
                f"'{type(module).__name__}' is not a valid Module subclass."
            )
        elif not isinstance(name, str):
            raise TypeError(
                f"Module name must be a string, but got '{type(name).__name__}'."
            )
        elif hasattr(self, name) and name not in self._modules:
            raise KeyError(
                f"Attribute '{name}' already exists and "
                f"cannot be used as a module name."
            )
        elif "." in name:
            raise KeyError(f"Module name cannot contain '.', but got: '{name}'.")
        elif name == "":
            raise KeyError('Module name cannot be an empty string ""')
        self._modules[name] = module

    def __getattr__(self, name: str) -> Any:
        if "_parameters" in self.__dict__:
            _parameters = self.__dict__["_parameters"]
            if name in _parameters:
                return _parameters[name]
        if "_buffers" in self.__dict__:
            _buffers = self.__dict__["_buffers"]
            if name in _buffers:
                return _buffers[name]
        if "_chats" in self.__dict__:
            _chats = self.__dict__["_chats"]
            if name in _chats:
                return _chats[name]
        if "_modules" in self.__dict__:
            modules = self.__dict__["_modules"]
            if name in modules:
                return modules[name]
        if "_models" in self.__dict__:
            _models = self.__dict__["_models"]
            if name in _models:
                return _models[name]
        if "_completion_configs" in self.__dict__:
            _completion_configs = self.__dict__["_completion_configs"]
            if name in _completion_configs:
                return _completion_configs[name]
        if "_functions" in self.__dict__:
            _functions = self.__dict__["_functions"]
            if name in _functions:
                return _functions[name]
        raise AttributeError(
            f"'{type(self).__name__}' object has no attribute '{name}'"
        )

    def __setattr__(
        self, name: str, value: Union[Variable, "Module", MultiTurnMessages, BaseModel]
    ) -> None:
        def remove_from(*dicts_or_sets):
            for d in dicts_or_sets:
                if name in d:
                    if isinstance(d, dict):
                        del d[name]
                    else:
                        d.discard(name)

        params = self.__dict__.get("_parameters")
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "Cannot assign parameters before Module.__init__() call."
                )
            remove_from(
                self.__dict__,
                self._buffers,
                self._modules,
                self._non_persistent_buffers_set,
                self._chats,
                self._models,
                self._completion_configs,
                self._functions,
            )
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError(
                    f"Cannot assign '{type(value).__name__}' as parameter '{name}' "
                    "(afnio.cognitive.Parameter or None expected)."
                )
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get("_modules")
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "Cannot assign module before Module.__init__() call."
                    )
                remove_from(
                    self.__dict__,
                    self._parameters,
                    self._buffers,
                    self._non_persistent_buffers_set,
                    self._chats,
                    self._models,
                    self._completion_configs,
                    self._functions,
                )
                modules[name] = value  # TODO: use `register_*` method?
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError(
                        f"Cannot assign '{type(value).__name__}' as child module "
                        f"'{name}' (afnio.cognitive.Module or None expected)."
                    )
                modules[name] = value  # TODO: use `register_*` method?
            else:
                chats = self.__dict__.get("_chats")
                if is_multi_turn_messages(value):
                    if chats is None:
                        raise AttributeError(
                            "Cannot assign chat before Module.__init__() call."
                        )
                    remove_from(
                        self.__dict__,
                        self._parameters,
                        self._modules,
                        self._buffers,
                        self._non_persistent_buffers_set,
                        self._models,
                        self._completion_configs,
                        self._functions,
                    )
                    chats[name] = value  # TODO: use `register_*` method?
                elif chats is not None and name in chats:
                    if value is not None:
                        raise TypeError(
                            f"Cannot assign '{type(value).__name__}' as chat '{name}' "
                            "(afnio.MultiTurnMessages or None expected)."
                        )
                    chats[name] = value  # TODO: use `register_*` method?
                else:
                    models = self.__dict__.get("_models")
                    if isinstance(value, BaseModel):
                        if models is None:
                            raise AttributeError(
                                "Cannot assign model before Module.__init__() call."
                            )
                        remove_from(
                            self.__dict__,
                            self._parameters,
                            self._modules,
                            self._buffers,
                            self._non_persistent_buffers_set,
                            self._chats,
                            self._completion_configs,
                            self._functions,
                        )
                        models[name] = value  # TODO: use `register_*` method?
                    elif models is not None and name in models:
                        if value is not None:
                            raise TypeError(
                                f"Cannot assign '{type(value).__name__}' "
                                f"as model '{name}' "
                                "(afnio.models.BaseModel or None expected)."
                            )
                        models[name] = value  # TODO: use `register_*` method?
                    else:
                        completion_configs = self.__dict__.get("_completion_configs")
                        if isinstance(value, dict):
                            if completion_configs is None:
                                raise AttributeError(
                                    "Cannot assign completion config "
                                    "before Module.__init__() call."
                                )
                            remove_from(
                                self.__dict__,
                                self._parameters,
                                self._modules,
                                self._buffers,
                                self._non_persistent_buffers_set,
                                self._chats,
                                self._models,
                                self._functions,
                            )
                            completion_configs[name] = (
                                value  # TODO: use `register_*` method?
                            )
                        elif (
                            completion_configs is not None
                            and name in completion_configs
                        ):
                            if value is not None:
                                raise TypeError(
                                    f"Cannot assign '{type(value).__name__}' "
                                    f"as completion config '{name}' "
                                    "(dict or None expected)."
                                )
                            completion_configs[name] = (
                                value  # TODO: use `register_*` method?
                            )
                        else:
                            functions = self.__dict__.get("_functions")
                            if _is_valid_function(value):
                                if functions is None:
                                    raise AttributeError(
                                        "Cannot assign function "
                                        "before Module.__init__() call."
                                    )
                                remove_from(
                                    self.__dict__,
                                    self._parameters,
                                    self._modules,
                                    self._buffers,
                                    self._non_persistent_buffers_set,
                                    self._chats,
                                    self._models,
                                    self._completion_configs,
                                )
                                functions[name] = (
                                    value  # TODO: use `register_*` method?
                                )
                            elif functions is not None and name in functions:
                                if value is not None:
                                    raise TypeError(
                                        f"Cannot assign '{type(value).__name__}' "
                                        f"as function '{name}' "
                                        "(standalone function or None expected)."
                                    )
                                functions[name] = (
                                    value  # TODO: use `register_*` method?
                                )
                            else:
                                buffers = self.__dict__.get("_buffers")
                                if buffers is not None and name in buffers:
                                    if value is not None and not isinstance(
                                        value, Variable
                                    ):
                                        raise TypeError(
                                            f"Cannot assign '{type(value).__name__}' "
                                            f"as buffer '{name}' "
                                            f"(afnio.Variable or None expected)."
                                        )
                                    buffers[name] = (
                                        value  # TODO: use `register_*` method?
                                    )
                                else:
                                    super().__setattr__(name, value)

    def __delattr__(self, name):
        if name in self._parameters:
            del self._parameters[name]
        elif name in self._buffers:
            del self._buffers[name]
            self._non_persistent_buffers_set.discard(name)
        elif name in self._modules:
            del self._modules[name]
        elif name in self._chats:
            del self._chats[name]
        elif name in self._models:
            del self._models[name]
        elif name in self._completion_configs:
            del self._completion_configs[name]
        elif name in self._functions:
            del self._functions[name]
        else:
            super().__delattr__(name)

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        """Save module state to the `destination` dictionary.

        The `destination` dictionary will contain the state
        of the module, but not its descendants. This is called on every
        submodule in [`state_dict`][..state_dict].

        In rare cases, subclasses can achieve class-specific behavior by
        overriding this method with custom logic.

        Args:
            destination: A dict where state will be stored.
            prefix: The prefix for parameters, buffers, chats, models,
                completion configs and functions used in this module.
            keep_vars: Whether to keep Variables in the state dict as is or detach them.
                Detaching is performed by default to avoid unexpected side effects
                in the user code.
        """
        for name, param in self._parameters.items():
            if param is not None:
                destination[prefix + name] = param if keep_vars else param.detach()
        for name, buf in self._buffers.items():
            if buf is not None and name not in self._non_persistent_buffers_set:
                destination[prefix + name] = buf if keep_vars else buf.detach()
        for name, chat in self._chats.items():
            if chat is not None:
                detached_chat = [
                    {
                        "role": message["role"],
                        "content": [
                            var if keep_vars else var.detach()
                            for var in message["content"]
                        ],
                    }
                    for message in chat
                ]
                destination[prefix + name] = detached_chat
        for name, model in self._models.items():
            if model is not None:
                # Always trigger custom __deepcopy__
                destination[prefix + name] = deepcopy(model)
        for name, config in self._completion_configs.items():
            if config is not None:
                destination[prefix + name] = config
        for name, func in self._functions.items():
            if func is not None:
                destination[prefix + name] = func

        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
        if (
            getattr(self.__class__, "get_extra_state", Module.get_extra_state)
            is not Module.get_extra_state
        ):
            destination[extra_state_key] = self.get_extra_state()

    # The user can pass an optional arbitrary mappable object to `state_dict`, in which
    # case `state_dict` returns back that same object. But if they pass nothing, an
    # `OrderedDict` is created and returned.
    T_destination = TypeVar("T_destination", bound=Dict[str, Any])

    def state_dict(
        self,
        *,
        destination: T_destination = None,
        prefix: str = "",
        keep_vars: bool = False,
    ) -> T_destination:
        """Return a dictionary containing references to the whole state of the module.

        Parameters, persistent buffers (e.g. running averages), multi-turn chats,
        models, completion configs and functions are included. Keys are corresponding
        parameter, buffer, chat, model, completion config and function names.
        Parameters, buffers, chats, models, completion configs and functions
        set to `None` are not included.

        Note:
            The returned object is a shallow copy. It contains references
            to the module's parameters, buffers, chats, models, completion configs
            and functions.

        Warning:
            Please avoid the use of argument `destination` as it is not
            designed for end-users.

        Args:
            destination (dict, optional): If provided, the state of module will
                be updated into the dict and the same object is returned.
                Otherwise, an `OrderedDict` will be created and returned.
                Default: `None`.
            prefix (str, optional): A prefix added to parameter, buffer, chat, model,
                completion config and function names to compose the keys in state_dict.
                Default: `''`.
            keep_vars (bool, optional): By default the [`Variable`][afnio.Variable]s
                returned in the state dict are detached from autodiff. If it's
                set to `True`, detaching will not be performed.
                Default: `False`.

        Returns:
            dict (dict): A dictionary containing a whole state of the module.

        Examples:
            >>> module.state_dict().keys()
            ['system_prompt', 'classification_labels', 'format_type', 'user_prompt']
        """
        if destination is None:
            destination = OrderedDict()
            destination._metadata = OrderedDict()

        local_metadata = dict(version=self._version)
        if hasattr(destination, "_metadata"):
            destination._metadata[prefix[:-1]] = local_metadata

        self._save_to_state_dict(destination, prefix, keep_vars)
        for name, module in self._modules.items():
            if module is not None:
                module.state_dict(
                    destination=destination,
                    prefix=prefix + name + ".",
                    keep_vars=keep_vars,
                )
        return destination

    def _load_chat_from_state_dict(
        self, name, key, param, input_param, error_msgs, assign
    ):
        """Load chat from `state_dict` while handling structure mismatches.

        This function ensures chat messages are correctly loaded into the module,
        validating structure, roles, and content sizes when `param` is already
        initialized. If `param` is `None`, the chat is registered dynamically, allowing
        `self.register_chat("messages", None)` to be used in `__init__` without
        predefining the number of messages and Variables.

        Args:
            name (str): Attribute name of the chat in the module.
            key (str): Attribute name of the chat in the state dictionary.
            param (Optional[MultiTurnMessages]): Existing chat structure.
            input_param (MultiTurnMessages): Chat data from `state_dict`.
            error_msgs (List[str]): Accumulator for mismatch error messages.
            assign (bool): Whether to directly assign `input_param` to the module.
        """
        if param is not None:
            if not is_multi_turn_messages(input_param):
                error_msgs.append(
                    f"While copying the chat '{key}', expected a "
                    f"afnio.MultiTurnMessages from checkpoint, "
                    f"but received {type(input_param).__name__}."
                )
                return

            if len(input_param) != len(param):
                error_msgs.append(
                    f"Size mismatch for chat '{key}': copying a chat with "
                    f"{len(input_param)} messages from checkpoint, but the "
                    f"chat in the current model has {len(param)} messages."
                )
                return

            for i, (msg_input, msg_param) in enumerate(zip(input_param, param)):
                if msg_input["role"] != msg_param["role"]:
                    error_msgs.append(
                        f"Role mismatch for chat '{key}' at message {i}: "
                        f"copying a role '{msg_input['role']}' from checkpoint, "
                        f"but the role in the current model is '{msg_param['role']}'."
                    )
                    return

                if len(msg_input["content"]) != len(msg_param["content"]):
                    error_msgs.append(
                        f"Content size mismatch for chat '{key}' at message {i}: "
                        f"copying {len(msg_input['content'])} variables from "
                        f"checkpoint, but the message in the current model has "
                        f"{len(msg_param['content'])} variables."
                    )
                    return

                for j, (var_input, var_param) in enumerate(
                    zip(msg_input["content"], msg_param["content"])
                ):
                    is_input_scalar = not isinstance(var_input.data, list)
                    is_param_scalar = not isinstance(var_param.data, list)

                    if is_input_scalar != is_param_scalar:
                        error_msgs.append(
                            f"Type mismatch for chat '{key}' at message {i}, "
                            f"variable {j}: copying a "
                            f"{'scalar' if is_input_scalar else 'non-scalar'} "
                            f"param from checkpoint, but the param in the "
                            f"current model is "
                            f"{'scalar' if is_param_scalar else 'non-scalar'}."
                        )
                        return

                    if not is_input_scalar and len(var_input.data) != len(
                        var_param.data
                    ):
                        error_msgs.append(
                            f"Size mismatch for chat '{key}' at message {i}, "
                            f"variable {j}: copying a param with `.data` list "
                            f"of length {len(input_param.data)}from "
                            f"checkpoint, but the param in the current model "
                            f"has length {len(var_param.data)}."
                        )
                        return
        try:
            with hf.no_grad():
                if assign or param is None:
                    setattr(self, name, input_param)
                else:
                    # Shape checks are already done above
                    for i, (msg_input, msg_param) in enumerate(zip(input_param, param)):
                        for j, (var_input, var_param) in enumerate(
                            zip(msg_input["content"], msg_param["content"])
                        ):
                            var_param.copy_(var_input)
        except Exception as ex:
            error_msgs.append(
                f"While copying the chat named '{key}', "
                f"an exception occurred : {ex.args}."
            )

    def _load_param_buf_from_state_dict(
        self, name, key, param, input_param, error_msgs, assign
    ):
        """Load parameters and buffers from `state_dict`, ensuring consistency with
        the model.

        This function validates and assigns parameters or buffers from `state_dict`,
        ensuring they match the expected type, shape, and scalar properties. If `param`
        is `None`, the parameter is registered dynamically, allowing
        `self.register_parameter(name, None)` or `self.register_buffer(name, None)`
        in `__init__` without requiring a predefined structure.

        Args:
            name (str): Attribute name of the parameter or buffer in the module.
            key (str): Attribute name of the parameter or buffer in
                the state dictionary.
            param (Optional[Variable]): Existing parameter or buffer.
            input_param (Variable): Parameter or buffer data from `state_dict`.
            error_msgs (List[str]): Accumulator for mismatch error messages.
            assign (bool): Whether to directly assign `input_param` to the module.
        """
        if param is not None:
            if not isinstance(input_param, Variable):
                error_msgs.append(
                    f'While copying the parameter named "{key}", '
                    f"expected afnio.Variable from checkpoint, "
                    f"but received {type(input_param)}"
                )
                return

            is_scalar_input_param = is_scalar_variable(input_param)
            is_scalar_param = is_scalar_variable(param)

            if (
                not is_scalar_input_param
                and not is_scalar_param
                and len(input_param.data) != len(param.data)
            ):
                # local shape should match the one in checkpoint
                error_msgs.append(
                    f"Size mismatch for '{key}': copying a param with `.data` list "
                    f"of length {len(input_param.data)} from checkpoint, "
                    f"but the param in the current model has length {len(param.data)}."
                )
                return

            if is_scalar_input_param != is_scalar_param:
                # local and checkpoint params should be both either scalar or not
                error_msgs.append(
                    f"Type mismatch for {key}: copying a "
                    f"{'scalar' if is_scalar_variable(input_param) else 'non-scalar'} "
                    f"param from checkpoint, but the param in the current model is "
                    f"{'scalar' if is_scalar_variable(param) else 'non-scalar'}."
                )
                return
        try:
            with hf.no_grad():
                if assign or param is None:
                    # Shape checks are already done above
                    if isinstance(param, Parameter):
                        if not isinstance(input_param, Parameter):
                            input_param = Parameter(
                                input_param.data,
                                input_param.role,
                                requires_grad=param.requires_grad,
                            )
                        else:
                            input_param.requires_grad_(param.requires_grad)
                    setattr(self, name, input_param)
                else:
                    param.copy_(input_param)
        except Exception as ex:
            model_data_info = (
                "scalar value"
                if is_scalar_param
                else f"list of length {len(param.data)}"
            )
            checkpoint_data_info = (
                "scalar value"
                if is_scalar_input_param
                else f"list of length {len(input_param.data)}"
            )
            error_msgs.append(
                f"While copying the parameter named '{key}', "
                f"which is a {model_data_info} in the current model and "
                f"which is a {checkpoint_data_info} in the checkpoint, "
                f"an exception occurred : {ex.args}."
            )

    def _load_model_from_state_dict(
        self, name, key, param, input_param, error_msgs, model_clients
    ):
        """Load model clients from `state_dict`, ensuring they are provided
        via `model_clients`.

        This function enforces that model clients must be pre-initialized and passed
        through `model_clients`. It validates the class type and ensures consistency
        between the expected and provided model client types.

        Args:
            name (str): Attribute name of the model client in the module.
            key (str): Attribute name of the model client in the state dictionary.
            param (Optional[BaseModel]): Existing model instance.
            input_param (dict): Serialized model client data from `state_dict`.
            error_msgs (List[str]): Accumulator for mismatch error messages.
            model_clients (Dict[str, BaseModel]): Pre-initialized model clients.

        Raises:
            ValueError: If the required model client is missing from `model_clients`.
        """
        if not isinstance(input_param, dict) or "class_type" not in input_param:
            error_msgs.append(
                f"While copying the model client '{key}', expected a serialized "
                f"dictionary with a 'class_type' entry from checkpoint, "
                f"but received {type(input_param).__name__}."
            )
            return

        model_cls_name = input_param["class_type"]
        model_cls = MODEL_REGISTRY.get(model_cls_name)

        if model_cls is None or not issubclass(model_cls, BaseModel):
            error_msgs.append(
                f"Model client '{key}' referenced an unknown or invalid class "
                f"type '{model_cls_name}' from checkpoint. "
                f"Ensure that '{model_cls_name}' is registered in MODEL_REGISTRY "
                f"and inherits from BaseModel."
            )
            return

        if key not in model_clients or not isinstance(model_clients[key], model_cls):
            error_msgs.append(
                f"Missing model client for '{key}' of expected "
                f"type '{model_cls}'. Please provide an instance "
                f"of '{model_cls}' using the `model_clients` "
                f"dictionary when calling `load_state_dict()`."
            )
            return

        if param is not None and param.get("class_type") != model_cls_name:
            error_msgs.append(
                f"Type mismatch for model client '{key}': expected an instance of "
                f"'{param.get('class_type', 'Unknown')}' from checkpoint, "
                f"but received '{model_cls_name}'."
            )
            return

        try:
            # Create new model client istance
            setattr(self, name, model_clients[key])

            # Add usage metadata to new model client instance
            usage = input_param.get("usage", {})
            new_model = getattr(self, name)
            new_model.update_usage(usage)
        except Exception as ex:
            error_msgs.append(
                f"Failed to initialize model client '{key}' of type '{model_cls_name}' "
                f"from state_dict: {ex.args}."
            )

    def _load_completion_config_from_state_dict(
        self, name, key, param, input_param, error_msgs
    ):
        """Load completion configuration from `state_dict`, ensuring consistency
        with the model.

        This function assigns completion configurations from `state_dict`. If `param` is
        `None`, the completion config is registered dynamically, allowing
        `self.register_completion_config("completion_config", None)`
        to be used in `__init__` without requiring predefined values.

        Args:
            name (str): Attribute name of the completion config in the module.
            key (str): Attribute name of the completion config in the state dictionary.
            param (Optional[Dict[str, Any]]): Existing completion config dictionary.
            input_param (Dict[str, Any]): Completion config data from `state_dict`.
            error_msgs (List[str]): Accumulator for mismatch error messages.
        """
        if param is not None:
            if not isinstance(input_param, dict):
                error_msgs.append(
                    f"While copying the completion config '{key}', "
                    f"expected a dictionary from checkpoint, "
                    f"but received {type(input_param).__name__}."
                )
                return

        try:
            setattr(self, name, input_param)
        except Exception as ex:
            error_msgs.append(
                f"While copying the completion config named '{key}', "
                f"an exception occurred: {ex.args}."
            )

    def _load_function_from_state_dict(self, name, key, param, input_param, error_msgs):
        """Load function from `state_dict`, ensuring consistency with the model.

        This function assigns functions from `state_dict`. If `param` is
        `None`, the function is registered dynamically, allowing
        `self.register_function("function", None)` to be used in `__init__`
        without requiring predefined values.

        Args:
            name (str): Attribute name of the function in the module.
            key (str): Attribute name of the function in the state dictionary.
            param (Optional[Callable[..., Any]]): Existing function reference.
            input_param (Callable[..., Any]): Function data from `state_dict`.
            error_msgs (List[str]): Accumulator for mismatch error messages.
        """
        if param is not None:
            if not _is_valid_function(input_param):
                error_msgs.append(
                    f"While copying the function '{key}', "
                    f"expected a standalone function from checkpoint, "
                    f"but received {type(input_param).__name__}."
                )
                return

        try:
            setattr(self, name, input_param)
        except Exception as ex:
            error_msgs.append(
                f"While copying the function named '{key}', "
                f"an exception occurred: {ex.args}."
            )

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
        model_clients: Dict[str, BaseModel] = None,
    ):
        """Copy parameters, buffers, chats, models, completion configs and functions
        from `state_dict` into only this module, but not its descendants.

        This is called on every submodule in [`load_state_dict`][..load_state_dict].
        Metadata saved for this module in input `state_dict` is provided as
        `local_metadata`. For state dicts without metadata, `local_metadata` is empty.
        Subclasses can achieve class-specific backward compatible loading using
        the version number at `local_metadata.get("version", None)`.
        Additionally, `local_metadata` can also contain the key
        `assign_to_params_buffers_chats` that indicates whether keys should be
        assigned their corresponding `Variable` or `MultiTurnMessages`
        in the `state_dict`.

        Note:
            `state_dict` is not the same object as the input `state_dict`
            to [`load_state_dict`][..load_state_dict]. So it can be modified.

        Args:
            state_dict (dict): A dict containing parameters, persistent buffers,
                chats, models, completion configs and functions.
            prefix (str): The prefix for parameters, buffers, chats, models,
                completion configs and functions used in this module.
            local_metadata (dict): A dict containing the metadata for this module.
            strict (bool): Whether to strictly enforce that the keys in
                `state_dict` with `prefix` match the names of parameters, buffers,
                chats, models, completion configs and functions in this module.
            missing_keys (list of str): If `strict=True`, add missing keys to
                this list.
            unexpected_keys (list of str): If `strict=True`, add unexpected
                keys to this list.
            error_msgs (list of str): Error messages should be added to this
                list, and will be reported together in
                [`load_state_dict`][..load_state_dict].
            model_clients (dict, optional): A dictionary mapping model client keys
                (e.g., 'fw_model_client') to their respective instances of
                [`BaseModel`][afnio.models.model.BaseModel]. These instances will be
                used to reconstruct any model clients referenced within the optimizer
                state. If a required model client is missing, an error will be raised
                with instructions on how to provide the missing client.
        """
        persistent_buffers = {
            k: v
            for k, v in self._buffers.items()
            if k not in self._non_persistent_buffers_set
        }
        local_name_params = itertools.chain(
            self._parameters.items(),
            persistent_buffers.items(),
            self._chats.items(),
            self._models.items(),
            self._completion_configs.items(),
            self._functions.items(),
        )
        local_state = {k: v for k, v in local_name_params}
        assign_to_params_buffers_chats = local_metadata.get(
            "assign_to_params_buffers_chats", False
        )
        model_clients = model_clients or {}

        for name, param in local_state.items():
            key = prefix + name
            if key in state_dict:
                input_param = state_dict[key]
                # Handle chats
                if name in self._chats:
                    self._load_chat_from_state_dict(
                        name,
                        key,
                        param,
                        input_param,
                        error_msgs,
                        assign_to_params_buffers_chats,
                    )
                # Handle models
                elif name in self._models:
                    self._load_model_from_state_dict(
                        name,
                        key,
                        param,
                        input_param,
                        error_msgs,
                        model_clients,
                    )
                # Handle completion configs
                elif name in self._completion_configs:
                    self._load_completion_config_from_state_dict(
                        name,
                        key,
                        param,
                        input_param,
                        error_msgs,
                    )
                # Handle functions
                elif name in self._functions:
                    self._load_function_from_state_dict(
                        name,
                        key,
                        param,
                        input_param,
                        error_msgs,
                    )
                else:
                    # Handle parameters and buffers
                    self._load_param_buf_from_state_dict(
                        name,
                        key,
                        param,
                        input_param,
                        error_msgs,
                        assign_to_params_buffers_chats,
                    )
            elif strict:
                missing_keys.append(key)

        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
        if (
            getattr(self.__class__, "set_extra_state", Module.set_extra_state)
            is not Module.set_extra_state
        ):
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
            elif strict:
                missing_keys.append(extra_state_key)
        elif strict and (extra_state_key in state_dict):
            unexpected_keys.append(extra_state_key)

        if strict:
            for key in state_dict.keys():
                if key.startswith(prefix) and key != extra_state_key:
                    input_name = key[len(prefix) :].split(".", 1)  # noqa: E203
                    # Must be Module if it have attributes
                    if len(input_name) > 1:
                        if input_name[0] not in self._modules:
                            unexpected_keys.append(key)
                    elif input_name[0] not in local_state:
                        unexpected_keys.append(key)

    def load_state_dict(
        self,
        state_dict: Mapping[str, Any],
        strict: bool = True,
        assign: bool = False,
        model_clients: Dict[str, BaseModel] = None,
    ):
        """Copy parameters, buffers, chats, models, completion configs and functions
        from `state_dict` into this module and its descendants.

        If `strict` is `True`, then the keys of `state_dict` must exactly match the keys
        returned by this module's [`state_dict`][..state_dict] function.

        Warning:
            If `assign` is `True` the optimizer must be created after
            the call to [`load_state_dict`][.].

        Note:
            If a parameter, or buffer, or chat, or model, or completion config, or
            function is registered as `None` and its corresponding key exists in
            `state_dict`, [`load_state_dict`][.] will raise a `RuntimeError`.

        Args:
            state_dict (dict): A dict containing parameters, persistent buffers,
                chats, models, completion configs and functions.
            strict (bool, optional): Whether to strictly enforce that the keys
                in `state_dict` match the keys returned by this module's
                [`state_dict`][..state_dict] function. Default: `True`
            assign (bool, optional): When `False`, the properties of the Variables
                in the current module are preserved while when `True`, the
                properties of the Variables in the state dict are preserved. The only
                exception is the `requires_grad` field of
                [`Parameter`][afnio.cognitive.parameter.Parameter]'s for which the value
                from the module is preserved. Default: `False`
            model_clients (dict, optional): A dictionary mapping model client keys
                (e.g., 'fw_model_client') to their respective instances of
                [`BaseModel`][afnio.models.model.BaseModel]. These instances will be
                used to reconstruct any model clients referenced within the optimizer
                state. If a required model client is missing, an error will be raised
                with instructions on how to provide the missing client.

        Returns:
            incompatible_keys (NamedTuple): A `NamedTuple` with `missing_keys` and
                `unexpected_keys` fields (see below note for more details).

        Note:
            The return value reports key mismatches encountered during loading:

            - `missing_keys` is a list of str containing any keys that are expected by
                this module but missing from the provided `state_dict`.
            - `unexpected_keys` is a list of str containing the keys that are not
                expected by this module but present in the provided `state_dict`.

        """
        if not isinstance(state_dict, Mapping):
            raise TypeError(
                f"Expected state_dict to be dict-like, got {type(state_dict)}."
            )

        missing_keys: List[str] = []
        unexpected_keys: List[str] = []
        error_msgs: List[str] = []

        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, "_metadata", None)
        state_dict = OrderedDict(state_dict)
        if metadata is not None:
            # mypy isn't aware that "_metadata" exists in state_dict
            state_dict._metadata = metadata  # type: ignore[attr-defined]

        def load(module, local_state_dict, prefix=""):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            if assign:
                local_metadata["assign_to_params_buffers_chats"] = assign
            module._load_from_state_dict(
                local_state_dict,
                prefix,
                local_metadata,
                True,
                missing_keys,
                unexpected_keys,
                error_msgs,
                model_clients,
            )
            for name, child in module._modules.items():
                if child is not None:
                    child_prefix = prefix + name + "."
                    child_state_dict = {
                        k: v
                        for k, v in local_state_dict.items()
                        if k.startswith(child_prefix)
                    }
                    load(child, child_state_dict, child_prefix)  # noqa: F821

        load(self, state_dict)
        del load

        if strict:
            if len(unexpected_keys) > 0:
                error_msgs.insert(
                    0,
                    "Unexpected key(s) in state_dict: {}. ".format(
                        ", ".join(f'"{k}"' for k in unexpected_keys)
                    ),
                )
            if len(missing_keys) > 0:
                error_msgs.insert(
                    0,
                    "Missing key(s) in state_dict: {}. ".format(
                        ", ".join(f'"{k}"' for k in missing_keys)
                    ),
                )

        if len(error_msgs) > 0:
            raise RuntimeError(
                "Error(s) in loading state_dict for {}:\n\t{}".format(
                    self.__class__.__name__, "\n\t".join(error_msgs)
                )
            )
        return _IncompatibleKeys(missing_keys, unexpected_keys)

    def get_extra_state(self) -> Any:
        """Return any extra state to include in the module's state_dict.

        Implement this and a corresponding [`set_extra_state`][..set_extra_state] for
        your module if you need to store extra state. This function is called when
        building the module's `state_dict()`.

        Note that extra state should be picklable to ensure working serialization
        of the state_dict.

        Returns:
            object: Any extra state to store in the module's state_dict.
        """
        raise RuntimeError(
            "Reached a code path in Module.get_extra_state() that "
            "should never be called."
        )

    def set_extra_state(self, state: Any) -> None:
        """Set extra state contained in the loaded `state_dict`.

        This function is called from [`load_state_dict`][..load_state_dict] to handle
        any extra state found within the `state_dict`. Implement this function and a
        corresponding [`get_extra_state`][..get_extra_state] for your module if you need
        to store extra state within its `state_dict`.

        Args:
            state (dict): Extra state from the `state_dict`.
        """
        raise RuntimeError(
            "Reached a code path in Module.set_extra_state() that "
            "should never be called. "
        )

    def _named_members(
        self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True
    ):
        r"""Help yield various names + members of modules."""
        memo = set()
        modules = (
            self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate)
            if recurse
            else [(prefix, self)]
        )
        for module_prefix, module in modules:
            members = get_members_fn(module)
            for k, v in members:
                value = v

                # Convert chat messages into a hashable structure
                if is_multi_turn_messages(v):
                    v = tuple((entry["role"], tuple(entry["content"])) for entry in v)

                # Convert dictionaries (e.g., completion_args) into hashable tuples
                elif isinstance(v, dict):
                    v = tuple(sorted(v.items()))

                if v is None or v in memo:
                    continue

                if remove_duplicate:
                    memo.add(v)
                name = module_prefix + ("." if module_prefix else "") + k
                yield name, value

    def buffers(self, recurse: bool = True) -> Iterator[Variable]:
        r"""Return an iterator over module buffers.

        Args:
            recurse: if `True`, then yields buffers of this module
                and all submodules. Otherwise, yields only buffers that
                are direct members of this module.

        Yields:
            Module buffer

        Examples:
            >>> for buf in model.buffers():
            >>>     print(type(buf), buf.data)
            <class 'afnio.Variable'> ("Structure your answer as JSON.")
            <class 'afnio.Variable'> ("Use the format\n\n{\n  \"response\": \"Your concise answer here.\"\n}")
        """  # noqa: E501
        for _, buf in self.named_buffers(recurse=recurse):
            yield buf

    def named_buffers(
        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
    ) -> Iterator[Tuple[str, Variable]]:
        r"""Return an iterator over module buffers, yielding both the name of
        the buffer as well as the buffer itself.

        Args:
            prefix (str): prefix to prepend to all buffer names.
            recurse (bool, optional): if `True`, then yields buffers of this module
                and all submodules. Otherwise, yields only buffers that
                are direct members of this module. Defaults to `True`.
            remove_duplicate (bool, optional): whether to remove the duplicated buffers
                in the result. Defaults to `True`.

        Yields:
            Tuple containing the name and buffer

        Examples:
            >>> for name, buf in self.named_buffers():
            >>>     if "format_type" in name:
            >>>         print(param.data)
        """
        gen = self._named_members(
            lambda module: module._buffers.items(),
            prefix=prefix,
            recurse=recurse,
            remove_duplicate=remove_duplicate,
        )
        yield from gen

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """Return an iterator over module parameters.

        This is typically passed to an optimizer.

        Args:
            recurse (bool): if `True`, then yields parameters of this module
                and all submodules. Otherwise, yields only parameters that
                are direct members of this module.

        Yields:
            Module parameter

        Examples:
            >>> for param in pipeline.parameters():
            >>>     print(type(param), param.data)
            <class 'cog.Parameter'> ("You are a doctor.")
            <class 'cog.Parameter'> ("Only answer with YES or NO.")
        """
        for _, param in self.named_parameters(recurse=recurse):
            yield param

    def named_parameters(
        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
    ) -> Iterator[Tuple[str, Parameter]]:
        """Return an iterator over module parameters, yielding both the name of the
        parameter as well as the parameter itself.

        Args:
            prefix (str): prefix to prepend to all parameter names.
            recurse (bool): if `True`, then yields parameters of this module
                and all submodules. Otherwise, yields only parameters that
                are direct members of this module.
            remove_duplicate (bool, optional): whether to remove the duplicated
                parameters in the result. Defaults to `True`.

        Yields:
            Tuple containing the name and parameter

        Examples:
            >>> for name, param in self.named_parameters():
            >>>     if "prompt" in name:
            >>>         print(param.data)
        """
        gen = self._named_members(
            lambda module: module._parameters.items(),
            prefix=prefix,
            recurse=recurse,
            remove_duplicate=remove_duplicate,
        )
        yield from gen

    def chats(self, recurse: bool = True) -> Iterator[MultiTurnMessages]:
        """Return an iterator over module multi-turn chats.

        This is typically passed to an optimizer.

        Args:
            recurse (bool): if `True`, then yields chats of this module
                and all submodules. Otherwise, yields only chats that
                are direct members of this module.

        Yields:
            Module chats

        Examples:
            >>> for chat in pipeline.chats():
            >>>     print(type(chat), chat)
            <class 'cog.MultiTurnMessages'> [{'role': 'system', 'content': [Variable(data=You are a doctor., role=system instruction, requires_grad=False)]}, {'role': 'user', 'content': [Variable(data=Is {item} a disease?, role=user query, requires_grad=False)]}]
            <class 'cog.MultiTurnMessages'> [{'role': 'system', 'content': [Variable(data=You are a helpful assistant., role=system instruction, requires_grad=False), Variable(data=Only answer with YES or NO., role=user query, requires_grad=False)]}]
        """  # noqa: E501
        for _, chat in self.named_chats(recurse=recurse):
            yield chat

    def named_chats(
        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
    ) -> Iterator[Tuple[str, MultiTurnMessages]]:
        """Return an iterator over module multi-turn chats, yielding both
        the name of chat as well as the chat itself.

        Args:
            prefix (str): prefix to prepend to all chat names.
            recurse (bool): if `True`, then yields chats of this module
                and all submodules. Otherwise, yields only chats that
                are direct members of this module.
            remove_duplicate (bool, optional): whether to remove the duplicated
                chats in the result. Defaults to `True`.

        Yields:
            Tuple containing the name and chat

        Examples:
            >>> for name, chat in self.named_chats():
            >>>     if "messages" in name:
            >>>         print(messages[0]["role"])
        """
        gen = self._named_members(
            lambda module: module._chats.items(),
            prefix=prefix,
            recurse=recurse,
            remove_duplicate=remove_duplicate,
        )
        yield from gen

    def models(self, recurse: bool = True) -> Iterator[BaseModel]:
        """Return an iterator over module language model clients.

        Args:
            recurse (bool): if `True`, then yields models of this module
                and all submodules. Otherwise, yields only models that
                are direct members of this module.

        Yields:
            Module model

        Examples:
            >>> for model in pipeline.models():
            >>>     print(type(model))
            <class 'afnio.models.openai.AsyncOpenAI'>
        """
        for _, model in self.named_models(recurse=recurse):
            yield model

    def named_models(
        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
    ) -> Iterator[Tuple[str, BaseModel]]:
        """Return an iterator over module model clients, yielding both the name of the
        model as well as the model itself.

        Args:
            prefix (str): prefix to prepend to all model names.
            recurse (bool): if `True`, then yields models of this module
                and all submodules. Otherwise, yields only models that
                are direct members of this module.
            remove_duplicate (bool, optional): whether to remove the duplicated
                models in the result. Defaults to `True`.

        Yields:
            Tuple containing the name and model

        Examples:
            >>> for name, model in self.named_models():
            >>>     print(name, type(model))
            model_client <class 'afnio.models.openai.AsyncOpenAI'>
        """
        gen = self._named_members(
            lambda module: module._models.items(),
            prefix=prefix,
            recurse=recurse,
            remove_duplicate=remove_duplicate,
        )
        yield from gen

    def completion_configs(self, recurse: bool = True) -> Iterator[Dict[str, Any]]:
        """Return an iterator over registered completion configs.

        Args:
            recurse (bool): if `True`, then yields completion configs of this module
                and all submodules. Otherwise, yields only completion configs that
                are direct members of this module.

        Yields:
            Completion arguments

        Examples:
            >>> for config in model.completion_configs():
            >>>     print(config)
            {"model": "gpt-4o", "seed": 42, "temperature": 0}
        """
        for _, config in self.named_completion_configs(recurse=recurse):
            yield config

    def named_completion_configs(
        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
    ) -> Iterator[Tuple[str, Dict[str, Any]]]:
        """Return an iterator over module completion configs, yielding both the name of
        the completion config as well as the completion config itself.

        Args:
            prefix (str): prefix to prepend to all completion config names.
            recurse (bool): if `True`, then yields completion configs of this module
                and all submodules. Otherwise, yields only completion configs that
                are direct members of this module.
            remove_duplicate (bool, optional): whether to remove the duplicated
                completion configs in the result. Defaults to `True`.

        Yields:
            Tuple containing the name and completion configs

        Examples:
            >>> for name, config in self.named_completion_configs():
            >>>     print(name, type(config))
            chat.completion_args {'model': 'gpt-4o', 'seed': 42, 'temperature': 0}
        """
        gen = self._named_members(
            lambda module: module._completion_configs.items(),
            prefix=prefix,
            recurse=recurse,
            remove_duplicate=remove_duplicate,
        )
        yield from gen

    def functions(self, recurse: bool = True) -> Iterator[Dict[str, Any]]:
        """Return an iterator over registered functions.

        Args:
            recurse (bool): if `True`, then yields functions of this module
                and all submodules. Otherwise, yields only functions that
                are direct members of this module.

        Yields:
            Functions

        Examples:
            >>> for func in model.functions():
            >>>     print(func)
            <built-in function sum>
            <function my_func at 0x7e7a0665b9c0>
        """
        for _, config in self.named_functions(recurse=recurse):
            yield config

    def named_functions(
        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
    ) -> Iterator[Tuple[str, Dict[str, Any]]]:
        """Return an iterator over module functions, yielding both the name of
        the function as well as the function itself.

        Args:
            prefix (str): prefix to prepend to all function names.
            recurse (bool): if `True`, then yields functions of this module
                and all submodules. Otherwise, yields only functions that
                are direct members of this module.
            remove_duplicate (bool, optional): whether to remove the duplicated
                functions in the result. Defaults to `True`.

        Yields:
            Tuple containing the name and functions

        Examples:
            >>> for name, func in self.named_functions():
            >>>     print(name, func)
            reduction_fn <built-in function sum>
            eval_fn <function my_func at 0x7e7a0665b9c0>
        """
        gen = self._named_members(
            lambda module: module._functions.items(),
            prefix=prefix,
            recurse=recurse,
            remove_duplicate=remove_duplicate,
        )
        yield from gen

    def children(self) -> Iterator["Module"]:
        """Return an iterator over immediate children modules.

        Yields:
            A child module
        """
        for _, module in self.named_children():
            yield module

    def named_children(self) -> Iterator[Tuple[str, "Module"]]:
        """Return an iterator over immediate children modules, yielding both the name
        of the module as well as the module itself.

        Yields:
            Tuple containing a name and child module
        """
        memo = set()
        for name, module in self._modules.items():
            if module is not None and module not in memo:
                memo.add(module)
                yield name, module

    def modules(self) -> Iterator["Module"]:
        """Return an iterator over all modules in the network.

        Yields:
            A module in the network

        Note:
            Duplicate modules are returned only once. In the following
            example, `add` will be returned only once.

        Examples:
            >>> class MyPipeline(cog.Module):
            ...     def __init__(self):
            ...         super().__init__()
            ...         add = cog.Add()
            ...         self.module1 = add
            ...         self.module2 = add
            >>>     def forward(self, x, y):
            ...         out1 = self.module1(x, x)
            ...         out2 = self.module2(x, y)
            ...         return out1 + out2
            >>> pipeline = MyPipeline()
            >>> for idx, m in enumerate(model.modules()):
            ...     print(idx, '->', m)
            0 -> MyModel(
            (module1): Module()
            (module2): Module()
            )
            1 -> Module()
        """
        for _, module in self.named_modules():
            yield module

    def named_modules(
        self,
        memo: Optional[Set["Module"]] = None,
        prefix: str = "",
        remove_duplicate: bool = True,
    ) -> Iterator[Tuple[str, "Module"]]:
        """Return an iterator over all modules in the network, yielding both
        the name of the module as well as the module itself.

        Args:
            memo: a memo to store the set of modules already added to the result
            prefix: a prefix that will be added to the name of the module
            remove_duplicate: whether to remove the duplicated module instances
                in the result or not

        Yields:
            Tuple of name and module

        Note:
            Duplicate modules are returned only once. In the following
            example, `add` will be returned only once.

        Examples:
            >>> class MyPipeline(cog.Module):
            ...     def __init__(self):
            ...     super().__init__()
            ...     add = cog.Add()
            ...     self.module1 = add
            ...     self.module2 = add
            >>> def forward(self, x, y):
            ...     out1 = self.module1(x, x)
            ...     out2 = self.module2(x, y)
            ...     return out1 + out2
            >>> pipeline = MyPipeline()
            >>> for idx, m in enumerate(model.named_modules()):
            ...     print(idx, '->', m)
            0 -> ('', MyModel(
            (module1): Module()
            (module2): Module()
            ))
            1 -> ('module1', Module())

            >>> class MyPipeline(cog.Module):
            ...     def __init__(self):
            ...     super().__init__()
            ...     add = cog.Add()
            ...     self.module1 = add
            ...     self.module2 = add
            >>> def forward(self, x, y):
            ...     out1 = self.module1(x, x)
            ...     out2 = self.module2(x, y)
            ...     return out1 + out2
            >>> pipeline = MyPipeline()
            >>> for idx, m in enumerate(model.named_modules(remove_duplicate=False)):
            ...     print(idx, '->', m)
            0 -> ('', MyModel(
            (module1): Module()
            (module2): Module()
            ))
            1 -> ('module1', Module())
            2 -> ('module2', Module())
        """
        if memo is None:
            memo = set()
        if self not in memo:
            if remove_duplicate:
                memo.add(self)
            yield prefix, self
            for name, module in self._modules.items():
                if module is None:
                    continue
                submodule_prefix = prefix + ("." if prefix else "") + name
                yield from module.named_modules(
                    memo, submodule_prefix, remove_duplicate
                )

    def train(self: T, mode: bool = True) -> T:
        """Set the module in training mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected.

        Args:
            mode: whether to set training mode (`True`) or evaluation mode (`False`).

        Returns:
            self (Module): The module itself.
        """
        if not isinstance(mode, bool):
            raise ValueError("Training mode is expected to be boolean.")
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

    def eval(self: T) -> T:
        """Set the module in evaluation mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected.

        This is equivalent with calling `self.train(False)`.
        See [`train`][..train] for more details.

        Returns:
            self (Module): The module itself.
        """
        return self.train(False)

    def requires_grad_(self: T, requires_grad: bool = True) -> T:
        """Change if autodiff should record operations on parameters and chats
        in this module.

        This method sets the [`requires_grad`][afnio.Variable.requires_grad] attributes
        of all module parameters in-place. It also sets the
        [`requires_grad`][afnio.Variable.requires_grad] attributes of all the
        `Variables` within the content of multi-turn chats.

        **Effect on Parameters:**

        - Sets [`requires_grad`][afnio.Variable.requires_grad] for each registered
            parameter in the module.

        **Effect on Chats:**

        - Iterates through all multi-turn chats and sets
            [`requires_grad`][afnio.Variable.requires_grad] for each `Variable` in
            the `"content"` key of the chat's message.

        This method is helpful for freezing part of the module for finetuning
        or training parts of a model individually.

        Args:
            requires_grad: Whether autodiff should record operations on parameters and
                chats in this module.

        Returns:
            self (Module): The module itself.
        """
        # Set requires_grad on all parameters
        for p in self.parameters():
            p.requires_grad_(requires_grad)

        # Set requires_grad on all variables in message content
        for chat in self.chats():
            for message in chat:
                for variable in message["content"]:
                    variable.requires_grad_(requires_grad)

        return self

    def empty_grad(self) -> None:
        """Reset gradients of all model parameters and content variables
        in chats' messages.

        This method is useful for clearing out gradients before starting a new
        optimization step. It ensures that both module parameters and Variables within
        multi-turn chat's message contents have their gradients reset, avoiding
        unintended gradient accumulation.
        """
        # Reset gradients of all parameters
        for p in self.parameters():
            if p.grad:
                p.grad = []

        # Reset gradients of all variables in message content
        for chat in self.chats():
            for message in chat:
                for variable in message["content"]:
                    if variable.grad:
                        variable.grad = []

    def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
        """Perform a single training step.

        This method should be implemented in subclasses to define the training logic.
        It is called by the [`Trainer`][afnio.trainer.trainer.Trainer]
        during the training loop.

        Args:
            batch: The output of your data iterable, normally
                a [`DataLoader`][afnio.util.data.DataLoader].
            batch_idx: The index of this batch.

        Returns:
            The result of the training step (see below below note for details).

        Notes:
            The return value can be one of the following:

            - `Tuple[Variable, Variable]`: The loss as a tuple of two `Variable`s:
                - The evaluation `score` (a `Variable` containing the loss value).
                - The `explanation` (a `Variable` containing a string explanation
                    of the evaluation result).
            - `dict`: A dictionary. Can include any keys, but must include
                the key `'loss'` containing a tuple of two `Variable`s
                (`score` and `explanation`).
            - `None`: Skip to the next batch.

        Raises:
            NotImplementedError: If not implemented in a subclass.
        """
        raise NotImplementedError(
            "You must implement training_step in your Module subclass."
        )

    def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
        """Perform a single validation step.

        This method should be implemented in subclasses to define the validation logic.
        It is called by the [`Trainer`][afnio.trainer.trainer.Trainer]
        during the validation loop.

        Args:
            batch: The output of your data iterable,
                normally a [`DataLoader`][afnio.util.data.DataLoader].
            batch_idx: The index of this batch.

        Returns:
            The result of the validation step (see below below note for details).

        Notes:
            The return value can be one of the following:

            - `Tuple[Variable, Variable]`: The loss as a tuple of two `Variable`s:
                - The evaluation `score` (a `Variable` containing the loss value).
                - The `explanation` (a `Variable` containing a string explanation
                    of the evaluation result).
            - `dict`: A dictionary. Can include any keys, but must include
                the key `'loss'` containing a tuple of two `Variable`s
                (`score` and `explanation`).
            - `None`: Skip to the next batch.

        Raises:
            NotImplementedError: If not implemented in a subclass.
        """
        raise NotImplementedError(
            "You must implement validation_step in your Module subclass."
        )

    def test_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
        """Perform a single test step.

        This method should be implemented in subclasses to define the test logic.
        It is called by the [`Trainer`][afnio.trainer.trainer.Trainer]
        during the testing loop.

        Args:
            batch: The output of your data iterable,
                normally a [`DataLoader`][afnio.util.data.DataLoader].
            batch_idx: The index of this batch.

        Returns:
            The result of the test step (see below below note for details).

        Notes:
            The return value can be one of the following:

            - `Tuple[Variable, Variable]`: The loss as a tuple of two `Variable`s:
                - The evaluation `score` (a `Variable` containing the loss value).
                - The `explanation` (a `Variable` containing a string explanation
                    of the evaluation result).
            - `dict`: A dictionary. Can include any keys, but must include
                the key `'loss'` containing a tuple of two `Variable`s
                (`score` and `explanation`).
            - None: Skip to the next batch.

        Raises:
            NotImplementedError: If not implemented in a subclass.
        """
        raise NotImplementedError(
            "You must implement test_step in your Module subclass."
        )

    def configure_optimizers(self) -> Optimizer:
        """Configure and return the optimizer for this module.

        This method should be implemented in subclasses to define the optimizer
        configuration. It is called by the [`Trainer`][afnio.trainer.trainer.Trainer]
        to set up the optimization routine.

        Returns:
            An instance of an optimizer configured for this module.

        Raises:
            NotImplementedError: If not implemented in a subclass.
        """
        raise NotImplementedError(
            "You must implement configure_optimizers in your Module subclass."
        )

    def optimizers(self) -> Union[Optimizer, List[Optimizer]]:
        """Returns the optimizer(s) that are being used during training. Useful for
        manual optimization.

        This method is useful for accessing the optimizer(s) configured in the
        [`configure_optimizers`][..configure_optimizers] method by the
        [`Trainer.fit()`][afnio.trainer.trainer.Trainer.fit] method.

        Returns:
            The optimizer(s) used by this module.

        Examples:
            >>> optimizers = model.optimizers()
            >>> for optimizer in optimizers:
            >>>     print(optimizer)
            TGD (
            Parameter Group 0
                completion_args: {'model': 'gpt-4.1'}
                constraints: []
                inputs: {}
                messages: [
                {'role': 'system', 'content': [Variable(data="Placeholder Textual Gradient Descent optimizer system prompt", role=Textual Gradient Descent optimizer system prompt, requires_grad=False)]},
                {'role': 'user', 'content': [Variable(data="Placeholder for Textual Gradient Descent optimizer user prompt", role=Textual Gradient Descent optimizer user prompt, requires_grad=False)]}
                ]
                model_client: <afnio.models.openai.AsyncOpenAI object at 0x710df9c149a0>
                momentum: 3
            )
        """  # noqa: E501
        if self._optimizers is not None:
            return self._optimizers
        raise AttributeError(
            "No optimizer found. Did you call `configure_optimizers()` "
            "and did the `Trainer` set `_optimizers`?"
        )

training instance-attribute

Boolean represents whether this module is in training or evaluation mode. Defaults to True.

automatic_optimization instance-attribute

Boolean that determines whether optimization steps handled automatically by the Trainer.fit() or manually by the user. Defaults to True.

If True, the Trainer.fit() method will automatically call optimizer.clear_grad(), explanation.backward(), and optimizer.step(). If False, the user must perform backpropagation and optimizer steps manually in the [training_step()][afnio.cognitive.modules.module.Module.automatic_optimization.training_step] method.

__init__(*args, **kwargs)

Initialize internal Module state.

Calls super().__setattr__('a', a) instead of the typical self.a = a to avoid Module.__setattr__ overhead. Module's __setattr__ has special handling for parameters, submodules, buffers, multi-turn chats, language model clients and completion configurations but simply calls into super().__setattr__ for all other attributes.

Source code in afnio/cognitive/modules/module.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def __init__(self, *args, **kwargs) -> None:
    """Initialize internal Module state.

    Calls `super().__setattr__('a', a)` instead of the typical `self.a = a`
    to avoid `Module.__setattr__` overhead. Module's `__setattr__` has special
    handling for parameters, submodules, buffers, multi-turn chats, language model
    clients and completion configurations but simply calls into
    `super().__setattr__` for all other attributes.
    """
    super().__setattr__("training", True)
    super().__setattr__("automatic_optimization", True)
    super().__setattr__("_optimizers", None)
    super().__setattr__("_parameters", OrderedDict())
    super().__setattr__("_buffers", OrderedDict())
    super().__setattr__("_non_persistent_buffers_set", set())
    super().__setattr__("_chats", OrderedDict())
    super().__setattr__("_modules", OrderedDict())
    super().__setattr__("_models", OrderedDict())
    super().__setattr__("_completion_configs", OrderedDict())
    super().__setattr__("_functions", OrderedDict())

forward(*args, **kwargs)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

One should invoke the Module instance (Module.__call__ method) instead of directly calling Module.forward(). This way hooks are registered and run.

Source code in afnio/cognitive/modules/module.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def forward(self, *args, **kwargs) -> Any:
    """Define the computation performed at every call.

    Should be overridden by all subclasses.

    Note:
        One should invoke the [`Module`][..] instance (`Module.__call__` method)
        instead of directly calling `Module.forward()`. This way hooks are
        registered and run.
    """
    raise NotImplementedError(
        f"Module [{type(self).__name__}] is missing the required "
        '"forward" function.'
    )

extra_repr() abstractmethod

Set the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Returns:

Type Description
str

A string containing the extra representation of the module, which will be included in the module's __repr__ output.

Source code in afnio/cognitive/modules/module.py
218
219
220
221
222
223
224
225
226
227
228
229
@abstractmethod
def extra_repr(self) -> str:
    """Set the extra representation of the module.

    To print customized extra information, you should re-implement this method in
    your own modules. Both single-line and multi-line strings are acceptable.

    Returns:
        A string containing the extra representation of the module, which will be \
        included in the module's `__repr__` output.
    """
    return ""

register_buffer(name, variable, persistent=True)

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered an agent Parameter. For example, LMJudgeEvaluator's reduction_fn_purpose is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's state_dict.

Buffers can be accessed as attributes using given names.

Parameters:

Name Type Description Default
name str

Name of the buffer. The buffer can be accessed from this module using the given name.

required
variable Variable | None

Buffer to be registered. If None, then operations that run on buffers are ignored. If None, the buffer is not included in the module's state_dict.

required
persistent bool

Whether the buffer is part of this module's state_dict.

True

Example:: >>> self.register_buffer('reduction_fn_purpose', afnio.Variable(data="summation", role="reduction function purpose"))

Source code in afnio/cognitive/modules/module.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def register_buffer(
    self, name: str, variable: Optional[Variable], persistent: bool = True
) -> None:
    """Add a buffer to the module.

    This is typically used to register a buffer that should not to be
    considered an agent [`Parameter`][afnio.cognitive.parameter.Parameter].
    For example, [`LMJudgeEvaluator`][afnio.cognitive.modules.lm_judge_evaluator.LMJudgeEvaluator]'s
    `reduction_fn_purpose` is not a parameter, but is part of the module's state.
    Buffers, by default, are persistent and will be saved alongside parameters.
    This behavior can be changed by setting `persistent` to `False`. The only
    difference between a persistent buffer and a non-persistent buffer is that the
    latter will not be a part of this module's [`state_dict`][..state_dict].

    Buffers can be accessed as attributes using given names.

    Args:
        name: Name of the buffer. The buffer can be accessed from this module
            using the given name.
        variable: Buffer to be registered. If `None`, then operations that run
            on buffers are ignored. If `None`, the buffer is **not** included in
            the module's [`state_dict`][..state_dict].
        persistent: Whether the buffer is part of this module's
            [`state_dict`][..state_dict].

    Example::
        >>> self.register_buffer('reduction_fn_purpose', afnio.Variable(data="summation", role="reduction function purpose"))
    """  # noqa: E501
    if "_buffers" not in self.__dict__:
        raise AttributeError("Cannot assign buffer before Module.__init__() call.")
    elif not isinstance(name, str):
        raise TypeError(
            f"Buffer name should be a string. Got {type(name).__name__}."
        )
    elif "." in name:
        raise KeyError('Buffer name cannot contain ".".')
    elif name == "":
        raise KeyError('Buffer name cannot be empty string "".')
    elif hasattr(self, name) and name not in self._buffers:
        raise KeyError(f"Attribute '{name}' already exists.")
    elif variable is not None and not isinstance(variable, Variable):
        raise TypeError(
            f"Cannot assign '{type(variable).__name__}' object to buffer '{name}' "
            "(afnio.Variable or None required)."
        )
    else:
        self._buffers[name] = variable
        if persistent:
            self._non_persistent_buffers_set.discard(name)
        else:
            self._non_persistent_buffers_set.add(name)

register_parameter(name, param)

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

Parameters:

Name Type Description Default
name str

Name of the parameter. The parameter can be accessed from this module using the given name.

required
param Parameter | None

Parameter to be added to the module. If None, then operations that run on parameters are ignored. If None, the parameter is not included in the module's state_dict.

required
Source code in afnio/cognitive/modules/module.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
    """Add a parameter to the module.

    The parameter can be accessed as an attribute using given name.

    Args:
        name: Name of the parameter. The parameter can be accessed from this module
            using the given name.
        param: Parameter to be added to the module. If `None`, then operations that
            run on parameters are ignored. If `None`, the parameter is **not**
            included in the module's [`state_dict`][..state_dict].
    """
    if "_parameters" not in self.__dict__:
        raise AttributeError(
            "Cannot assign parameter before Module.__init__() call."
        )

    elif not isinstance(name, str):
        raise TypeError(
            f"Parameter name should be a string. Got {type(name).__name__}."
        )
    elif "." in name:
        raise KeyError('Parameter name cannot contain ".".')
    elif name == "":
        raise KeyError('Parameter name cannot be empty string "".')
    elif hasattr(self, name) and name not in self._parameters:
        raise KeyError(f"Attribute '{name}' already exists.")

    if param is None:
        self._parameters[name] = None
    elif not isinstance(param, Parameter):
        raise TypeError(
            f"Cannot assign '{type(param).__name__}' object to parameter '{name}' "
            "(afnio.cognitive.Parameter or None required)."
        )
    elif param.grad_fn:
        raise ValueError(
            f"Cannot assign non-leaf Variable to parameter '{name}'. Model "
            f"parameters must be created explicitly. To express '{name}' "
            "as a function of another Variable, compute the value in "
            "the forward() method."
        )
    else:
        self._parameters[name] = param

register_chat(name, messages)

Add multi-turn chat messages to the module.

The chat can be accessed as an attribute using given name.

Parameters:

Name Type Description Default
name str

Name of the chat. The chat can be accessed from this module using the given name.

required
messages MultiTurnMessages | None

Chat to be added to the module. If None, then operations that run on chats are ignored. If None, the chat is not included in the module's state_dict.

required
Source code in afnio/cognitive/modules/module.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def register_chat(self, name: str, messages: Optional[MultiTurnMessages]) -> None:
    """Add multi-turn chat messages to the module.

    The chat can be accessed as an attribute using given name.

    Args:
        name: Name of the chat. The chat can be accessed from this module
            using the given name.
        messages: Chat to be added to the module. If `None`, then operations that
            run on chats are ignored. If `None`, the chat is **not** included in
            the module's [`state_dict`][..state_dict].
    """
    if "_chats" not in self.__dict__:
        raise AttributeError("Cannot assign chat before Module.__init__() call.")

    elif not isinstance(name, str):
        raise TypeError(
            f"Chat name should be a string. " f"Got {type(name).__name__}."
        )
    elif "." in name:
        raise KeyError('Chat name cannot contain ".".')
    elif name == "":
        raise KeyError('Chat name cannot be empty string "".')
    elif hasattr(self, name) and name not in self._chats:
        raise KeyError(f"Attribute '{name}' already exists.")

    if messages is None:
        self._chats[name] = None
    elif not is_multi_turn_messages(messages):
        raise TypeError(
            f"Cannot assign '{type(messages).__name__}' object to chat '{name}' "
            "(afnio.MultiTurnMessages or None required)."
        )
    else:
        self._chats[name] = messages

register_model(name, model)

Add language model the module.

The language model can be accessed as an attribute using given name.

Parameters:

Name Type Description Default
name str

Name of the model. The model can be accessed from this module using the given name.

required
model BaseModel | None

Model to be added to the module. If None, then operations that run on models are ignored. If None, the model is not included in the module's state_dict.

required
Source code in afnio/cognitive/modules/module.py
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def register_model(self, name: str, model: Optional[BaseModel]) -> None:
    """Add language model the module.

    The language model can be accessed as an attribute using given name.

    Args:
        name: Name of the model. The model can be accessed from this module
            using the given name.
        model: Model to be added to the module. If `None`, then operations that run
            on models are ignored. If `None`, the model is **not** included in
            the module's [`state_dict`][..state_dict].
    """
    if "_models" not in self.__dict__:
        raise AttributeError("Cannot assign model before Module.__init__() call.")

    elif not isinstance(name, str):
        raise TypeError(
            f"Model name should be a string. " f"Got {type(name).__name__}."
        )
    elif "." in name:
        raise KeyError('Model name cannot contain ".".')
    elif name == "":
        raise KeyError('Model name cannot be empty string "".')
    elif hasattr(self, name) and name not in self._models:
        raise KeyError(f"Attribute '{name}' already exists.")

    if model is None:
        self._models[name] = None
    elif not isinstance(model, BaseModel):
        raise TypeError(
            f"Cannot assign '{type(model).__name__}' object to model '{name}' "
            "(afnio.models.BaseModel or None required)."
        )
    else:
        self._models[name] = model

register_completion_config(name, args)

Register completion-specific arguments for text generation.

This method allows dynamic storage of completion-related parameters such as temperature, max_tokens, top_p, etc.

Parameters:

Name Type Description Default
name str

Name of the completion argument set.

required
args dict[str, Any] | None

Dictionary of completion arguments. If None, the argument is not included in the module's state_dict.

required
Source code in afnio/cognitive/modules/module.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
def register_completion_config(
    self, name: str, args: Optional[Dict[str, Any]]
) -> None:
    """Register completion-specific arguments for text generation.

    This method allows dynamic storage of completion-related parameters
    such as `temperature`, `max_tokens`, `top_p`, etc.

    Args:
        name: Name of the completion argument set.
        args: Dictionary of completion arguments. If `None`, the argument is **not**
            included in the module's [`state_dict`][..state_dict].
    """
    if not isinstance(name, str):
        raise TypeError(
            f"Completion config name should be a string. Got {type(name).__name__}."
        )
    elif "." in name:
        raise KeyError('Completion config name cannot contain ".".')
    elif name == "":
        raise KeyError('Completion config name cannot be an empty string "".')
    elif hasattr(self, name) and name not in self._completion_configs:
        raise KeyError(f"Attribute '{name}' already exists.")

    if args is None:
        self._completion_configs[name] = None
    elif not isinstance(args, dict):
        raise TypeError(
            f"Cannot assign '{type(args).__name__}' object to "
            f"completion config '{name}' (dict or None required)."
        )
    else:
        self._completion_configs[name] = args

register_function(name, func)

Add a function to the module.

The function can be accessed as an attribute using given name.

Parameters:

Name Type Description Default
name str

Name of the function. The function can be accessed from this module using the given name.

required
func FunctionType | None

A standard Python function (i.e., a def-defined function, not a lambda or callable object) that can be pickled and registered for later execution. If None, the function is unregistered. If None, the function is not included in the module's state_dict.

required
Source code in afnio/cognitive/modules/module.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
def register_function(self, name: str, func: Optional[FunctionType]) -> None:
    """Add a function to the module.

    The function can be accessed as an attribute using given name.

    Args:
        name: Name of the function. The function can be accessed from this module
            using the given name.
        func: A standard Python function (i.e., a def-defined function, not a lambda
            or callable object) that can be pickled and registered for later
            execution. If `None`, the function is unregistered. If `None`, the
            function is **not** included in the
            module's [`state_dict`][..state_dict].
    """
    if "_functions" not in self.__dict__:
        raise AttributeError(
            "Cannot assign function before Module.__init__() call."
        )
    elif not isinstance(name, str):
        raise TypeError(
            f"Function name should be a string. Got {type(name).__name__}."
        )
    elif "." in name:
        raise KeyError('Function name cannot contain ".".')
    elif name == "":
        raise KeyError('Function name cannot be empty string "".')
    elif hasattr(self, name) and name not in self._functions:
        raise KeyError(f"Attribute '{name}' already exists.")

    if func is None:
        self._functions[name] = None
    else:
        _validate_function(func)  # Validate the function before registering
        self._functions[name] = func

register_module(name, module)

Add a child module to the current module.

This method explicitly adds a child module to the current module's hierarchy. The child module can then be accessed as an attribute using the given name and will be registered in the _modules dictionary.

When to use: - Use register_module() when dynamically adding submodules at runtime, especially when the submodule name is determined programmatically. - This can be useful for creating flexible and modular architectures.

When it's unnecessary: - Directly assigning the module to an attribute (e.g., self.module_name = SubModule()) automatically registers it, so using register_module() is unnecessary in such cases.

Parameters:

Name Type Description Default
name str

Name of the child module. The child module can be accessed from this module using the given name.

required
module Module | None

Child module to be added to the module.

required

Raises:

Type Description
TypeError

If module is not a subclass of Module or if name is not a string.

KeyError

If name is already an attribute of the module but not in _modules, or if name contains invalid characters such as '.' or is empty.

Examples:

>>> class DynamicPipeline(cog.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>         # Dynamically add submodules
>>>         for i in range(3):
>>>             self.register_module(f"layer_{i}", cog.Module())
>>>
>>> pipeline = DynamicPipeline()
>>> print(pipeline._modules.keys())
odict_keys(['layer_0', 'layer_1', 'layer_2'])
Note

If assigning submodules using standard attribute assignment (e.g., self.submodule = SubModule()), calling register_module() explicitly is not required. Direct assignment automatically registers the module.

Source code in afnio/cognitive/modules/module.py
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
def register_module(self, name: str, module: Optional["Module"]) -> None:
    """Add a child module to the current module.

    This method explicitly adds a child module to the current module's hierarchy.
    The child module can then be accessed as an attribute using the given name
    and will be registered in the `_modules` dictionary.

    **When to use**:
        - Use `register_module()` when dynamically adding submodules at runtime,
            especially when the submodule name is determined programmatically.
        - This can be useful for creating flexible and modular architectures.

    **When it's unnecessary**:
        - Directly assigning the module to an attribute (e.g.,
            `self.module_name = SubModule()`) automatically registers it, so using
            `register_module()` is unnecessary in such cases.

    Args:
        name: Name of the child module. The child module can be accessed from
            this module using the given name.
        module: Child module to be added to the module.

    Raises:
        TypeError: If `module` is not a subclass of `Module` or
            if `name` is not a string.
        KeyError: If `name` is already an attribute of the module but not
            in `_modules`, or if `name` contains invalid characters
            such as `'.'` or is empty.

    Examples:
        >>> class DynamicPipeline(cog.Module):
        >>>     def __init__(self):
        >>>         super().__init__()
        >>>         # Dynamically add submodules
        >>>         for i in range(3):
        >>>             self.register_module(f"layer_{i}", cog.Module())
        >>>
        >>> pipeline = DynamicPipeline()
        >>> print(pipeline._modules.keys())
        odict_keys(['layer_0', 'layer_1', 'layer_2'])

    Note:
        If assigning submodules using standard attribute assignment
        (e.g., `self.submodule = SubModule()`), calling `register_module()`
        explicitly is not required. Direct assignment automatically registers
        the module.
    """
    if not isinstance(module, Module) and module is not None:
        raise TypeError(
            f"'{type(module).__name__}' is not a valid Module subclass."
        )
    elif not isinstance(name, str):
        raise TypeError(
            f"Module name must be a string, but got '{type(name).__name__}'."
        )
    elif hasattr(self, name) and name not in self._modules:
        raise KeyError(
            f"Attribute '{name}' already exists and "
            f"cannot be used as a module name."
        )
    elif "." in name:
        raise KeyError(f"Module name cannot contain '.', but got: '{name}'.")
    elif name == "":
        raise KeyError('Module name cannot be an empty string ""')
    self._modules[name] = module

state_dict(*, destination=None, prefix='', keep_vars=False)

Return a dictionary containing references to the whole state of the module.

Parameters, persistent buffers (e.g. running averages), multi-turn chats, models, completion configs and functions are included. Keys are corresponding parameter, buffer, chat, model, completion config and function names. Parameters, buffers, chats, models, completion configs and functions set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module's parameters, buffers, chats, models, completion configs and functions.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:

Name Type Description Default
destination dict

If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

None
prefix str

A prefix added to parameter, buffer, chat, model, completion config and function names to compose the keys in state_dict. Default: ''.

''
keep_vars bool

By default the Variables returned in the state dict are detached from autodiff. If it's set to True, detaching will not be performed. Default: False.

False

Returns:

Name Type Description
dict dict

A dictionary containing a whole state of the module.

Examples:

>>> module.state_dict().keys()
['system_prompt', 'classification_labels', 'format_type', 'user_prompt']
Source code in afnio/cognitive/modules/module.py
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
def state_dict(
    self,
    *,
    destination: T_destination = None,
    prefix: str = "",
    keep_vars: bool = False,
) -> T_destination:
    """Return a dictionary containing references to the whole state of the module.

    Parameters, persistent buffers (e.g. running averages), multi-turn chats,
    models, completion configs and functions are included. Keys are corresponding
    parameter, buffer, chat, model, completion config and function names.
    Parameters, buffers, chats, models, completion configs and functions
    set to `None` are not included.

    Note:
        The returned object is a shallow copy. It contains references
        to the module's parameters, buffers, chats, models, completion configs
        and functions.

    Warning:
        Please avoid the use of argument `destination` as it is not
        designed for end-users.

    Args:
        destination (dict, optional): If provided, the state of module will
            be updated into the dict and the same object is returned.
            Otherwise, an `OrderedDict` will be created and returned.
            Default: `None`.
        prefix (str, optional): A prefix added to parameter, buffer, chat, model,
            completion config and function names to compose the keys in state_dict.
            Default: `''`.
        keep_vars (bool, optional): By default the [`Variable`][afnio.Variable]s
            returned in the state dict are detached from autodiff. If it's
            set to `True`, detaching will not be performed.
            Default: `False`.

    Returns:
        dict (dict): A dictionary containing a whole state of the module.

    Examples:
        >>> module.state_dict().keys()
        ['system_prompt', 'classification_labels', 'format_type', 'user_prompt']
    """
    if destination is None:
        destination = OrderedDict()
        destination._metadata = OrderedDict()

    local_metadata = dict(version=self._version)
    if hasattr(destination, "_metadata"):
        destination._metadata[prefix[:-1]] = local_metadata

    self._save_to_state_dict(destination, prefix, keep_vars)
    for name, module in self._modules.items():
        if module is not None:
            module.state_dict(
                destination=destination,
                prefix=prefix + name + ".",
                keep_vars=keep_vars,
            )
    return destination

load_state_dict(state_dict, strict=True, assign=False, model_clients=None)

Copy parameters, buffers, chats, models, completion configs and functions from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module's state_dict function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Note

If a parameter, or buffer, or chat, or model, or completion config, or function is registered as None and its corresponding key exists in state_dict, load_state_dict will raise a RuntimeError.

Parameters:

Name Type Description Default
state_dict dict

A dict containing parameters, persistent buffers, chats, models, completion configs and functions.

required
strict bool

Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict function. Default: True

True
assign bool

When False, the properties of the Variables in the current module are preserved while when True, the properties of the Variables in the state dict are preserved. The only exception is the requires_grad field of Parameter's for which the value from the module is preserved. Default: False

False
model_clients dict

A dictionary mapping model client keys (e.g., 'fw_model_client') to their respective instances of BaseModel. These instances will be used to reconstruct any model clients referenced within the optimizer state. If a required model client is missing, an error will be raised with instructions on how to provide the missing client.

None

Returns:

Name Type Description
incompatible_keys NamedTuple

A NamedTuple with missing_keys and unexpected_keys fields (see below note for more details).

Note

The return value reports key mismatches encountered during loading:

  • missing_keys is a list of str containing any keys that are expected by this module but missing from the provided state_dict.
  • unexpected_keys is a list of str containing the keys that are not expected by this module but present in the provided state_dict.
Source code in afnio/cognitive/modules/module.py
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
def load_state_dict(
    self,
    state_dict: Mapping[str, Any],
    strict: bool = True,
    assign: bool = False,
    model_clients: Dict[str, BaseModel] = None,
):
    """Copy parameters, buffers, chats, models, completion configs and functions
    from `state_dict` into this module and its descendants.

    If `strict` is `True`, then the keys of `state_dict` must exactly match the keys
    returned by this module's [`state_dict`][..state_dict] function.

    Warning:
        If `assign` is `True` the optimizer must be created after
        the call to [`load_state_dict`][.].

    Note:
        If a parameter, or buffer, or chat, or model, or completion config, or
        function is registered as `None` and its corresponding key exists in
        `state_dict`, [`load_state_dict`][.] will raise a `RuntimeError`.

    Args:
        state_dict (dict): A dict containing parameters, persistent buffers,
            chats, models, completion configs and functions.
        strict (bool, optional): Whether to strictly enforce that the keys
            in `state_dict` match the keys returned by this module's
            [`state_dict`][..state_dict] function. Default: `True`
        assign (bool, optional): When `False`, the properties of the Variables
            in the current module are preserved while when `True`, the
            properties of the Variables in the state dict are preserved. The only
            exception is the `requires_grad` field of
            [`Parameter`][afnio.cognitive.parameter.Parameter]'s for which the value
            from the module is preserved. Default: `False`
        model_clients (dict, optional): A dictionary mapping model client keys
            (e.g., 'fw_model_client') to their respective instances of
            [`BaseModel`][afnio.models.model.BaseModel]. These instances will be
            used to reconstruct any model clients referenced within the optimizer
            state. If a required model client is missing, an error will be raised
            with instructions on how to provide the missing client.

    Returns:
        incompatible_keys (NamedTuple): A `NamedTuple` with `missing_keys` and
            `unexpected_keys` fields (see below note for more details).

    Note:
        The return value reports key mismatches encountered during loading:

        - `missing_keys` is a list of str containing any keys that are expected by
            this module but missing from the provided `state_dict`.
        - `unexpected_keys` is a list of str containing the keys that are not
            expected by this module but present in the provided `state_dict`.

    """
    if not isinstance(state_dict, Mapping):
        raise TypeError(
            f"Expected state_dict to be dict-like, got {type(state_dict)}."
        )

    missing_keys: List[str] = []
    unexpected_keys: List[str] = []
    error_msgs: List[str] = []

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = OrderedDict(state_dict)
    if metadata is not None:
        # mypy isn't aware that "_metadata" exists in state_dict
        state_dict._metadata = metadata  # type: ignore[attr-defined]

    def load(module, local_state_dict, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        if assign:
            local_metadata["assign_to_params_buffers_chats"] = assign
        module._load_from_state_dict(
            local_state_dict,
            prefix,
            local_metadata,
            True,
            missing_keys,
            unexpected_keys,
            error_msgs,
            model_clients,
        )
        for name, child in module._modules.items():
            if child is not None:
                child_prefix = prefix + name + "."
                child_state_dict = {
                    k: v
                    for k, v in local_state_dict.items()
                    if k.startswith(child_prefix)
                }
                load(child, child_state_dict, child_prefix)  # noqa: F821

    load(self, state_dict)
    del load

    if strict:
        if len(unexpected_keys) > 0:
            error_msgs.insert(
                0,
                "Unexpected key(s) in state_dict: {}. ".format(
                    ", ".join(f'"{k}"' for k in unexpected_keys)
                ),
            )
        if len(missing_keys) > 0:
            error_msgs.insert(
                0,
                "Missing key(s) in state_dict: {}. ".format(
                    ", ".join(f'"{k}"' for k in missing_keys)
                ),
            )

    if len(error_msgs) > 0:
        raise RuntimeError(
            "Error(s) in loading state_dict for {}:\n\t{}".format(
                self.__class__.__name__, "\n\t".join(error_msgs)
            )
        )
    return _IncompatibleKeys(missing_keys, unexpected_keys)

get_extra_state()

Return any extra state to include in the module's state_dict.

Implement this and a corresponding set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict.

Returns:

Name Type Description
object Any

Any extra state to store in the module's state_dict.

Source code in afnio/cognitive/modules/module.py
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
def get_extra_state(self) -> Any:
    """Return any extra state to include in the module's state_dict.

    Implement this and a corresponding [`set_extra_state`][..set_extra_state] for
    your module if you need to store extra state. This function is called when
    building the module's `state_dict()`.

    Note that extra state should be picklable to ensure working serialization
    of the state_dict.

    Returns:
        object: Any extra state to store in the module's state_dict.
    """
    raise RuntimeError(
        "Reached a code path in Module.get_extra_state() that "
        "should never be called."
    )

set_extra_state(state)

Set extra state contained in the loaded state_dict.

This function is called from load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state for your module if you need to store extra state within its state_dict.

Parameters:

Name Type Description Default
state dict

Extra state from the state_dict.

required
Source code in afnio/cognitive/modules/module.py
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
def set_extra_state(self, state: Any) -> None:
    """Set extra state contained in the loaded `state_dict`.

    This function is called from [`load_state_dict`][..load_state_dict] to handle
    any extra state found within the `state_dict`. Implement this function and a
    corresponding [`get_extra_state`][..get_extra_state] for your module if you need
    to store extra state within its `state_dict`.

    Args:
        state (dict): Extra state from the `state_dict`.
    """
    raise RuntimeError(
        "Reached a code path in Module.set_extra_state() that "
        "should never be called. "
    )

buffers(recurse=True)

Return an iterator over module buffers.

Parameters:

Name Type Description Default
recurse bool

if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.

True

Yields:

Type Description
Variable

Module buffer

Examples:

>>> for buf in model.buffers():
>>>     print(type(buf), buf.data)
<class 'afnio.Variable'> ("Structure your answer as JSON.")
<class 'afnio.Variable'> ("Use the format\n\n{\n  \"response\": \"Your concise answer here.\"\n}")
Source code in afnio/cognitive/modules/module.py
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
def buffers(self, recurse: bool = True) -> Iterator[Variable]:
    r"""Return an iterator over module buffers.

    Args:
        recurse: if `True`, then yields buffers of this module
            and all submodules. Otherwise, yields only buffers that
            are direct members of this module.

    Yields:
        Module buffer

    Examples:
        >>> for buf in model.buffers():
        >>>     print(type(buf), buf.data)
        <class 'afnio.Variable'> ("Structure your answer as JSON.")
        <class 'afnio.Variable'> ("Use the format\n\n{\n  \"response\": \"Your concise answer here.\"\n}")
    """  # noqa: E501
    for _, buf in self.named_buffers(recurse=recurse):
        yield buf

named_buffers(prefix='', recurse=True, remove_duplicate=True)

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:

Name Type Description Default
prefix str

prefix to prepend to all buffer names.

''
recurse bool

if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

True
remove_duplicate bool

whether to remove the duplicated buffers in the result. Defaults to True.

True

Yields:

Type Description
tuple[str, Variable]

Tuple containing the name and buffer

Examples:

>>> for name, buf in self.named_buffers():
>>>     if "format_type" in name:
>>>         print(param.data)
Source code in afnio/cognitive/modules/module.py
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
def named_buffers(
    self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Variable]]:
    r"""Return an iterator over module buffers, yielding both the name of
    the buffer as well as the buffer itself.

    Args:
        prefix (str): prefix to prepend to all buffer names.
        recurse (bool, optional): if `True`, then yields buffers of this module
            and all submodules. Otherwise, yields only buffers that
            are direct members of this module. Defaults to `True`.
        remove_duplicate (bool, optional): whether to remove the duplicated buffers
            in the result. Defaults to `True`.

    Yields:
        Tuple containing the name and buffer

    Examples:
        >>> for name, buf in self.named_buffers():
        >>>     if "format_type" in name:
        >>>         print(param.data)
    """
    gen = self._named_members(
        lambda module: module._buffers.items(),
        prefix=prefix,
        recurse=recurse,
        remove_duplicate=remove_duplicate,
    )
    yield from gen

parameters(recurse=True)

Return an iterator over module parameters.

This is typically passed to an optimizer.

Parameters:

Name Type Description Default
recurse bool

if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

True

Yields:

Type Description
Parameter

Module parameter

Examples:

>>> for param in pipeline.parameters():
>>>     print(type(param), param.data)
<class 'cog.Parameter'> ("You are a doctor.")
<class 'cog.Parameter'> ("Only answer with YES or NO.")
Source code in afnio/cognitive/modules/module.py
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
    """Return an iterator over module parameters.

    This is typically passed to an optimizer.

    Args:
        recurse (bool): if `True`, then yields parameters of this module
            and all submodules. Otherwise, yields only parameters that
            are direct members of this module.

    Yields:
        Module parameter

    Examples:
        >>> for param in pipeline.parameters():
        >>>     print(type(param), param.data)
        <class 'cog.Parameter'> ("You are a doctor.")
        <class 'cog.Parameter'> ("Only answer with YES or NO.")
    """
    for _, param in self.named_parameters(recurse=recurse):
        yield param

named_parameters(prefix='', recurse=True, remove_duplicate=True)

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:

Name Type Description Default
prefix str

prefix to prepend to all parameter names.

''
recurse bool

if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

True
remove_duplicate bool

whether to remove the duplicated parameters in the result. Defaults to True.

True

Yields:

Type Description
tuple[str, Parameter]

Tuple containing the name and parameter

Examples:

>>> for name, param in self.named_parameters():
>>>     if "prompt" in name:
>>>         print(param.data)
Source code in afnio/cognitive/modules/module.py
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
def named_parameters(
    self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Parameter]]:
    """Return an iterator over module parameters, yielding both the name of the
    parameter as well as the parameter itself.

    Args:
        prefix (str): prefix to prepend to all parameter names.
        recurse (bool): if `True`, then yields parameters of this module
            and all submodules. Otherwise, yields only parameters that
            are direct members of this module.
        remove_duplicate (bool, optional): whether to remove the duplicated
            parameters in the result. Defaults to `True`.

    Yields:
        Tuple containing the name and parameter

    Examples:
        >>> for name, param in self.named_parameters():
        >>>     if "prompt" in name:
        >>>         print(param.data)
    """
    gen = self._named_members(
        lambda module: module._parameters.items(),
        prefix=prefix,
        recurse=recurse,
        remove_duplicate=remove_duplicate,
    )
    yield from gen

chats(recurse=True)

Return an iterator over module multi-turn chats.

This is typically passed to an optimizer.

Parameters:

Name Type Description Default
recurse bool

if True, then yields chats of this module and all submodules. Otherwise, yields only chats that are direct members of this module.

True

Yields:

Type Description
MultiTurnMessages

Module chats

Examples:

>>> for chat in pipeline.chats():
>>>     print(type(chat), chat)
<class 'cog.MultiTurnMessages'> [{'role': 'system', 'content': [Variable(data=You are a doctor., role=system instruction, requires_grad=False)]}, {'role': 'user', 'content': [Variable(data=Is {item} a disease?, role=user query, requires_grad=False)]}]
<class 'cog.MultiTurnMessages'> [{'role': 'system', 'content': [Variable(data=You are a helpful assistant., role=system instruction, requires_grad=False), Variable(data=Only answer with YES or NO., role=user query, requires_grad=False)]}]
Source code in afnio/cognitive/modules/module.py
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
def chats(self, recurse: bool = True) -> Iterator[MultiTurnMessages]:
    """Return an iterator over module multi-turn chats.

    This is typically passed to an optimizer.

    Args:
        recurse (bool): if `True`, then yields chats of this module
            and all submodules. Otherwise, yields only chats that
            are direct members of this module.

    Yields:
        Module chats

    Examples:
        >>> for chat in pipeline.chats():
        >>>     print(type(chat), chat)
        <class 'cog.MultiTurnMessages'> [{'role': 'system', 'content': [Variable(data=You are a doctor., role=system instruction, requires_grad=False)]}, {'role': 'user', 'content': [Variable(data=Is {item} a disease?, role=user query, requires_grad=False)]}]
        <class 'cog.MultiTurnMessages'> [{'role': 'system', 'content': [Variable(data=You are a helpful assistant., role=system instruction, requires_grad=False), Variable(data=Only answer with YES or NO., role=user query, requires_grad=False)]}]
    """  # noqa: E501
    for _, chat in self.named_chats(recurse=recurse):
        yield chat

named_chats(prefix='', recurse=True, remove_duplicate=True)

Return an iterator over module multi-turn chats, yielding both the name of chat as well as the chat itself.

Parameters:

Name Type Description Default
prefix str

prefix to prepend to all chat names.

''
recurse bool

if True, then yields chats of this module and all submodules. Otherwise, yields only chats that are direct members of this module.

True
remove_duplicate bool

whether to remove the duplicated chats in the result. Defaults to True.

True

Yields:

Type Description
tuple[str, MultiTurnMessages]

Tuple containing the name and chat

Examples:

>>> for name, chat in self.named_chats():
>>>     if "messages" in name:
>>>         print(messages[0]["role"])
Source code in afnio/cognitive/modules/module.py
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
def named_chats(
    self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, MultiTurnMessages]]:
    """Return an iterator over module multi-turn chats, yielding both
    the name of chat as well as the chat itself.

    Args:
        prefix (str): prefix to prepend to all chat names.
        recurse (bool): if `True`, then yields chats of this module
            and all submodules. Otherwise, yields only chats that
            are direct members of this module.
        remove_duplicate (bool, optional): whether to remove the duplicated
            chats in the result. Defaults to `True`.

    Yields:
        Tuple containing the name and chat

    Examples:
        >>> for name, chat in self.named_chats():
        >>>     if "messages" in name:
        >>>         print(messages[0]["role"])
    """
    gen = self._named_members(
        lambda module: module._chats.items(),
        prefix=prefix,
        recurse=recurse,
        remove_duplicate=remove_duplicate,
    )
    yield from gen

models(recurse=True)

Return an iterator over module language model clients.

Parameters:

Name Type Description Default
recurse bool

if True, then yields models of this module and all submodules. Otherwise, yields only models that are direct members of this module.

True

Yields:

Type Description
BaseModel

Module model

Examples:

>>> for model in pipeline.models():
>>>     print(type(model))
<class 'afnio.models.openai.AsyncOpenAI'>
Source code in afnio/cognitive/modules/module.py
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
def models(self, recurse: bool = True) -> Iterator[BaseModel]:
    """Return an iterator over module language model clients.

    Args:
        recurse (bool): if `True`, then yields models of this module
            and all submodules. Otherwise, yields only models that
            are direct members of this module.

    Yields:
        Module model

    Examples:
        >>> for model in pipeline.models():
        >>>     print(type(model))
        <class 'afnio.models.openai.AsyncOpenAI'>
    """
    for _, model in self.named_models(recurse=recurse):
        yield model

named_models(prefix='', recurse=True, remove_duplicate=True)

Return an iterator over module model clients, yielding both the name of the model as well as the model itself.

Parameters:

Name Type Description Default
prefix str

prefix to prepend to all model names.

''
recurse bool

if True, then yields models of this module and all submodules. Otherwise, yields only models that are direct members of this module.

True
remove_duplicate bool

whether to remove the duplicated models in the result. Defaults to True.

True

Yields:

Type Description
tuple[str, BaseModel]

Tuple containing the name and model

Examples:

>>> for name, model in self.named_models():
>>>     print(name, type(model))
model_client <class 'afnio.models.openai.AsyncOpenAI'>
Source code in afnio/cognitive/modules/module.py
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
def named_models(
    self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, BaseModel]]:
    """Return an iterator over module model clients, yielding both the name of the
    model as well as the model itself.

    Args:
        prefix (str): prefix to prepend to all model names.
        recurse (bool): if `True`, then yields models of this module
            and all submodules. Otherwise, yields only models that
            are direct members of this module.
        remove_duplicate (bool, optional): whether to remove the duplicated
            models in the result. Defaults to `True`.

    Yields:
        Tuple containing the name and model

    Examples:
        >>> for name, model in self.named_models():
        >>>     print(name, type(model))
        model_client <class 'afnio.models.openai.AsyncOpenAI'>
    """
    gen = self._named_members(
        lambda module: module._models.items(),
        prefix=prefix,
        recurse=recurse,
        remove_duplicate=remove_duplicate,
    )
    yield from gen

completion_configs(recurse=True)

Return an iterator over registered completion configs.

Parameters:

Name Type Description Default
recurse bool

if True, then yields completion configs of this module and all submodules. Otherwise, yields only completion configs that are direct members of this module.

True

Yields:

Type Description
dict[str, Any]

Completion arguments

Examples:

>>> for config in model.completion_configs():
>>>     print(config)
{"model": "gpt-4o", "seed": 42, "temperature": 0}
Source code in afnio/cognitive/modules/module.py
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
def completion_configs(self, recurse: bool = True) -> Iterator[Dict[str, Any]]:
    """Return an iterator over registered completion configs.

    Args:
        recurse (bool): if `True`, then yields completion configs of this module
            and all submodules. Otherwise, yields only completion configs that
            are direct members of this module.

    Yields:
        Completion arguments

    Examples:
        >>> for config in model.completion_configs():
        >>>     print(config)
        {"model": "gpt-4o", "seed": 42, "temperature": 0}
    """
    for _, config in self.named_completion_configs(recurse=recurse):
        yield config

named_completion_configs(prefix='', recurse=True, remove_duplicate=True)

Return an iterator over module completion configs, yielding both the name of the completion config as well as the completion config itself.

Parameters:

Name Type Description Default
prefix str

prefix to prepend to all completion config names.

''
recurse bool

if True, then yields completion configs of this module and all submodules. Otherwise, yields only completion configs that are direct members of this module.

True
remove_duplicate bool

whether to remove the duplicated completion configs in the result. Defaults to True.

True

Yields:

Type Description
tuple[str, dict[str, Any]]

Tuple containing the name and completion configs

Examples:

>>> for name, config in self.named_completion_configs():
>>>     print(name, type(config))
chat.completion_args {'model': 'gpt-4o', 'seed': 42, 'temperature': 0}
Source code in afnio/cognitive/modules/module.py
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
def named_completion_configs(
    self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Dict[str, Any]]]:
    """Return an iterator over module completion configs, yielding both the name of
    the completion config as well as the completion config itself.

    Args:
        prefix (str): prefix to prepend to all completion config names.
        recurse (bool): if `True`, then yields completion configs of this module
            and all submodules. Otherwise, yields only completion configs that
            are direct members of this module.
        remove_duplicate (bool, optional): whether to remove the duplicated
            completion configs in the result. Defaults to `True`.

    Yields:
        Tuple containing the name and completion configs

    Examples:
        >>> for name, config in self.named_completion_configs():
        >>>     print(name, type(config))
        chat.completion_args {'model': 'gpt-4o', 'seed': 42, 'temperature': 0}
    """
    gen = self._named_members(
        lambda module: module._completion_configs.items(),
        prefix=prefix,
        recurse=recurse,
        remove_duplicate=remove_duplicate,
    )
    yield from gen

functions(recurse=True)

Return an iterator over registered functions.

Parameters:

Name Type Description Default
recurse bool

if True, then yields functions of this module and all submodules. Otherwise, yields only functions that are direct members of this module.

True

Yields:

Type Description
dict[str, Any]

Functions

Examples:

>>> for func in model.functions():
>>>     print(func)
<built-in function sum>
<function my_func at 0x7e7a0665b9c0>
Source code in afnio/cognitive/modules/module.py
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
def functions(self, recurse: bool = True) -> Iterator[Dict[str, Any]]:
    """Return an iterator over registered functions.

    Args:
        recurse (bool): if `True`, then yields functions of this module
            and all submodules. Otherwise, yields only functions that
            are direct members of this module.

    Yields:
        Functions

    Examples:
        >>> for func in model.functions():
        >>>     print(func)
        <built-in function sum>
        <function my_func at 0x7e7a0665b9c0>
    """
    for _, config in self.named_functions(recurse=recurse):
        yield config

named_functions(prefix='', recurse=True, remove_duplicate=True)

Return an iterator over module functions, yielding both the name of the function as well as the function itself.

Parameters:

Name Type Description Default
prefix str

prefix to prepend to all function names.

''
recurse bool

if True, then yields functions of this module and all submodules. Otherwise, yields only functions that are direct members of this module.

True
remove_duplicate bool

whether to remove the duplicated functions in the result. Defaults to True.

True

Yields:

Type Description
tuple[str, dict[str, Any]]

Tuple containing the name and functions

Examples:

>>> for name, func in self.named_functions():
>>>     print(name, func)
reduction_fn <built-in function sum>
eval_fn <function my_func at 0x7e7a0665b9c0>
Source code in afnio/cognitive/modules/module.py
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
def named_functions(
    self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Dict[str, Any]]]:
    """Return an iterator over module functions, yielding both the name of
    the function as well as the function itself.

    Args:
        prefix (str): prefix to prepend to all function names.
        recurse (bool): if `True`, then yields functions of this module
            and all submodules. Otherwise, yields only functions that
            are direct members of this module.
        remove_duplicate (bool, optional): whether to remove the duplicated
            functions in the result. Defaults to `True`.

    Yields:
        Tuple containing the name and functions

    Examples:
        >>> for name, func in self.named_functions():
        >>>     print(name, func)
        reduction_fn <built-in function sum>
        eval_fn <function my_func at 0x7e7a0665b9c0>
    """
    gen = self._named_members(
        lambda module: module._functions.items(),
        prefix=prefix,
        recurse=recurse,
        remove_duplicate=remove_duplicate,
    )
    yield from gen

children()

Return an iterator over immediate children modules.

Yields:

Type Description
Module

A child module

Source code in afnio/cognitive/modules/module.py
1929
1930
1931
1932
1933
1934
1935
1936
def children(self) -> Iterator["Module"]:
    """Return an iterator over immediate children modules.

    Yields:
        A child module
    """
    for _, module in self.named_children():
        yield module

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields:

Type Description
tuple[str, Module]

Tuple containing a name and child module

Source code in afnio/cognitive/modules/module.py
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
    """Return an iterator over immediate children modules, yielding both the name
    of the module as well as the module itself.

    Yields:
        Tuple containing a name and child module
    """
    memo = set()
    for name, module in self._modules.items():
        if module is not None and module not in memo:
            memo.add(module)
            yield name, module

modules()

Return an iterator over all modules in the network.

Yields:

Type Description
Module

A module in the network

Note

Duplicate modules are returned only once. In the following example, add will be returned only once.

Examples:

>>> class MyPipeline(cog.Module):
...     def __init__(self):
...         super().__init__()
...         add = cog.Add()
...         self.module1 = add
...         self.module2 = add
>>>     def forward(self, x, y):
...         out1 = self.module1(x, x)
...         out2 = self.module2(x, y)
...         return out1 + out2
>>> pipeline = MyPipeline()
>>> for idx, m in enumerate(model.modules()):
...     print(idx, '->', m)
0 -> MyModel(
(module1): Module()
(module2): Module()
)
1 -> Module()
Source code in afnio/cognitive/modules/module.py
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
def modules(self) -> Iterator["Module"]:
    """Return an iterator over all modules in the network.

    Yields:
        A module in the network

    Note:
        Duplicate modules are returned only once. In the following
        example, `add` will be returned only once.

    Examples:
        >>> class MyPipeline(cog.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         add = cog.Add()
        ...         self.module1 = add
        ...         self.module2 = add
        >>>     def forward(self, x, y):
        ...         out1 = self.module1(x, x)
        ...         out2 = self.module2(x, y)
        ...         return out1 + out2
        >>> pipeline = MyPipeline()
        >>> for idx, m in enumerate(model.modules()):
        ...     print(idx, '->', m)
        0 -> MyModel(
        (module1): Module()
        (module2): Module()
        )
        1 -> Module()
    """
    for _, module in self.named_modules():
        yield module

named_modules(memo=None, prefix='', remove_duplicate=True)

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Parameters:

Name Type Description Default
memo set[Module] | None

a memo to store the set of modules already added to the result

None
prefix str

a prefix that will be added to the name of the module

''
remove_duplicate bool

whether to remove the duplicated module instances in the result or not

True

Yields:

Type Description
tuple[str, Module]

Tuple of name and module

Note

Duplicate modules are returned only once. In the following example, add will be returned only once.

Examples:

>>> class MyPipeline(cog.Module):
...     def __init__(self):
...     super().__init__()
...     add = cog.Add()
...     self.module1 = add
...     self.module2 = add
>>> def forward(self, x, y):
...     out1 = self.module1(x, x)
...     out2 = self.module2(x, y)
...     return out1 + out2
>>> pipeline = MyPipeline()
>>> for idx, m in enumerate(model.named_modules()):
...     print(idx, '->', m)
0 -> ('', MyModel(
(module1): Module()
(module2): Module()
))
1 -> ('module1', Module())
>>> class MyPipeline(cog.Module):
...     def __init__(self):
...     super().__init__()
...     add = cog.Add()
...     self.module1 = add
...     self.module2 = add
>>> def forward(self, x, y):
...     out1 = self.module1(x, x)
...     out2 = self.module2(x, y)
...     return out1 + out2
>>> pipeline = MyPipeline()
>>> for idx, m in enumerate(model.named_modules(remove_duplicate=False)):
...     print(idx, '->', m)
0 -> ('', MyModel(
(module1): Module()
(module2): Module()
))
1 -> ('module1', Module())
2 -> ('module2', Module())
Source code in afnio/cognitive/modules/module.py
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
def named_modules(
    self,
    memo: Optional[Set["Module"]] = None,
    prefix: str = "",
    remove_duplicate: bool = True,
) -> Iterator[Tuple[str, "Module"]]:
    """Return an iterator over all modules in the network, yielding both
    the name of the module as well as the module itself.

    Args:
        memo: a memo to store the set of modules already added to the result
        prefix: a prefix that will be added to the name of the module
        remove_duplicate: whether to remove the duplicated module instances
            in the result or not

    Yields:
        Tuple of name and module

    Note:
        Duplicate modules are returned only once. In the following
        example, `add` will be returned only once.

    Examples:
        >>> class MyPipeline(cog.Module):
        ...     def __init__(self):
        ...     super().__init__()
        ...     add = cog.Add()
        ...     self.module1 = add
        ...     self.module2 = add
        >>> def forward(self, x, y):
        ...     out1 = self.module1(x, x)
        ...     out2 = self.module2(x, y)
        ...     return out1 + out2
        >>> pipeline = MyPipeline()
        >>> for idx, m in enumerate(model.named_modules()):
        ...     print(idx, '->', m)
        0 -> ('', MyModel(
        (module1): Module()
        (module2): Module()
        ))
        1 -> ('module1', Module())

        >>> class MyPipeline(cog.Module):
        ...     def __init__(self):
        ...     super().__init__()
        ...     add = cog.Add()
        ...     self.module1 = add
        ...     self.module2 = add
        >>> def forward(self, x, y):
        ...     out1 = self.module1(x, x)
        ...     out2 = self.module2(x, y)
        ...     return out1 + out2
        >>> pipeline = MyPipeline()
        >>> for idx, m in enumerate(model.named_modules(remove_duplicate=False)):
        ...     print(idx, '->', m)
        0 -> ('', MyModel(
        (module1): Module()
        (module2): Module()
        ))
        1 -> ('module1', Module())
        2 -> ('module2', Module())
    """
    if memo is None:
        memo = set()
    if self not in memo:
        if remove_duplicate:
            memo.add(self)
        yield prefix, self
        for name, module in self._modules.items():
            if module is None:
                continue
            submodule_prefix = prefix + ("." if prefix else "") + name
            yield from module.named_modules(
                memo, submodule_prefix, remove_duplicate
            )

train(mode=True)

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected.

Parameters:

Name Type Description Default
mode bool

whether to set training mode (True) or evaluation mode (False).

True

Returns:

Name Type Description
self Module

The module itself.

Source code in afnio/cognitive/modules/module.py
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
def train(self: T, mode: bool = True) -> T:
    """Set the module in training mode.

    This has any effect only on certain modules. See documentations of
    particular modules for details of their behaviors in training/evaluation
    mode, if they are affected.

    Args:
        mode: whether to set training mode (`True`) or evaluation mode (`False`).

    Returns:
        self (Module): The module itself.
    """
    if not isinstance(mode, bool):
        raise ValueError("Training mode is expected to be boolean.")
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self

eval()

Set the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected.

This is equivalent with calling self.train(False). See train for more details.

Returns:

Name Type Description
self Module

The module itself.

Source code in afnio/cognitive/modules/module.py
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
def eval(self: T) -> T:
    """Set the module in evaluation mode.

    This has any effect only on certain modules. See documentations of
    particular modules for details of their behaviors in training/evaluation
    mode, if they are affected.

    This is equivalent with calling `self.train(False)`.
    See [`train`][..train] for more details.

    Returns:
        self (Module): The module itself.
    """
    return self.train(False)

requires_grad_(requires_grad=True)

Change if autodiff should record operations on parameters and chats in this module.

This method sets the requires_grad attributes of all module parameters in-place. It also sets the requires_grad attributes of all the Variables within the content of multi-turn chats.

Effect on Parameters:

  • Sets requires_grad for each registered parameter in the module.

Effect on Chats:

  • Iterates through all multi-turn chats and sets requires_grad for each Variable in the "content" key of the chat's message.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually.

Parameters:

Name Type Description Default
requires_grad bool

Whether autodiff should record operations on parameters and chats in this module.

True

Returns:

Name Type Description
self Module

The module itself.

Source code in afnio/cognitive/modules/module.py
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
def requires_grad_(self: T, requires_grad: bool = True) -> T:
    """Change if autodiff should record operations on parameters and chats
    in this module.

    This method sets the [`requires_grad`][afnio.Variable.requires_grad] attributes
    of all module parameters in-place. It also sets the
    [`requires_grad`][afnio.Variable.requires_grad] attributes of all the
    `Variables` within the content of multi-turn chats.

    **Effect on Parameters:**

    - Sets [`requires_grad`][afnio.Variable.requires_grad] for each registered
        parameter in the module.

    **Effect on Chats:**

    - Iterates through all multi-turn chats and sets
        [`requires_grad`][afnio.Variable.requires_grad] for each `Variable` in
        the `"content"` key of the chat's message.

    This method is helpful for freezing part of the module for finetuning
    or training parts of a model individually.

    Args:
        requires_grad: Whether autodiff should record operations on parameters and
            chats in this module.

    Returns:
        self (Module): The module itself.
    """
    # Set requires_grad on all parameters
    for p in self.parameters():
        p.requires_grad_(requires_grad)

    # Set requires_grad on all variables in message content
    for chat in self.chats():
        for message in chat:
            for variable in message["content"]:
                variable.requires_grad_(requires_grad)

    return self

empty_grad()

Reset gradients of all model parameters and content variables in chats' messages.

This method is useful for clearing out gradients before starting a new optimization step. It ensures that both module parameters and Variables within multi-turn chat's message contents have their gradients reset, avoiding unintended gradient accumulation.

Source code in afnio/cognitive/modules/module.py
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
def empty_grad(self) -> None:
    """Reset gradients of all model parameters and content variables
    in chats' messages.

    This method is useful for clearing out gradients before starting a new
    optimization step. It ensures that both module parameters and Variables within
    multi-turn chat's message contents have their gradients reset, avoiding
    unintended gradient accumulation.
    """
    # Reset gradients of all parameters
    for p in self.parameters():
        if p.grad:
            p.grad = []

    # Reset gradients of all variables in message content
    for chat in self.chats():
        for message in chat:
            for variable in message["content"]:
                if variable.grad:
                    variable.grad = []

training_step(batch, batch_idx)

Perform a single training step.

This method should be implemented in subclasses to define the training logic. It is called by the Trainer during the training loop.

Parameters:

Name Type Description Default
batch Any

The output of your data iterable, normally a [DataLoader][afnio.util.data.DataLoader].

required
batch_idx int

The index of this batch.

required

Returns:

Type Description
STEP_OUTPUT

The result of the training step (see below below note for details).

Notes

The return value can be one of the following:

  • Tuple[Variable, Variable]: The loss as a tuple of two Variables:
    • The evaluation score (a Variable containing the loss value).
    • The explanation (a Variable containing a string explanation of the evaluation result).
  • dict: A dictionary. Can include any keys, but must include the key 'loss' containing a tuple of two Variables (score and explanation).
  • None: Skip to the next batch.

Raises:

Type Description
NotImplementedError

If not implemented in a subclass.

Source code in afnio/cognitive/modules/module.py
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
    """Perform a single training step.

    This method should be implemented in subclasses to define the training logic.
    It is called by the [`Trainer`][afnio.trainer.trainer.Trainer]
    during the training loop.

    Args:
        batch: The output of your data iterable, normally
            a [`DataLoader`][afnio.util.data.DataLoader].
        batch_idx: The index of this batch.

    Returns:
        The result of the training step (see below below note for details).

    Notes:
        The return value can be one of the following:

        - `Tuple[Variable, Variable]`: The loss as a tuple of two `Variable`s:
            - The evaluation `score` (a `Variable` containing the loss value).
            - The `explanation` (a `Variable` containing a string explanation
                of the evaluation result).
        - `dict`: A dictionary. Can include any keys, but must include
            the key `'loss'` containing a tuple of two `Variable`s
            (`score` and `explanation`).
        - `None`: Skip to the next batch.

    Raises:
        NotImplementedError: If not implemented in a subclass.
    """
    raise NotImplementedError(
        "You must implement training_step in your Module subclass."
    )

validation_step(batch, batch_idx)

Perform a single validation step.

This method should be implemented in subclasses to define the validation logic. It is called by the Trainer during the validation loop.

Parameters:

Name Type Description Default
batch Any

The output of your data iterable, normally a [DataLoader][afnio.util.data.DataLoader].

required
batch_idx int

The index of this batch.

required

Returns:

Type Description
STEP_OUTPUT

The result of the validation step (see below below note for details).

Notes

The return value can be one of the following:

  • Tuple[Variable, Variable]: The loss as a tuple of two Variables:
    • The evaluation score (a Variable containing the loss value).
    • The explanation (a Variable containing a string explanation of the evaluation result).
  • dict: A dictionary. Can include any keys, but must include the key 'loss' containing a tuple of two Variables (score and explanation).
  • None: Skip to the next batch.

Raises:

Type Description
NotImplementedError

If not implemented in a subclass.

Source code in afnio/cognitive/modules/module.py
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
    """Perform a single validation step.

    This method should be implemented in subclasses to define the validation logic.
    It is called by the [`Trainer`][afnio.trainer.trainer.Trainer]
    during the validation loop.

    Args:
        batch: The output of your data iterable,
            normally a [`DataLoader`][afnio.util.data.DataLoader].
        batch_idx: The index of this batch.

    Returns:
        The result of the validation step (see below below note for details).

    Notes:
        The return value can be one of the following:

        - `Tuple[Variable, Variable]`: The loss as a tuple of two `Variable`s:
            - The evaluation `score` (a `Variable` containing the loss value).
            - The `explanation` (a `Variable` containing a string explanation
                of the evaluation result).
        - `dict`: A dictionary. Can include any keys, but must include
            the key `'loss'` containing a tuple of two `Variable`s
            (`score` and `explanation`).
        - `None`: Skip to the next batch.

    Raises:
        NotImplementedError: If not implemented in a subclass.
    """
    raise NotImplementedError(
        "You must implement validation_step in your Module subclass."
    )

test_step(batch, batch_idx)

Perform a single test step.

This method should be implemented in subclasses to define the test logic. It is called by the Trainer during the testing loop.

Parameters:

Name Type Description Default
batch Any

The output of your data iterable, normally a [DataLoader][afnio.util.data.DataLoader].

required
batch_idx int

The index of this batch.

required

Returns:

Type Description
STEP_OUTPUT

The result of the test step (see below below note for details).

Notes

The return value can be one of the following:

  • Tuple[Variable, Variable]: The loss as a tuple of two Variables:
    • The evaluation score (a Variable containing the loss value).
    • The explanation (a Variable containing a string explanation of the evaluation result).
  • dict: A dictionary. Can include any keys, but must include the key 'loss' containing a tuple of two Variables (score and explanation).
  • None: Skip to the next batch.

Raises:

Type Description
NotImplementedError

If not implemented in a subclass.

Source code in afnio/cognitive/modules/module.py
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
def test_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
    """Perform a single test step.

    This method should be implemented in subclasses to define the test logic.
    It is called by the [`Trainer`][afnio.trainer.trainer.Trainer]
    during the testing loop.

    Args:
        batch: The output of your data iterable,
            normally a [`DataLoader`][afnio.util.data.DataLoader].
        batch_idx: The index of this batch.

    Returns:
        The result of the test step (see below below note for details).

    Notes:
        The return value can be one of the following:

        - `Tuple[Variable, Variable]`: The loss as a tuple of two `Variable`s:
            - The evaluation `score` (a `Variable` containing the loss value).
            - The `explanation` (a `Variable` containing a string explanation
                of the evaluation result).
        - `dict`: A dictionary. Can include any keys, but must include
            the key `'loss'` containing a tuple of two `Variable`s
            (`score` and `explanation`).
        - None: Skip to the next batch.

    Raises:
        NotImplementedError: If not implemented in a subclass.
    """
    raise NotImplementedError(
        "You must implement test_step in your Module subclass."
    )

configure_optimizers()

Configure and return the optimizer for this module.

This method should be implemented in subclasses to define the optimizer configuration. It is called by the Trainer to set up the optimization routine.

Returns:

Type Description
Optimizer

An instance of an optimizer configured for this module.

Raises:

Type Description
NotImplementedError

If not implemented in a subclass.

Source code in afnio/cognitive/modules/module.py
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
def configure_optimizers(self) -> Optimizer:
    """Configure and return the optimizer for this module.

    This method should be implemented in subclasses to define the optimizer
    configuration. It is called by the [`Trainer`][afnio.trainer.trainer.Trainer]
    to set up the optimization routine.

    Returns:
        An instance of an optimizer configured for this module.

    Raises:
        NotImplementedError: If not implemented in a subclass.
    """
    raise NotImplementedError(
        "You must implement configure_optimizers in your Module subclass."
    )

optimizers()

Returns the optimizer(s) that are being used during training. Useful for manual optimization.

This method is useful for accessing the optimizer(s) configured in the configure_optimizers method by the Trainer.fit() method.

Returns:

Type Description
Optimizer | list[Optimizer]

The optimizer(s) used by this module.

Examples:

>>> optimizers = model.optimizers()
>>> for optimizer in optimizers:
>>>     print(optimizer)
TGD (
Parameter Group 0
    completion_args: {'model': 'gpt-4.1'}
    constraints: []
    inputs: {}
    messages: [
    {'role': 'system', 'content': [Variable(data="Placeholder Textual Gradient Descent optimizer system prompt", role=Textual Gradient Descent optimizer system prompt, requires_grad=False)]},
    {'role': 'user', 'content': [Variable(data="Placeholder for Textual Gradient Descent optimizer user prompt", role=Textual Gradient Descent optimizer user prompt, requires_grad=False)]}
    ]
    model_client: <afnio.models.openai.AsyncOpenAI object at 0x710df9c149a0>
    momentum: 3
)
Source code in afnio/cognitive/modules/module.py
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
def optimizers(self) -> Union[Optimizer, List[Optimizer]]:
    """Returns the optimizer(s) that are being used during training. Useful for
    manual optimization.

    This method is useful for accessing the optimizer(s) configured in the
    [`configure_optimizers`][..configure_optimizers] method by the
    [`Trainer.fit()`][afnio.trainer.trainer.Trainer.fit] method.

    Returns:
        The optimizer(s) used by this module.

    Examples:
        >>> optimizers = model.optimizers()
        >>> for optimizer in optimizers:
        >>>     print(optimizer)
        TGD (
        Parameter Group 0
            completion_args: {'model': 'gpt-4.1'}
            constraints: []
            inputs: {}
            messages: [
            {'role': 'system', 'content': [Variable(data="Placeholder Textual Gradient Descent optimizer system prompt", role=Textual Gradient Descent optimizer system prompt, requires_grad=False)]},
            {'role': 'user', 'content': [Variable(data="Placeholder for Textual Gradient Descent optimizer user prompt", role=Textual Gradient Descent optimizer user prompt, requires_grad=False)]}
            ]
            model_client: <afnio.models.openai.AsyncOpenAI object at 0x710df9c149a0>
            momentum: 3
        )
    """  # noqa: E501
    if self._optimizers is not None:
        return self._optimizers
    raise AttributeError(
        "No optimizer found. Did you call `configure_optimizers()` "
        "and did the `Trainer` set `_optimizers`?"
    )