Skip to content

Commit

Permalink
fix batchsampler does not work correctly (#20327)
Browse files Browse the repository at this point in the history
* fix batchsampler does not work correctly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add batch sampler shuffle state test
  • Loading branch information
dadwadw233 authored Nov 13, 2024
1 parent 1f2d7a1 commit bd5866b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/lightning/pytorch/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ def _is_dataloader_shuffled(dataloader: object) -> bool:
if not hasattr(dataloader, "sampler"):
# shuffling is enabled via a sampler. No sampler, no shuffling
return False
sampler = dataloader.sampler
batch_sampler = dataloader.batch_sampler
sampler = batch_sampler.sampler if batch_sampler is not None else dataloader.sampler
if isinstance(sampler, SequentialSampler):
return False
return isinstance(sampler, RandomSampler)
28 changes: 27 additions & 1 deletion tests/tests_pytorch/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.data import (
_get_dataloader_init_args_and_kwargs,
_is_dataloader_shuffled,
_update_dataloader,
extract_batch_size,
has_len_all_ranks,
Expand All @@ -20,7 +21,7 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler


def test_extract_batch_size():
Expand Down Expand Up @@ -304,6 +305,31 @@ def __init__(self, extra_arg):
_ = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)


def test_batch_sampler_shuffle_setting():
"""Test whether the `shuffle` state is correctly set in the `BatchSampler`."""

random_sampler = RandomSampler(range(10))
seq_sampler = SequentialSampler(range(10))
shuffled_dataloader = DataLoader(
range(10), batch_sampler=BatchSampler(random_sampler, batch_size=2, drop_last=False)
)
sequential_dataloader = DataLoader(
range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=2, drop_last=False)
)

# if batch_size is 1, the pytorch init a default SequentialSampler and set BatchSampler to None
single_dataloader = DataLoader(range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=1, drop_last=False))
assert _is_dataloader_shuffled(shuffled_dataloader)
assert not _is_dataloader_shuffled(sequential_dataloader)
assert not _is_dataloader_shuffled(single_dataloader)

# if batch_size is 1, and no batch_sampler is set, the pytorch will set BatchSampler to None
single_dataloader = DataLoader(range(10), batch_size=1)
shuffled_single_dataloader = DataLoader(range(10), batch_size=1, shuffle=True)
assert not _is_dataloader_shuffled(single_dataloader)
assert _is_dataloader_shuffled(shuffled_single_dataloader)


@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
Expand Down

0 comments on commit bd5866b

Please sign in to comment.