Skip to content

Commit

Permalink
Fix random sampler patch state management. Also add unit tests for it
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
gokulavasan committed Apr 9, 2024
1 parent cf39424 commit a5b7695
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 22 deletions.
33 changes: 33 additions & 0 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,39 @@ def test_map_shuffle(self):

self.assertEqual(batches, exp)

def test_map_iterrupted_shuffle(self):
every_n_steps = 10

for pw, num_workers, every_n_steps in itertools.product([False, True], [0, 2], [1, 5, 10, 15]):
dataset = DummyMapDataset(10, shuffle=True)
dl = StatefulDataLoader(
dataset=dataset,
shuffle=True,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw if num_workers > 0 else False,
)

it = iter(dl)
state0 = dl.state_dict()
exp = []
for _ in range(4):
exp.append(next(it))
state1 = dl.state_dict()

dl.load_state_dict(state1)
it = iter(dl)
for data in it:
exp.append(data)

dl.load_state_dict(state0)
batches = []
for data in iter(dl):
batches.append(data)

self.assertEqual(batches, exp)


class TestSnapshotEnd(unittest.TestCase):
def test_generator(self):
Expand Down
54 changes: 32 additions & 22 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,52 @@
from typing import Any, Dict, Optional, Sized
from typing import Any, Dict, Iterator, Optional, Sized

import torch.utils.data.sampler
from torch.utils.data.dataloader import _InfiniteConstantSampler

from .stateful import Stateful


class RandomSampler(torch.utils.data.sampler.RandomSampler, Stateful):
class _StatefulRandomSamplerIterator(Iterator[int], Stateful):
def __init__(self, sampler, parent_iterator: Iterator[int]):
self.sampler = sampler
self.parent_iterator = parent_iterator
self.yielded = 0
self.next_yielded = None
self.generator_state = sampler.generator.get_state()

def __next__(self) -> int:
if self.next_yielded is not None:
for _ in range(self.next_yielded):
next(self.parent_iterator)

self.yielded = self.next_yielded
self.next_yielded = None

val = next(self.parent_iterator)
self.yielded += 1
return val

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.generator_state = state_dict["generator"]
self.sampler.generator.set_state(state_dict["generator"])
self.next_yielded = state_dict["yielded"]

def state_dict(self) -> Dict[str, Any]:
return {"generator": self.generator_state, "yielded": self.yielded}


class RandomSampler(torch.utils.data.sampler.RandomSampler):
def __init__(
self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None
):

if generator is None:
# Ensure that underlying sampler has something repeatable
generator = torch.Generator()
generator.manual_seed(1)
super().__init__(data_source, replacement, num_samples, generator)
self.yielded = 0
self.next_yielded = None

def state_dict(self) -> Dict[str, Any]:
return {"generator": self.generator.get_state() if self.generator else None, "yielded": self.yielded}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if state_dict["generator"] is not None:
self.generator.set_state(state_dict["generator"])
self.next_yielded = state_dict["yielded"]

def __iter__(self):
super_iter = super().__iter__()
self.yielded = self.next_yielded or 0
while True:
try:
val = next(super_iter)
yield val
self.yielded += 1
except StopIteration:
return
return _StatefulRandomSamplerIterator(self, super().__iter__())


torch.utils.data.sampler.RandomSampler = RandomSampler # type: ignore[misc]
Expand Down

0 comments on commit a5b7695

Please sign in to comment.