Skip to content

afnio.utils.data.dataloader

afnio.utils.data.dataloader.DataLoader

Bases: Generic[T_co]

Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.

The DataLoader supports both map-style and iterable-style datasets with single-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.

See afnio.utils.data documentation page for more details.

Source code in afnio/utils/data/dataloader.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class DataLoader(Generic[T_co]):
    """
    Data loader combines a dataset and a sampler, and provides an iterable over the
    given dataset.

    The [`DataLoader`][afnio.utils.data.DataLoader] supports both map-style and
    iterable-style datasets with single-process loading, customizing loading order
    and optional automatic batching (collation) and memory pinning.

    See [`afnio.utils.data`][afnio.utils.data] documentation page for more details.
    """

    dataset: Dataset[T_co]
    batch_size: Optional[int]
    drop_last: bool
    sampler: Union[Sampler, Iterable]
    __initialized = False

    def __init__(
        self,
        dataset: Dataset[T_co],
        batch_size: Optional[int] = 1,
        shuffle: Optional[bool] = False,
        sampler: Union[Sampler, Iterable, None] = None,
        drop_last: bool = False,
        seed: Optional[int] = None,
    ):
        """Initializes the `DataLoader` with the given dataset and options.

        Args:
            dataset: Dataset from which to load the data.
            batch_size: How many samples per batch to load.
            shuffle: Set to `True` to have the data reshuffled at every epoch.
            sampler: Defines the strategy to draw samples from the dataset. Can be any
                `Iterable` with `__len__` implemented. If specified, `shuffle`
                must not be specified.
            drop_last: Set to `True` to drop the last incomplete batch, if the dataset
                size is not divisible by the batch size. If `False` and the size of
                dataset is not divisible by the batch size, then the last batch
                will be smaller.
            seed: If not `None`, this seed will be used by
                [`RandomSampler`][afnio.utils.data.RandomSampler]
                to generate random indexes.
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        if shuffle not in {True, False}:
            raise ValueError(
                f"DataLoader with IterableDataset: "
                f"expected unspecified shuffle option, but got shuffle={shuffle}"
            )

        if sampler is not None and shuffle:
            raise ValueError("sampler option is mutually exclusive with shuffle")

        if sampler is None:
            if shuffle:
                sampler = RandomSampler(dataset, seed=seed)
            else:
                sampler = SequentialSampler(dataset)

        self.index_sampler = sampler
        self._sampler_iter = iter(self.index_sampler)
        self.__initialized = True

    def __iter__(self) -> Iterable[Any]:
        self._sampler_iter = iter(self.index_sampler)  # Ensure new iterator every time
        return self

    def _next_index(self):
        return next(self._sampler_iter)

    def __next__(self) -> Any:
        """Returns the next batch from the dataset, collated according to the structure
        of the dataset's `__getitem__` output.

        **Batching logic:**

        - If the dataset returns a dictionary, this method aggregates each key across
          the batch into a list of values. For example, if each sample is
          `{'a': 'foo', 'b': 'bar'}`, the batch will be `{'a': [...], 'b': [...]}`.
        - If the dataset returns a tuple (e.g., `(X, y)`), this method recursively
          collates each position in the tuple using
          [`collate_tuple()`][afnio.utils.data.dataloader.collate_tuple], preserving
          nested tuple structure and batching [`Variables`][afnio.Variable]
          as described below.
        - If the dataset returns [`Variables`][afnio.Variable] directly, this method
          batches them into a single Variable whose [`data`][afnio.Variable.data] is a
          list of the original [`data`][afnio.Variable.data] fields, and whose
          [`role`][afnio.Variable.role] and
          [`requires_grad`][afnio.Variable.requires_grad] are taken
          from the first [`Variables`][afnio.Variable].
        - Otherwise, returns the batch as a `list`.
        """
        # Suppress notifications for individual Variables
        with suppress_variable_notifications():
            batch = []
            for _ in range(self.batch_size):
                try:
                    index = self._next_index()
                    batch.append(self.dataset[index])
                except StopIteration:
                    if not batch or self.drop_last:
                        raise
                    break

        # If dataset returns a dictionary, we aggregate each key across the batch
        if (
            batch
            and isinstance(batch[0], dict)  # noqa: W503
            and all(isinstance(item, dict) for item in batch)  # noqa: W503
        ):
            keys = batch[0].keys()
            collated = {}
            for key in keys:
                values = [item[key] for item in batch]
                collated[key] = values
            return collated
        # If dataset returns a tuple, we recursively collate each position in the tuple
        if (
            batch
            and isinstance(batch[0], tuple)  # noqa: W503
            and all(isinstance(item, tuple) for item in batch)  # noqa: W503
        ):
            return collate_tuple(batch)

        # If dataset returns Variables, we batch them into a single Variable
        if (
            batch
            and isinstance(batch[0], Variable)  # noqa: W503
            and all(isinstance(item, Variable) for item in batch)  # noqa: W503
        ):
            first = batch[0]
            return Variable(
                data=[item.data for item in batch],
                role=first.role,
                requires_grad=first.requires_grad,
            )

        return batch

    def __len__(self) -> int:
        length = len(self.dataset)
        if self.batch_size is not None:
            from math import ceil

            if self.drop_last:
                length = length // self.batch_size
            else:
                length = ceil(length / self.batch_size)
        return length

__init__(dataset, batch_size=1, shuffle=False, sampler=None, drop_last=False, seed=None)

Initializes the DataLoader with the given dataset and options.

Parameters:

Name Type Description Default
dataset Dataset[T_co]

Dataset from which to load the data.

required
batch_size int | None

How many samples per batch to load.

1
shuffle bool | None

Set to True to have the data reshuffled at every epoch.

False
sampler Sampler | Iterable | None

Defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.

None
drop_last bool

Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.

False
seed int | None

If not None, this seed will be used by RandomSampler to generate random indexes.

None
Source code in afnio/utils/data/dataloader.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    dataset: Dataset[T_co],
    batch_size: Optional[int] = 1,
    shuffle: Optional[bool] = False,
    sampler: Union[Sampler, Iterable, None] = None,
    drop_last: bool = False,
    seed: Optional[int] = None,
):
    """Initializes the `DataLoader` with the given dataset and options.

    Args:
        dataset: Dataset from which to load the data.
        batch_size: How many samples per batch to load.
        shuffle: Set to `True` to have the data reshuffled at every epoch.
        sampler: Defines the strategy to draw samples from the dataset. Can be any
            `Iterable` with `__len__` implemented. If specified, `shuffle`
            must not be specified.
        drop_last: Set to `True` to drop the last incomplete batch, if the dataset
            size is not divisible by the batch size. If `False` and the size of
            dataset is not divisible by the batch size, then the last batch
            will be smaller.
        seed: If not `None`, this seed will be used by
            [`RandomSampler`][afnio.utils.data.RandomSampler]
            to generate random indexes.
    """
    self.dataset = dataset
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.drop_last = drop_last

    if shuffle not in {True, False}:
        raise ValueError(
            f"DataLoader with IterableDataset: "
            f"expected unspecified shuffle option, but got shuffle={shuffle}"
        )

    if sampler is not None and shuffle:
        raise ValueError("sampler option is mutually exclusive with shuffle")

    if sampler is None:
        if shuffle:
            sampler = RandomSampler(dataset, seed=seed)
        else:
            sampler = SequentialSampler(dataset)

    self.index_sampler = sampler
    self._sampler_iter = iter(self.index_sampler)
    self.__initialized = True

__next__()

Returns the next batch from the dataset, collated according to the structure of the dataset's __getitem__ output.

Batching logic:

  • If the dataset returns a dictionary, this method aggregates each key across the batch into a list of values. For example, if each sample is {'a': 'foo', 'b': 'bar'}, the batch will be {'a': [...], 'b': [...]}.
  • If the dataset returns a tuple (e.g., (X, y)), this method recursively collates each position in the tuple using collate_tuple(), preserving nested tuple structure and batching Variables as described below.
  • If the dataset returns Variables directly, this method batches them into a single Variable whose data is a list of the original data fields, and whose role and requires_grad are taken from the first Variables.
  • Otherwise, returns the batch as a list.
Source code in afnio/utils/data/dataloader.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def __next__(self) -> Any:
    """Returns the next batch from the dataset, collated according to the structure
    of the dataset's `__getitem__` output.

    **Batching logic:**

    - If the dataset returns a dictionary, this method aggregates each key across
      the batch into a list of values. For example, if each sample is
      `{'a': 'foo', 'b': 'bar'}`, the batch will be `{'a': [...], 'b': [...]}`.
    - If the dataset returns a tuple (e.g., `(X, y)`), this method recursively
      collates each position in the tuple using
      [`collate_tuple()`][afnio.utils.data.dataloader.collate_tuple], preserving
      nested tuple structure and batching [`Variables`][afnio.Variable]
      as described below.
    - If the dataset returns [`Variables`][afnio.Variable] directly, this method
      batches them into a single Variable whose [`data`][afnio.Variable.data] is a
      list of the original [`data`][afnio.Variable.data] fields, and whose
      [`role`][afnio.Variable.role] and
      [`requires_grad`][afnio.Variable.requires_grad] are taken
      from the first [`Variables`][afnio.Variable].
    - Otherwise, returns the batch as a `list`.
    """
    # Suppress notifications for individual Variables
    with suppress_variable_notifications():
        batch = []
        for _ in range(self.batch_size):
            try:
                index = self._next_index()
                batch.append(self.dataset[index])
            except StopIteration:
                if not batch or self.drop_last:
                    raise
                break

    # If dataset returns a dictionary, we aggregate each key across the batch
    if (
        batch
        and isinstance(batch[0], dict)  # noqa: W503
        and all(isinstance(item, dict) for item in batch)  # noqa: W503
    ):
        keys = batch[0].keys()
        collated = {}
        for key in keys:
            values = [item[key] for item in batch]
            collated[key] = values
        return collated
    # If dataset returns a tuple, we recursively collate each position in the tuple
    if (
        batch
        and isinstance(batch[0], tuple)  # noqa: W503
        and all(isinstance(item, tuple) for item in batch)  # noqa: W503
    ):
        return collate_tuple(batch)

    # If dataset returns Variables, we batch them into a single Variable
    if (
        batch
        and isinstance(batch[0], Variable)  # noqa: W503
        and all(isinstance(item, Variable) for item in batch)  # noqa: W503
    ):
        first = batch[0]
        return Variable(
            data=[item.data for item in batch],
            role=first.role,
            requires_grad=first.requires_grad,
        )

    return batch

afnio.utils.data.dataloader.collate_tuple(items)

Recursively collates a batch of tuples, preserving nested structure.

This function should only be called when processing batches where each element is a tuple (i.e., when the dataset's __getitem__ returns tuples).

The function first transposes the batch, so that each position in the tuple is grouped together. For each group:

  • If all elements are Variabless, returns a single Variable whose data is a list of the original data fields, and whose role and requires_grad are taken from the first Variable.
  • If all elements are tuples, recursively collates them to preserve nested structure.
  • If some elements are tuples and some are not, recursively collates the tuples and leaves other elements as is, preserving their position.
  • Otherwise, returns a list of the grouped items.

This enables flexible batching for datasets that return tuples of Variabless, nested tuples, or mixed structures.

Parameters:

Name Type Description Default
items Iterable[tuple]

An iterable of tuples, where each tuple is a sample from the dataset.

required

Returns:

Type Description
tuple

A single tuple representing the collated batch, with structure determined by the rules above.

Source code in afnio/utils/data/dataloader.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def collate_tuple(items: Iterable[tuple]) -> tuple:
    """Recursively collates a batch of tuples, preserving nested structure.

    This function should only be called when processing batches where each element
    is a tuple (i.e., when the dataset's `__getitem__` returns tuples).

    The function first transposes the batch, so that each position in the tuple is
    grouped together. For each group:

    - If all elements are [`Variables`][afnio.Variable]s, returns a single `Variable`
        whose [`data`][afnio.Variable.data] is a list of the original
        [`data`][afnio.Variable.data] fields, and whose [`role`][afnio.Variable.role]
        and [`requires_grad`][afnio.Variable.requires_grad] are taken
        from the first [`Variable`][afnio.Variable].
    - If all elements are tuples, recursively collates them to preserve nested
        structure.
    - If some elements are tuples and some are not, recursively collates the tuples and
        leaves other elements as is, preserving their position.
    - Otherwise, returns a `list` of the grouped items.

    This enables flexible batching for datasets that return tuples of
    [`Variables`][afnio.Variable]s, nested tuples, or mixed structures.

    Args:
        items: An iterable of tuples, where each tuple is a sample from the dataset.

    Returns:
        A single tuple representing the collated batch, with structure determined \
        by the rules above.
    """
    transposed = list(zip(*items))
    collated = []
    for group in transposed:
        # If all are Variables, batch as Variable
        if all(isinstance(x, Variable) for x in group):
            first = group[0]
            collated.append(
                Variable(
                    data=[x.data for x in group],
                    role=first.role,
                    requires_grad=first.requires_grad,
                )
            )
        # If all are tuples, recurse
        elif all(isinstance(x, tuple) for x in group):
            collated.append(collate_tuple(group))
        # If some are tuples and some are not, handle each element
        else:
            collated.append(
                [collate_tuple([x]) if isinstance(x, tuple) else x for x in group]
            )
    return tuple(collated)