Adding all project files

This commit is contained in:
Martina Burlando 2025-08-02 02:00:33 +02:00
parent 6c9e127bdc
commit cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions

View file

@ -0,0 +1,19 @@
# Functional DataPipe
from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper
from torch.utils.data.datapipes.map.combinatorics import (
ShufflerIterDataPipe as Shuffler,
)
from torch.utils.data.datapipes.map.combining import (
ConcaterMapDataPipe as Concater,
ZipperMapDataPipe as Zipper,
)
from torch.utils.data.datapipes.map.grouping import BatcherMapDataPipe as Batcher
from torch.utils.data.datapipes.map.utils import (
SequenceWrapperMapDataPipe as SequenceWrapper,
)
__all__ = ["Batcher", "Concater", "Mapper", "SequenceWrapper", "Shuffler", "Zipper"]
# Please keep this list sorted
assert __all__ == sorted(__all__)

View file

@ -0,0 +1,65 @@
# mypy: allow-untyped-defs
from typing import Callable, TypeVar
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
__all__ = ["MapperMapDataPipe", "default_fn"]
_T_co = TypeVar("_T_co", covariant=True)
# Default function to return each item directly
# In order to keep datapipe picklable, eliminates the usage
# of python lambda function
def default_fn(data):
return data
@functional_datapipe("map")
class MapperMapDataPipe(MapDataPipe[_T_co]):
r"""
Apply the input function over each item from the source DataPipe (functional name: ``map``).
The function can be any regular Python function or partial object. Lambda
function is not recommended as it is not supported by pickle.
Args:
datapipe: Source MapDataPipe
fn: Function being applied to each item
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper, Mapper
>>> def add_one(x):
... return x + 1
>>> dp = SequenceWrapper(range(10))
>>> map_dp_1 = dp.map(add_one)
>>> list(map_dp_1)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> map_dp_2 = Mapper(dp, lambda x: x + 1)
>>> list(map_dp_2)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
"""
datapipe: MapDataPipe
fn: Callable
def __init__(
self,
datapipe: MapDataPipe,
fn: Callable = default_fn,
) -> None:
super().__init__()
self.datapipe = datapipe
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]
def __len__(self) -> int:
return len(self.datapipe)
def __getitem__(self, index) -> _T_co:
return self.fn(self.datapipe[index])

View file

@ -0,0 +1,130 @@
# mypy: allow-untyped-defs
import random
from collections.abc import Iterator
from typing import Optional, TypeVar
import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
__all__ = ["ShufflerIterDataPipe"]
_T_co = TypeVar("_T_co", covariant=True)
# @functional_datapipe('shuffle')
class ShufflerIterDataPipe(IterDataPipe[_T_co]):
r"""
Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
set up random seed are different based on :attr:`num_workers`.
For single-process mode (:attr:`num_workers == 0`), the random seed is set before
the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
for each worker process.
Args:
datapipe: MapDataPipe being shuffled
indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper
>>> dp = SequenceWrapper(range(10))
>>> shuffle_dp = dp.shuffle().set_seed(0)
>>> list(shuffle_dp)
[7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
>>> list(shuffle_dp)
[6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
>>> # Reset seed for Shuffler
>>> shuffle_dp = shuffle_dp.set_seed(0)
>>> list(shuffle_dp)
[7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
Note:
Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
of data during data-processing.
"""
datapipe: MapDataPipe[_T_co]
_enabled: bool
_seed: Optional[int]
_rng: random.Random
def __init__(
self,
datapipe: MapDataPipe[_T_co],
*,
indices: Optional[list] = None,
) -> None:
super().__init__()
self.datapipe = datapipe
self.indices = list(range(len(datapipe))) if indices is None else indices
self._enabled = True
self._seed = None
self._rng = random.Random()
self._shuffled_indices: list = self.indices
def set_shuffle(self, shuffle=True):
self._enabled = shuffle
return self
def set_seed(self, seed: int):
self._seed = seed
return self
def __iter__(self) -> Iterator[_T_co]:
if not self._enabled:
for idx in self.indices:
yield self.datapipe[idx]
else:
while self._shuffled_indices:
idx = self._shuffled_indices.pop()
yield self.datapipe[idx]
def reset(self) -> None:
if self._enabled and self._seed is None:
self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
self._rng.seed(self._seed)
self._seed = None
self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
def __len__(self) -> int:
return len(self.datapipe)
def __getstate__(self):
state = (
self.datapipe,
self.indices,
self._enabled,
self._seed,
self._rng.getstate(),
self._shuffled_indices,
self._valid_iterator_id,
self._number_of_samples_yielded,
)
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(state)
return state
def __setstate__(self, state):
(
self.datapipe,
self.indices,
self._enabled,
self._seed,
rng_state,
self._shuffled_indices,
self._valid_iterator_id,
self._number_of_samples_yielded,
) = state
self._rng = random.Random()
self._rng.setstate(rng_state)
MapDataPipe.register_datapipe_as_function("shuffle", ShufflerIterDataPipe)

View file

