Skip to content

Commit

Permalink
[Nodes] Add Prebatch setting to ParallelMapper (#1417)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho authored Jan 2, 2025
1 parent 88c7b96 commit 0d2b0a0
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ repos:
- usort == 1.0.0

- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.1.0
hooks:
- id: flake8
37 changes: 29 additions & 8 deletions test/nodes/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import itertools

import unittest
from typing import List
from typing import List, Optional

from parameterized import parameterized
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_exception_handling_mapper_multiprocess(self):
def test_exception_handling_mapper_multiprocess_cuda(self):
self._test_exception_handling_mapper(True, "process")

def _test_map(self, in_order, method) -> None:
def _test_map(self, in_order, method, prebatch) -> None:
batch_size = 6
n = 80
multiprocessing_context = None if IS_WINDOWS else "forkserver"
Expand All @@ -68,6 +68,7 @@ def _test_map(self, in_order, method) -> None:
in_order=in_order,
method=method,
multiprocessing_context=multiprocessing_context,
prebatch=prebatch,
)
node = Prefetcher(node, prefetch_factor=2)

Expand Down Expand Up @@ -98,25 +99,40 @@ def _test_map(self, in_order, method) -> None:
)

def test_in_order_threads(self):
self._test_map(True, "thread")
self._test_map(True, "thread", None)

def test_out_of_order_threads(self):
self._test_map(False, "thread")
self._test_map(False, "thread", None)

def test_in_order_process(self):
self._test_map(True, "process")
self._test_map(True, "process", None)

def test_out_of_order_process(self):
self._test_map(False, "process")
self._test_map(False, "process", None)

def test_in_order_thread_prebatch(self):
self._test_map(True, "thread", 3)

def test_out_of_order_thread_prebatch(self):
self._test_map(False, "thread", 3)

def test_in_order_process_prebatch(self):
self._test_map(True, "process", 3)

def test_out_of_order_process_prebatch(self):
self._test_map(False, "process", 3)

@parameterized.expand(
itertools.product(
[0, 7, 13],
[True], # TODO: define and fix in_order = False
[0, 1, 9], # TODO: define and fix in_order = False
[None, 3], # prebatch
)
)
def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_frequency: int):
def test_save_load_state_thread(
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
):
method = "thread"
batch_size = 6
n = 80
Expand All @@ -129,6 +145,7 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
in_order=in_order,
method=method,
snapshot_frequency=snapshot_frequency,
prebatch=prebatch,
)
node = Prefetcher(node, prefetch_factor=2)
run_test_save_load_state(self, node, midpoint)
Expand All @@ -138,9 +155,12 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
[0, 7, 13],
[True], # TODO: define and fix in_order = False
[0, 1, 9], # TODO: define and fix in_order = False
[None, 3], # prebatch
)
)
def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_frequency: int):
def test_save_load_state_process(
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
):
method = "process"
batch_size = 6
n = 80
Expand All @@ -155,6 +175,7 @@ def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_f
method=method,
multiprocessing_context=multiprocessing_context,
snapshot_frequency=snapshot_frequency,
prebatch=prebatch,
)
node = Prefetcher(node, prefetch_factor=2)
run_test_save_load_state(self, node, midpoint)
142 changes: 115 additions & 27 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import queue
import threading
import time
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Protocol, TypeVar, Union
from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, Protocol, Sequence, TypeVar, Union

import torch.multiprocessing as mp
from torchdata.nodes.base_node import BaseNode, T
from torchdata.nodes.batch import Batcher, Unbatcher
from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper
from torchdata.nodes.snapshot_store import QueueSnapshotStore, SnapshotStore

Expand Down Expand Up @@ -52,6 +53,18 @@ def Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) -> "ParallelMapper[T]"
)


Xseq = Sequence[X]
Tseq = Sequence[T]


class MapOverBatch(Generic[X, T]):
def __init__(self, map_fn: Callable[[X], T]):
self.map_fn = map_fn

def __call__(self, xlist: Sequence[X]) -> Sequence[T]:
return [self.map_fn(x) for x in xlist]


def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_event: threading.Event):
buffer: Dict[int, Any] = {}
cur_idx = 0
Expand Down Expand Up @@ -272,6 +285,78 @@ def _shutdown(self):
t.join(timeout=QUEUE_TIMEOUT * 5)


class _ParallelMapperImpl(BaseNode[T]):
"""This class implements _ParallelMapperIter and _InlineMapperIter as a BaseNode,
allowing them to be composed with other BaseNodes.
TODO: In the future, this class may go away once we implement reset() on
_ParallelMapperIter and _InlineMapperIter themselves so we don't need this
additional level of abstraction.
"""

