Skip to content

afnio.utils.data

afnio.utils.data.DataLoader

Bases: Generic[T_co]

Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.

The DataLoader supports both map-style and iterable-style datasets with single-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.

See afnio.utils.data documentation page for more details.

Source code in afnio/utils/data/dataloader.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class DataLoader(Generic[T_co]):
    """
    Data loader combines a dataset and a sampler, and provides an iterable over the
    given dataset.

    The [`DataLoader`][afnio.utils.data.DataLoader] supports both map-style and
    iterable-style datasets with single-process loading, customizing loading order
    and optional automatic batching (collation) and memory pinning.

    See [`afnio.utils.data`][afnio.utils.data] documentation page for more details.
    """

    dataset: Dataset[T_co]
    batch_size: Optional[int]
    drop_last: bool
    sampler: Union[Sampler, Iterable]
    __initialized = False

    def __init__(
        self,
        dataset: Dataset[T_co],
        batch_size: Optional[int] = 1,
        shuffle: Optional[bool] = False,
        sampler: Union[Sampler, Iterable, None] = None,
        drop_last: bool = False,
        seed: Optional[int] = None,
    ):
        """Initializes the `DataLoader` with the given dataset and options.

        Args:
            dataset: Dataset from which to load the data.
            batch_size: How many samples per batch to load.
            shuffle: Set to `True` to have the data reshuffled at every epoch.
            sampler: Defines the strategy to draw samples from the dataset. Can be any
                `Iterable` with `__len__` implemented. If specified, `shuffle`
                must not be specified.
            drop_last: Set to `True` to drop the last incomplete batch, if the dataset
                size is not divisible by the batch size. If `False` and the size of
                dataset is not divisible by the batch size, then the last batch
                will be smaller.
            seed: If not `None`, this seed will be used by
                [`RandomSampler`][afnio.utils.data.RandomSampler]
                to generate random indexes.
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        if shuffle not in {True, False}:
            raise ValueError(
                f"DataLoader with IterableDataset: "
                f"expected unspecified shuffle option, but got shuffle={shuffle}"
            )

        if sampler is not None and shuffle:
            raise ValueError("sampler option is mutually exclusive with shuffle")

        if sampler is None:
            if shuffle:
                sampler = RandomSampler(dataset, seed=seed)
            else:
                sampler = SequentialSampler(dataset)

        self.index_sampler = sampler
        self._sampler_iter = iter(self.index_sampler)
        self.__initialized = True

    def __iter__(self) -> Iterable[Any]:
        self._sampler_iter = iter(self.index_sampler)  # Ensure new iterator every time
        return self

    def _next_index(self):
        return next(self._sampler_iter)

    def __next__(self) -> Any:
        """Returns the next batch from the dataset, collated according to the structure
        of the dataset's `__getitem__` output.

        **Batching logic:**

        - If the dataset returns a dictionary, this method aggregates each key across
          the batch into a list of values. For example, if each sample is
          `{'a': 'foo', 'b': 'bar'}`, the batch will be `{'a': [...], 'b': [...]}`.
        - If the dataset returns a tuple (e.g., `(X, y)`), this method recursively
          collates each position in the tuple using
          [`collate_tuple()`][afnio.utils.data.dataloader.collate_tuple], preserving
          nested tuple structure and batching [`Variables`][afnio.Variable]
          as described below.
        - If the dataset returns [`Variables`][afnio.Variable] directly, this method
          batches them into a single Variable whose [`data`][afnio.Variable.data] is a
          list of the original [`data`][afnio.Variable.data] fields, and whose
          [`role`][afnio.Variable.role] and
          [`requires_grad`][afnio.Variable.requires_grad] are taken
          from the first [`Variables`][afnio.Variable].
        - Otherwise, returns the batch as a `list`.
        """
        # Suppress notifications for individual Variables
        with suppress_variable_notifications():
            batch = []
            for _ in range(self.batch_size):
                try:
                    index = self._next_index()
                    batch.append(self.dataset[index])
                except StopIteration:
                    if not batch or self.drop_last:
                        raise
                    break

        # If dataset returns a dictionary, we aggregate each key across the batch
        if (
            batch
            and isinstance(batch[0], dict)  # noqa: W503
            and all(isinstance(item, dict) for item in batch)  # noqa: W503
        ):
            keys = batch[0].keys()
            collated = {}
            for key in keys:
                values = [item[key] for item in batch]
                collated[key] = values
            return collated
        # If dataset returns a tuple, we recursively collate each position in the tuple
        if (
            batch
            and isinstance(batch[0], tuple)  # noqa: W503
            and all(isinstance(item, tuple) for item in batch)  # noqa: W503
        ):
            return collate_tuple(batch)

        # If dataset returns Variables, we batch them into a single Variable
        if (
            batch
            and isinstance(batch[0], Variable)  # noqa: W503
            and all(isinstance(item, Variable) for item in batch)  # noqa: W503
        ):
            first = batch[0]
            return Variable(
                data=[item.data for item in batch],
                role=first.role,
                requires_grad=first.requires_grad,
            )

        return batch

    def __len__(self) -> int:
        length = len(self.dataset)
        if self.batch_size is not None:
            from math import ceil

            if self.drop_last:
                length = length // self.batch_size
            else:
                length = ceil(length / self.batch_size)
        return length

__init__(dataset, batch_size=1, shuffle=False, sampler=None, drop_last=False, seed=None)

Initializes the DataLoader with the given dataset and options.

Parameters:

Name Type Description Default
dataset Dataset[T_co]

Dataset from which to load the data.

required
batch_size int | None

How many samples per batch to load.

1
shuffle bool | None

Set to True to have the data reshuffled at every epoch.

False
sampler Sampler | Iterable | None

Defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.

None
drop_last bool

Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.

False
seed int | None

If not None, this seed will be used by RandomSampler to generate random indexes.

None
Source code in afnio/utils/data/dataloader.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    dataset: Dataset[T_co],
    batch_size: Optional[int] = 1,
    shuffle: Optional[bool] = False,
    sampler: Union[Sampler, Iterable, None] = None,
    drop_last: bool = False,
    seed: Optional[int] = None,
):
    """Initializes the `DataLoader` with the given dataset and options.

    Args:
        dataset: Dataset from which to load the data.
        batch_size: How many samples per batch to load.
        shuffle: Set to `True` to have the data reshuffled at every epoch.
        sampler: Defines the strategy to draw samples from the dataset. Can be any
            `Iterable` with `__len__` implemented. If specified, `shuffle`
            must not be specified.
        drop_last: Set to `True` to drop the last incomplete batch, if the dataset
            size is not divisible by the batch size. If `False` and the size of
            dataset is not divisible by the batch size, then the last batch
            will be smaller.
        seed: If not `None`, this seed will be used by
            [`RandomSampler`][afnio.utils.data.RandomSampler]
            to generate random indexes.
    """
    self.dataset = dataset
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.drop_last = drop_last

    if shuffle not in {True, False}:
        raise ValueError(
            f"DataLoader with IterableDataset: "
            f"expected unspecified shuffle option, but got shuffle={shuffle}"
        )

    if sampler is not None and shuffle:
        raise ValueError("sampler option is mutually exclusive with shuffle")

    if sampler is None:
        if shuffle:
            sampler = RandomSampler(dataset, seed=seed)
        else:
            sampler = SequentialSampler(dataset)

    self.index_sampler = sampler
    self._sampler_iter = iter(self.index_sampler)
    self.__initialized = True

__next__()

Returns the next batch from the dataset, collated according to the structure of the dataset's __getitem__ output.

Batching logic:

  • If the dataset returns a dictionary, this method aggregates each key across the batch into a list of values. For example, if each sample is {'a': 'foo', 'b': 'bar'}, the batch will be {'a': [...], 'b': [...]}.
  • If the dataset returns a tuple (e.g., (X, y)), this method recursively collates each position in the tuple using collate_tuple(), preserving nested tuple structure and batching Variables as described below.
  • If the dataset returns Variables directly, this method batches them into a single Variable whose data is a list of the original data fields, and whose role and requires_grad are taken from the first Variables.
  • Otherwise, returns the batch as a list.
Source code in afnio/utils/data/dataloader.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def __next__(self) -> Any:
    """Returns the next batch from the dataset, collated according to the structure
    of the dataset's `__getitem__` output.

    **Batching logic:**

    - If the dataset returns a dictionary, this method aggregates each key across
      the batch into a list of values. For example, if each sample is
      `{'a': 'foo', 'b': 'bar'}`, the batch will be `{'a': [...], 'b': [...]}`.
    - If the dataset returns a tuple (e.g., `(X, y)`), this method recursively
      collates each position in the tuple using
      [`collate_tuple()`][afnio.utils.data.dataloader.collate_tuple], preserving
      nested tuple structure and batching [`Variables`][afnio.Variable]
      as described below.
    - If the dataset returns [`Variables`][afnio.Variable] directly, this method
      batches them into a single Variable whose [`data`][afnio.Variable.data] is a
      list of the original [`data`][afnio.Variable.data] fields, and whose
      [`role`][afnio.Variable.role] and
      [`requires_grad`][afnio.Variable.requires_grad] are taken
      from the first [`Variables`][afnio.Variable].
    - Otherwise, returns the batch as a `list`.
    """
    # Suppress notifications for individual Variables
    with suppress_variable_notifications():
        batch = []
        for _ in range(self.batch_size):
            try:
                index = self._next_index()
                batch.append(self.dataset[index])
            except StopIteration:
                if not batch or self.drop_last:
                    raise
                break

    # If dataset returns a dictionary, we aggregate each key across the batch
    if (
        batch
        and isinstance(batch[0], dict)  # noqa: W503
        and all(isinstance(item, dict) for item in batch)  # noqa: W503
    ):
        keys = batch[0].keys()
        collated = {}
        for key in keys:
            values = [item[key] for item in batch]
            collated[key] = values
        return collated
    # If dataset returns a tuple, we recursively collate each position in the tuple
    if (
        batch
        and isinstance(batch[0], tuple)  # noqa: W503
        and all(isinstance(item, tuple) for item in batch)  # noqa: W503
    ):
        return collate_tuple(batch)

    # If dataset returns Variables, we batch them into a single Variable
    if (
        batch
        and isinstance(batch[0], Variable)  # noqa: W503
        and all(isinstance(item, Variable) for item in batch)  # noqa: W503
    ):
        first = batch[0]
        return Variable(
            data=[item.data for item in batch],
            role=first.role,
            requires_grad=first.requires_grad,
        )

    return batch

afnio.utils.data.Dataset

Bases: Generic[T_co]

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key and __len__(), which is expected to return the size of the dataset by the default options of DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

Source code in afnio/utils/data/dataset.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Dataset(Generic[T_co]):
    """An abstract class representing a [`Dataset`][.].

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite `__getitem__()`, supporting fetching a
    data sample for a given key and `__len__()`, which is expected to return
    the size of the dataset by the default options of
    [`DataLoader`][afnio.utils.data.dataloader.DataLoader]. Subclasses could also
    optionally implement `__getitems__()`, for speedup batched samples loading.
    This method accepts list of indices of samples of batch and returns list of samples.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")

    def __len__(self):
        raise NotImplementedError("Subclasses of Dataset should implement __len__.")