@ -0,0 +1,105 @@
# mypy: allow-untyped-defs
from collections.abc import Sized
from typing import TypeVar
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe
__all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"]
_T_co = TypeVar("_T_co", covariant=True)
@functional_datapipe("concat")
class ConcaterMapDataPipe(MapDataPipe):
r"""
Concatenate multiple Map DataPipes (functional name: ``concat``).
The new index of is the cumulative sum of source DataPipes.
For example, if there are 2 source DataPipes both with length 5,
index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
elements of the first DataPipe, and 5 to 9 would refer to elements
of the second DataPipe.
Args:
datapipes: Map DataPipes being concatenated
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper
>>> dp1 = SequenceWrapper(range(3))
>>> dp2 = SequenceWrapper(range(3))
>>> concat_dp = dp1.concat(dp2)
>>> list(concat_dp)
[0, 1, 2, 0, 1, 2]
"""
datapipes: tuple[MapDataPipe]
def __init__(self, *datapipes: MapDataPipe):
if len(datapipes) == 0:
raise ValueError("Expected at least one DataPipe, but got nothing")
if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
raise TypeError("Expected all inputs to be `MapDataPipe`")
if not all(isinstance(dp, Sized) for dp in datapipes):
raise TypeError("Expected all inputs to be `Sized`")
self.datapipes = datapipes # type: ignore[assignment]
def __getitem__(self, index) -> _T_co: # type: ignore[type-var]
offset = 0
for dp in self.datapipes:
if index - offset < len(dp):
return dp[index - offset]
else:
offset += len(dp)
raise IndexError(f"Index {index} is out of range.")
def __len__(self) -> int:
return sum(len(dp) for dp in self.datapipes)
@functional_datapipe("zip")
class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]):
r"""
Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
Args:
*datapipes: Map DataPipes being aggregated
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper
>>> dp1 = SequenceWrapper(range(3))
>>> dp2 = SequenceWrapper(range(10, 13))
>>> zip_dp = dp1.zip(dp2)
>>> list(zip_dp)
[(0, 10), (1, 11), (2, 12)]
"""
datapipes: tuple[MapDataPipe[_T_co], ...]
def __init__(self, *datapipes: MapDataPipe[_T_co]) -> None:
if len(datapipes) == 0:
raise ValueError("Expected at least one DataPipe, but got nothing")
if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
raise TypeError("Expected all inputs to be `MapDataPipe`")
if not all(isinstance(dp, Sized) for dp in datapipes):
raise TypeError("Expected all inputs to be `Sized`")
self.datapipes = datapipes
def __getitem__(self, index) -> tuple[_T_co, ...]:
res = []
for dp in self.datapipes:
try:
res.append(dp[index])
except IndexError as e:
raise IndexError(
f"Index {index} is out of range for one of the input MapDataPipes {dp}."
) from e
return tuple(res)
def __len__(self) -> int:
return min(len(dp) for dp in self.datapipes)

View file

@ -0,0 +1,74 @@
# mypy: allow-untyped-defs
from collections.abc import Sized
from typing import TypeVar
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import DataChunk, MapDataPipe
__all__ = ["BatcherMapDataPipe"]
_T = TypeVar("_T")
@functional_datapipe("batch")
class BatcherMapDataPipe(MapDataPipe[DataChunk]):
r"""
Create mini-batches of data (functional name: ``batch``).
An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
Args:
datapipe: Iterable DataPipe being batched
batch_size: The size of each batch
drop_last: Option to drop the last batch if it's not full
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper
>>> dp = SequenceWrapper(range(10))
>>> batch_dp = dp.batch(batch_size=2)
>>> list(batch_dp)
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
"""
datapipe: MapDataPipe
batch_size: int
drop_last: bool
def __init__(
self,
datapipe: MapDataPipe[_T],
batch_size: int,
drop_last: bool = False,
wrapper_class: type[DataChunk] = DataChunk,
) -> None:
assert batch_size > 0, "Batch size is required to be larger than 0!"
super().__init__()
self.datapipe = datapipe
self.batch_size = batch_size
self.drop_last = drop_last
self.wrapper_class = wrapper_class
def __getitem__(self, index) -> DataChunk:
batch: list = []
indices = range(index * self.batch_size, (index + 1) * self.batch_size)
try:
batch.extend(self.datapipe[i] for i in indices)
return self.wrapper_class(batch)
except IndexError as e:
if not self.drop_last and len(batch) > 0:
return self.wrapper_class(batch)
else:
raise IndexError(f"Index {index} is out of bound.") from e
def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
if self.drop_last:
return len(self.datapipe) // self.batch_size
else:
return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
else:
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")

View file

@ -0,0 +1,53 @@
# mypy: allow-untyped-defs
import copy
import warnings
from torch.utils.data.datapipes.datapipe import MapDataPipe
__all__ = ["SequenceWrapperMapDataPipe"]
class SequenceWrapperMapDataPipe(MapDataPipe):
r"""
Wraps a sequence object into a MapDataPipe.
Args:
sequence: Sequence object to be wrapped into an MapDataPipe
deepcopy: Option to deepcopy input sequence object
.. note::
If ``deepcopy`` is set to False explicitly, users should ensure
that data pipeline doesn't contain any in-place operations over
the iterable instance, in order to prevent data inconsistency
across iterations.
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper
>>> dp = SequenceWrapper(range(10))
>>> list(dp)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
>>> dp['a']
100
"""
def __init__(self, sequence, deepcopy=True):
if deepcopy:
try:
self.sequence = copy.deepcopy(sequence)
except TypeError:
warnings.warn(
"The input sequence can not be deepcopied, "
"please be aware of in-place modification would affect source data"
)
self.sequence = sequence
else:
self.sequence = sequence
def __getitem__(self, index):
return self.sequence[index]
def __len__(self):
return len(self.sequence)