Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamingDataloader: Resolve typo #19370

Merged
merged 3 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions src/lightning/data/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,10 @@ def __init__(self, loader: DataLoader) -> None:

distributed_env = _DistributedEnv.detect()

if self._loader._profile_bactches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE:
if self._loader._profile_batches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE:
from torch.utils.data._utils import worker

worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches, self._loader._profile_dir)
worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_batches, self._loader._profile_dir)

super().__init__(loader)

Expand All @@ -479,8 +479,56 @@ def _try_put_index(self) -> None:


class StreamingDataLoader(DataLoader):
"""The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the
dataset."""
r"""The StreamingDataLoader combines a dataset and a sampler, and provides an iterable over the given dataset.

The :class:`~lightning.data.streaming.dataloader.StreamingDataLoader` supports either a
StreamingDataset and CombinedStreamingDataset datasets with single- or multi-process loading,
customizing
loading order and optional automatic batching (collation) and memory pinning.

See :py:mod:`torch.utils.data` documentation page for more details.

Args:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
collate_fn (Callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
into device/CUDA pinned memory before returning them. If your data elements
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: ``0``)
worker_init_fn (Callable, optional): If not ``None``, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: ``None``)
multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
``None``, the default `multiprocessing context`_ of your operating system will
be used. (default: ``None``)
generator (torch.Generator, optional): If not ``None``, this RNG will be used
by RandomSampler to generate random indexes and multiprocessing to generate
``base_seed`` for workers. (default: ``None``)
prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
in advance by each worker. ``2`` means there will be a total of
2 * num_workers batches prefetched across all workers. (default value depends
on the set value for num_workers. If value of num_workers=0 default is ``None``.
Otherwise, if value of ``num_workers > 0`` default is ``2``).
persistent_workers (bool, optional): If ``True``, the data loader will not shut down
the worker processes after a dataset has been consumed once. This allows to
maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
``True``.
profile_batches (int, bool, optional): Whether to record data loading profile and generate a result.json file.
profile_dir (int, bool, optional): Where to store the recorded trace when profile_batches is enabled.

"""

__doc__ = DataLoader.__doc__

Expand All @@ -490,7 +538,7 @@ def __init__(
*args: Any,
batch_size: int = 1,
num_workers: int = 0,
profile_bactches: Union[bool, int] = False,
profile_batches: Union[bool, int] = False,
profile_dir: Optional[str] = None,
prefetch_factor: Optional[int] = None,
**kwargs: Any,
Expand All @@ -501,16 +549,16 @@ def __init__(
f" Found {dataset}."
)

if profile_bactches and not _VIZ_TRACKER_AVAILABLE:
raise ModuleNotFoundError("To use profile_bactches, viztracer is required. Run `pip install viztracer`")
if profile_batches and not _VIZ_TRACKER_AVAILABLE:
raise ModuleNotFoundError("To use profile_batches, viztracer is required. Run `pip install viztracer`")

if profile_bactches and num_workers == 0:
if profile_batches and num_workers == 0:
raise ValueError("Profiling is supported only with num_workers >= 1.")

self.current_epoch = 0
self.batch_size = batch_size
self.num_workers = num_workers
self._profile_bactches = profile_bactches
self._profile_batches = profile_batches
self._profile_dir = profile_dir
self._num_samples_yielded_streaming = 0
self._num_samples_yielded_combined: Dict[int, List[Any]] = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_data/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch):
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
)
dataloader = StreamingDataLoader(
dataset, batch_size=2, profile_bactches=profile, profile_dir=str(tmpdir), num_workers=1
dataset, batch_size=2, profile_batches=profile, profile_dir=str(tmpdir), num_workers=1
)
dataloader_iter = iter(dataloader)
batches = []
Expand Down
Loading