From bd5866b29527e0d612184019171d41a7dcee5586 Mon Sep 17 00:00:00 2001 From: Yuanhong Yu <913217005@qq.com> Date: Wed, 13 Nov 2024 21:01:47 +0800 Subject: [PATCH] fix batchsampler does not work correctly (#20327) * fix batchsampler does not work correctly * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add batch sampler shuffle state test --- src/lightning/pytorch/utilities/data.py | 3 ++- tests/tests_pytorch/utilities/test_data.py | 28 +++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 41c5ea86e50fb..b58142b3a4012 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -349,7 +349,8 @@ def _is_dataloader_shuffled(dataloader: object) -> bool: if not hasattr(dataloader, "sampler"): # shuffling is enabled via a sampler. No sampler, no shuffling return False - sampler = dataloader.sampler + batch_sampler = dataloader.batch_sampler + sampler = batch_sampler.sampler if batch_sampler is not None else dataloader.sampler if isinstance(sampler, SequentialSampler): return False return isinstance(sampler, RandomSampler) diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index e9c80d95c58a4..d79e9e24383a0 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -12,6 +12,7 @@ from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.data import ( _get_dataloader_init_args_and_kwargs, + _is_dataloader_shuffled, _update_dataloader, extract_batch_size, has_len_all_ranks, @@ -20,7 +21,7 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning_utilities.test.warning import no_warning_call from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, RandomSampler +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler def test_extract_batch_size(): @@ -304,6 +305,31 @@ def __init__(self, extra_arg): _ = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) +def test_batch_sampler_shuffle_setting(): + """Test whether the `shuffle` state is correctly set in the `BatchSampler`.""" + + random_sampler = RandomSampler(range(10)) + seq_sampler = SequentialSampler(range(10)) + shuffled_dataloader = DataLoader( + range(10), batch_sampler=BatchSampler(random_sampler, batch_size=2, drop_last=False) + ) + sequential_dataloader = DataLoader( + range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=2, drop_last=False) + ) + + # if batch_size is 1, the pytorch init a default SequentialSampler and set BatchSampler to None + single_dataloader = DataLoader(range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=1, drop_last=False)) + assert _is_dataloader_shuffled(shuffled_dataloader) + assert not _is_dataloader_shuffled(sequential_dataloader) + assert not _is_dataloader_shuffled(single_dataloader) + + # if batch_size is 1, and no batch_sampler is set, the pytorch will set BatchSampler to None + single_dataloader = DataLoader(range(10), batch_size=1) + shuffled_single_dataloader = DataLoader(range(10), batch_size=1, shuffle=True) + assert not _is_dataloader_shuffled(single_dataloader) + assert _is_dataloader_shuffled(shuffled_single_dataloader) + + @pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING]) def test_dataloader_kwargs_replacement_with_iterable_dataset(mode): """Test that DataLoader kwargs are not replaced when using Iterable Dataset."""