afnio.utils.data.RandomSampler

Bases: Sampler[int]

Samples elements randomly. If without replacement, then sample from a shuffled dataset.

If with replacement, then user can specify num_samples to draw.

Source code in afnio/utils/data/sampler.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class RandomSampler(Sampler[int]):
    """Samples elements randomly. If without replacement,
    then sample from a shuffled dataset.

    If with replacement, then user can specify `num_samples` to draw.
    """

    data_source: Sized
    replacement: bool

    def __init__(
        self,
        data_source: Sized,
        replacement: bool = False,
        num_samples: Optional[int] = None,
        seed: Optional[int] = None,
    ) -> None:
        """
        Initializes a `RandomSampler`.

        Args:
            data_source: Dataset to sample from.
            replacement: Samples are drawn on-demand with replacement if `True`.
            num_samples: Number of samples to draw, default=`len(dataset)`.
            seed: A number to set the seed for the random draws.
        """
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.seed = seed

        if not isinstance(self.replacement, bool):
            raise TypeError(
                f"replacement should be a boolean value, "
                f"but got replacement={self.replacement}"
            )

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                f"num_samples should be a positive integer value, "
                f"but got num_samples={self.num_samples}"
            )

    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def _is_valid_random_state(self, state) -> bool:
        return isinstance(state, tuple) and len(state) > 0

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        random.seed(self.seed)

        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from random.choices(range(n), k=32)
            yield from random.choices(range(n), k=self.num_samples % 32)
        else:
            for _ in range(self.num_samples // n):
                yield from random.sample(range(n), n)
            yield from random.sample(range(n), self.num_samples % n)

    def __len__(self) -> int:
        return self.num_samples

__init__(data_source, replacement=False, num_samples=None, seed=None)

Initializes a RandomSampler.

Parameters:

Name Type Description Default
data_source Sized

Dataset to sample from.

required
replacement bool

Samples are drawn on-demand with replacement if True.

False
num_samples int | None

Number of samples to draw, default=len(dataset).

None
seed int | None

A number to set the seed for the random draws.

None
Source code in afnio/utils/data/sampler.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def __init__(
    self,
    data_source: Sized,
    replacement: bool = False,
    num_samples: Optional[int] = None,
    seed: Optional[int] = None,
) -> None:
    """
    Initializes a `RandomSampler`.

    Args:
        data_source: Dataset to sample from.
        replacement: Samples are drawn on-demand with replacement if `True`.
        num_samples: Number of samples to draw, default=`len(dataset)`.
        seed: A number to set the seed for the random draws.
    """
    self.data_source = data_source
    self.replacement = replacement
    self._num_samples = num_samples
    self.seed = seed

    if not isinstance(self.replacement, bool):
        raise TypeError(
            f"replacement should be a boolean value, "
            f"but got replacement={self.replacement}"
        )

    if not isinstance(self.num_samples, int) or self.num_samples <= 0:
        raise ValueError(
            f"num_samples should be a positive integer value, "
            f"but got num_samples={self.num_samples}"
        )

afnio.utils.data.Sampler

Bases: Generic[T_co]

Base class for all Samplers.

Every Sampler subclass has to provide an __iter__() method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a __len__() method that returns the length of the returned iterators.

Examples:

>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = [len(x) for x in self.data]
>>>         yield from sorted(range(len(sizes)), key=sizes.__getitem__)
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = [len(x) for x in self.data]
>>>         sorted_indices = sorted(range(len(sizes)), key=sizes.__getitem__)
>>>         for start in range(0, len(sorted_indices), self.batch_size):
>>>             yield sorted_indices[start : start + self.batch_size]
Source code in afnio/utils/data/sampler.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Sampler(Generic[T_co]):
    """Base class for all Samplers.

    Every Sampler subclass has to provide an `__iter__()` method, providing a
    way to iterate over indices or lists of indices (batches) of dataset elements,
    and may provide a `__len__()` method that returns the length of the returned
    iterators.

    Examples:
        >>> class AccedingSequenceLengthSampler(Sampler[int]):
        >>>     def __init__(self, data: List[str]) -> None:
        >>>         self.data = data
        >>>
        >>>     def __len__(self) -> int:
        >>>         return len(self.data)
        >>>
        >>>     def __iter__(self) -> Iterator[int]:
        >>>         sizes = [len(x) for x in self.data]
        >>>         yield from sorted(range(len(sizes)), key=sizes.__getitem__)
        >>>
        >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
        >>>     def __init__(self, data: List[str], batch_size: int) -> None:
        >>>         self.data = data
        >>>         self.batch_size = batch_size
        >>>
        >>>     def __len__(self) -> int:
        >>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
        >>>
        >>>     def __iter__(self) -> Iterator[List[int]]:
        >>>         sizes = [len(x) for x in self.data]
        >>>         sorted_indices = sorted(range(len(sizes)), key=sizes.__getitem__)
        >>>         for start in range(0, len(sorted_indices), self.batch_size):
        >>>             yield sorted_indices[start : start + self.batch_size]
    """

    def __init__(self) -> None:
        raise NotImplementedError

    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

afnio.utils.data.SequentialSampler

Bases: Sampler[int]

Samples elements sequentially, always in the same order.

Source code in afnio/utils/data/sampler.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class SequentialSampler(Sampler[int]):
    """Samples elements sequentially, always in the same order."""

    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        """Initializes a `SequentialSampler`.

        Args:
            data_source: Dataset to sample from.
        """
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

__init__(data_source)

Initializes a SequentialSampler.

Parameters:

Name Type Description Default
data_source Sized

Dataset to sample from.

required
Source code in afnio/utils/data/sampler.py
54
55
56
57
58
59
60
def __init__(self, data_source: Sized) -> None:
    """Initializes a `SequentialSampler`.

    Args:
        data_source: Dataset to sample from.
    """
    self.data_source = data_source

afnio.utils.data.WeightedRandomSampler

Bases: Sampler[int]

Samples elements from [0,..,len(weights)-1] with given probabilities (weights).

Examples:

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
Source code in afnio/utils/data/sampler.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class WeightedRandomSampler(Sampler[int]):
    """Samples elements from `[0,..,len(weights)-1]` with given probabilities (weights).

    Examples:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [4, 4, 1, 4, 5]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """  # noqa: E501

    weights: Sequence[float]
    num_samples: int
    replacement: bool

    def __init__(
        self,
        weights: Sequence[float],
        num_samples: int,
        replacement: bool = True,
        seed: Optional[int] = None,
    ) -> None:
        """Initializes a `WeightedRandomSampler`.

        Args:
            weights: A sequence of weights, not necessary summing up to one
            num_samples: Number of samples to draw
            replacement: If `True`, samples are drawn with replacement.
                If not, they are drawn without replacement, which means that when a
                sample index is drawn for a row, it cannot be drawn again for that row.
            seed: A number to set the seed for the random draws.
        """
        if (
            not isinstance(num_samples, int)
            or isinstance(num_samples, bool)
            or num_samples <= 0
        ):
            raise ValueError(
                f"num_samples should be a positive integer value, "
                f"but got num_samples={num_samples}"
            )
        if not isinstance(replacement, bool):
            raise ValueError(
                f"replacement should be a boolean value, "
                f"but got replacement={replacement}"
            )

        if len(weights) == 0 or not all(isinstance(w, (float, int)) for w in weights):
            raise ValueError("Weights must be a non-empty sequence of numbers.")

        if not replacement and num_samples > len(weights):
            raise ValueError(
                f"num_samples ({num_samples}) cannot be greater than "
                f"the population size ({len(weights)}) when replacement is False."
            )

        self.weights = weights
        self.num_samples = num_samples
        self.replacement = replacement
        self.seed = seed

    def __iter__(self) -> Iterator[int]:
        random.seed(self.seed)

        total_weight = sum(self.weights)
        probabilities = [w / total_weight for w in self.weights]

        if self.replacement:
            yield from random.choices(
                population=range(len(self.weights)),
                weights=probabilities,
                k=self.num_samples,
            )
        else:
            # Sample without replacement
            yield from random.sample(range(len(self.weights)), k=self.num_samples)

    def __len__(self) -> int:
        return self.num_samples

__init__(weights, num_samples, replacement=True, seed=None)

Initializes a WeightedRandomSampler.

Parameters:

Name Type Description Default
weights Sequence[float]

A sequence of weights, not necessary summing up to one

required
num_samples int

Number of samples to draw

required
replacement bool

If True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.

True
seed int | None

A number to set the seed for the random draws.

None
Source code in afnio/utils/data/sampler.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def __init__(
    self,
    weights: Sequence[float],
    num_samples: int,
    replacement: bool = True,
    seed: Optional[int] = None,
) -> None:
    """Initializes a `WeightedRandomSampler`.

    Args:
        weights: A sequence of weights, not necessary summing up to one
        num_samples: Number of samples to draw
        replacement: If `True`, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
        seed: A number to set the seed for the random draws.
    """
    if (
        not isinstance(num_samples, int)
        or isinstance(num_samples, bool)
        or num_samples <= 0
    ):
        raise ValueError(
            f"num_samples should be a positive integer value, "
            f"but got num_samples={num_samples}"
        )
    if not isinstance(replacement, bool):
        raise ValueError(
            f"replacement should be a boolean value, "
            f"but got replacement={replacement}"
        )

    if len(weights) == 0 or not all(isinstance(w, (float, int)) for w in weights):
        raise ValueError("Weights must be a non-empty sequence of numbers.")

    if not replacement and num_samples > len(weights):
        raise ValueError(
            f"num_samples ({num_samples}) cannot be greater than "
            f"the population size ({len(weights)}) when replacement is False."
        )

    self.weights = weights
    self.num_samples = num_samples
    self.replacement = replacement
    self.seed = seed