diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index acd0ebef19af6..942d2f98a3090 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -452,10 +452,10 @@ def __init__(self, loader: DataLoader) -> None: distributed_env = _DistributedEnv.detect() - if self._loader._profile_bactches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE: + if self._loader._profile_batches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE: from torch.utils.data._utils import worker - worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches, self._loader._profile_dir) + worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_batches, self._loader._profile_dir) super().__init__(loader) @@ -479,8 +479,56 @@ def _try_put_index(self) -> None: class StreamingDataLoader(DataLoader): - """The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the - dataset.""" + r"""The StreamingDataLoader combines a dataset and a sampler, and provides an iterable over the given dataset. + + The :class:`~lightning.data.streaming.dataloader.StreamingDataLoader` supports either a + StreamingDataset and CombinedStreamingDataset datasets with single- or multi-process loading, + customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`torch.utils.data` documentation page for more details. + + Args: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (Callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (Callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If + ``None``, the default `multiprocessing context`_ of your operating system will + be used. (default: ``None``) + generator (torch.Generator, optional): If not ``None``, this RNG will be used + by RandomSampler to generate random indexes and multiprocessing to generate + ``base_seed`` for workers. (default: ``None``) + prefetch_factor (int, optional, keyword-only arg): Number of batches loaded + in advance by each worker. ``2`` means there will be a total of + 2 * num_workers batches prefetched across all workers. (default value depends + on the set value for num_workers. If value of num_workers=0 default is ``None``. + Otherwise, if value of ``num_workers > 0`` default is ``2``). + persistent_workers (bool, optional): If ``True``, the data loader will not shut down + the worker processes after a dataset has been consumed once. This allows to + maintain the workers `Dataset` instances alive. (default: ``False``) + pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is + ``True``. + profile_batches (int, bool, optional): Whether to record data loading profile and generate a result.json file. + profile_dir (int, bool, optional): Where to store the recorded trace when profile_batches is enabled. + + """ __doc__ = DataLoader.__doc__ @@ -490,7 +538,7 @@ def __init__( *args: Any, batch_size: int = 1, num_workers: int = 0, - profile_bactches: Union[bool, int] = False, + profile_batches: Union[bool, int] = False, profile_dir: Optional[str] = None, prefetch_factor: Optional[int] = None, **kwargs: Any, @@ -501,16 +549,16 @@ def __init__( f" Found {dataset}." ) - if profile_bactches and not _VIZ_TRACKER_AVAILABLE: - raise ModuleNotFoundError("To use profile_bactches, viztracer is required. Run `pip install viztracer`") + if profile_batches and not _VIZ_TRACKER_AVAILABLE: + raise ModuleNotFoundError("To use profile_batches, viztracer is required. Run `pip install viztracer`") - if profile_bactches and num_workers == 0: + if profile_batches and num_workers == 0: raise ValueError("Profiling is supported only with num_workers >= 1.") self.current_epoch = 0 self.batch_size = batch_size self.num_workers = num_workers - self._profile_bactches = profile_bactches + self._profile_batches = profile_batches self._profile_dir = profile_dir self._num_samples_yielded_streaming = 0 self._num_samples_yielded_combined: Dict[int, List[Any]] = {} diff --git a/tests/tests_data/streaming/test_dataloader.py b/tests/tests_data/streaming/test_dataloader.py index 293a96636adae..d70575aaa5858 100644 --- a/tests/tests_data/streaming/test_dataloader.py +++ b/tests/tests_data/streaming/test_dataloader.py @@ -84,7 +84,7 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch): [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) ) dataloader = StreamingDataLoader( - dataset, batch_size=2, profile_bactches=profile, profile_dir=str(tmpdir), num_workers=1 + dataset, batch_size=2, profile_batches=profile, profile_dir=str(tmpdir), num_workers=1 ) dataloader_iter = iter(dataloader) batches = []