def __init__(
self,
source: BaseNode[X],
map_fn: Callable[[X], T],
num_workers: int,
in_order: bool = True,
method: Literal["thread", "process"] = "thread",
multiprocessing_context: Optional[str] = None,
max_concurrent: Optional[int] = None,
snapshot_frequency: int = 1,
):
super().__init__()
assert method in ["thread", "process"]
self.source = source
self.map_fn = map_fn
self.num_workers = num_workers
self.in_order = in_order
self.method = method
self.multiprocessing_context = multiprocessing_context
self._mp_context: Any = mp
if self.method == "process" and self.multiprocessing_context is not None:
self._mp_context = mp.get_context(self.multiprocessing_context)

if max_concurrent is not None and num_workers > 0:
if not isinstance(max_concurrent, int) and max_concurrent > num_workers:
raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!")
self.max_concurrent = max_concurrent
self.snapshot_frequency = snapshot_frequency
self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if self._it is not None:
del self._it

if self.num_workers > 0:
self._it = self._parallel_reset(initial_state)
else:
self._it = self._inline_reset(initial_state)

def _inline_reset(self, initial_state: Optional[Dict[str, Any]]):
return _InlineMapperIter(source=self.source, map_fn=self.map_fn, initial_state=initial_state)

def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]):
return _ParallelMapperIter(
source=self.source,
map_fn=self.map_fn,
num_workers=self.num_workers,
in_order=self.in_order,
method=self.method,
mp_context=self._mp_context,
max_concurrent=self.max_concurrent,
snapshot_frequency=self.snapshot_frequency,
initial_state=initial_state,
)

def next(self) -> T:
return next(self._it) # type: ignore[arg-type, union-attr]

def get_state(self) -> Dict[str, Any]:
return self._it.get_state() # type: ignore[union-attr]


class ParallelMapper(BaseNode[T]):
"""ParallelMapper executes map_fn in parallel either in num_workers threads or
processes. For processes, multiprocessing_context can be spawn, forkserver, fork,
Expand All @@ -294,8 +379,12 @@ class ParallelMapper(BaseNode[T]):
multiprocessing_context (Optional[str]): The multiprocessing context to use for parallel processing. Default is None.
max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
prebatch (Optional[int]): Optionally perform pre-batching of items from source before mapping.
For small items, this may improve throughput at the expense of peak memory.
"""

IT_STATE_KEY = "it_state"

def __init__(
self,
source: BaseNode[X],
Expand All @@ -306,58 +395,57 @@ def __init__(
multiprocessing_context: Optional[str] = None,
max_concurrent: Optional[int] = None,
snapshot_frequency: int = 1,
prebatch: Optional[int] = None,
):
super().__init__()
assert method in ["thread", "process"]
self.source = source
self.map_fn = map_fn
self.num_workers = num_workers
self.in_order = in_order
self.method = method
self.multiprocessing_context = multiprocessing_context
self._mp_context: Any = mp
if self.method == "process" and self.multiprocessing_context is not None:
self._mp_context = mp.get_context(self.multiprocessing_context)

if max_concurrent is not None and num_workers > 0:
if not isinstance(max_concurrent, int) and max_concurrent > num_workers:
raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!")
self.max_concurrent = max_concurrent
self.snapshot_frequency = snapshot_frequency
self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if self._it is not None:
self._it._shutdown()
del self._it

if self.num_workers > 0:
self._parallel_reset(initial_state)
self.prebatch = prebatch
if prebatch is None:
self.map_fn = map_fn
self.source = source
else:
self._inline_reset(initial_state)

def _inline_reset(self, initial_state: Optional[Dict[str, Any]]):
self._it = _InlineMapperIter(source=self.source, map_fn=self.map_fn, initial_state=initial_state)
if prebatch <= 0:
raise ValueError(f"{prebatch=} must be a positive integer!")
self.map_fn = MapOverBatch(map_fn=map_fn) # type: ignore[assignment]
self.source = Batcher(source, batch_size=prebatch, drop_last=False) # type: ignore[assignment]

def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]):
self._it = _ParallelMapperIter(
_it = _ParallelMapperImpl(
source=self.source,
map_fn=self.map_fn,
num_workers=self.num_workers,
in_order=self.in_order,
method=self.method,
mp_context=self._mp_context,
multiprocessing_context=self.multiprocessing_context,
max_concurrent=self.max_concurrent,
snapshot_frequency=self.snapshot_frequency,
initial_state=initial_state,
)

def next(self):
if self.prebatch is None:
self._it = _it
else:
self._it = Unbatcher(_it) # type: ignore[arg-type, assignment]

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self._it.reset(initial_state[self.IT_STATE_KEY])
else:
self._it.reset()

def next(self) -> T:
return next(self._it) # type: ignore[arg-type, union-attr]

def get_state(self) -> Dict[str, Any]:
return self._it.get_state() # type: ignore[union-attr]
return {self.IT_STATE_KEY: self._it.state_dict()} # type: ignore[union-attr]


_WorkerType = Callable[
Expand Down

0 comments on commit 0d2b0a0

Please sign in to comment.