diff --git a/docs/source-pytorch/extensions/loops.rst b/docs/source-pytorch/extensions/loops.rst index 5c7385ec7c0b5..c35fa27f296e0 100644 --- a/docs/source-pytorch/extensions/loops.rst +++ b/docs/source-pytorch/extensions/loops.rst @@ -259,28 +259,6 @@ run (optional) ---------- -Subloops --------- - -When you want to customize nested loops within loops use the :meth:`~pytorch_lightning.loops.loop.Loop.connect` method: - -.. code-block:: python - - # Optional: stitch back the trainer arguments - epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) - # Optional: connect children loops as they might have existing state - epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) - # Instantiate and connect the loop. - trainer.fit_loop.connect(epoch_loop=epoch_loop) - trainer.fit(model) - -More about the built-in loops and how they are composed is explained in the next section. - -.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif - :alt: Animation showing how to connect a custom subloop - ----------- - Built-in Loops -------------- @@ -342,71 +320,6 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt It simply iterates over each prediction dataloader from one to the next by calling :code:`PredictionEpochLoop.run()` in its :code:`advance()` method. ----------- - -Available Loops in Lightning Flash ----------------------------------- - -`Active Learning `__ is a machine learning practice in which the user interacts with the learner in order to provide new labels when required. - -You can find a real use case in `Lightning Flash `_. - -Flash implements the :code:`ActiveLearningLoop` that you can use together with the :code:`ActiveLearningDataModule` to label new data on the fly. -To run the following demo, install Flash and `BaaL `__ first: - -.. code-block:: bash - - pip install lightning-flash[image] baal - -.. code-block:: python - - import torch - - import flash - from flash.core.classification import Probabilities - from flash.core.data.utils import download_data - from flash.image import ImageClassificationData, ImageClassifier - from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop - - # 1. Create the DataModule - download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") - - # Implement the research use-case where we mask labels from labelled dataset. - datamodule = ActiveLearningDataModule( - ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2), - initial_num_labels=5, - val_split=0.1, - ) - - # 2. Build the task - head = torch.nn.Sequential( - torch.nn.Dropout(p=0.1), - torch.nn.Linear(512, datamodule.num_classes), - ) - model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=Probabilities()) - - - # 3.1 Create the trainer - trainer = flash.Trainer(max_epochs=3) - - # 3.2 Create the active learning loop and connect it to the trainer - active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1) - active_learning_loop.connect(trainer.fit_loop) - trainer.fit_loop = active_learning_loop - - # 3.3 Finetune - trainer.finetune(model, datamodule=datamodule, strategy="freeze") - - # 4. Predict what's on a few images! ants or bees? - predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") - print(predictions) - - # 5. Save the model! - trainer.save_checkpoint("image_classification_model.pt") - -Here is the `Active Learning Loop example `_ and the `code for the active learning loop `_. - - ---------- Advanced Examples diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1fd76c02c350c..7b063b330beeb 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -55,6 +55,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed support for loop customization * Removed `Loop.replace()` ([#16361](https://github.com/Lightning-AI/lightning/pull/16361)) + * Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) + * Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) - Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172)) * Removed the `LightningModule.truncated_bptt_steps` attribute diff --git a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py index 2c96a9e2fc130..f2d840590e1e0 100644 --- a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -78,10 +78,6 @@ def prefetch_batches(self) -> int: is_unsized = batches[self.current_dataloader_idx] == float("inf") return int(is_unsized) - def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override] - """Connect the evaluation epoch loop with this loop.""" - self.epoch_loop = epoch_loop - @property def done(self) -> bool: """Returns whether all dataloaders are processed or evaluation should be skipped altogether.""" diff --git a/src/pytorch_lightning/loops/dataloader/prediction_loop.py b/src/pytorch_lightning/loops/dataloader/prediction_loop.py index 606bfcc4024ce..1f9df89a00501 100644 --- a/src/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/src/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -66,10 +66,6 @@ def dataloaders(self) -> Sequence[DataLoader]: def skip(self) -> bool: return sum(self.max_batches) == 0 - def connect(self, epoch_loop: PredictionEpochLoop) -> None: # type: ignore[override] - """Connect the prediction epoch loop with this loop.""" - self.epoch_loop = epoch_loop - def reset(self) -> None: """Resets the internal state of the loop for a new run.""" self.predictions = [] diff --git a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 77bca172d56b8..66794a8caf0ac 100644 --- a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -38,9 +38,6 @@ def should_store_predictions(self) -> bool: any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) return self.return_predictions or any_pred - def connect(self, **kwargs: "Loop") -> None: - raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") - def reset(self) -> None: """Resets the loops internal state.""" self._seen_batch_indices = [] diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index e63569f67a960..3abcfb95204d4 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -121,20 +121,6 @@ def done(self) -> bool: return False - def connect( # type: ignore[override] - self, - optimizer_loop: Optional[OptimizerLoop] = None, - manual_loop: Optional[ManualOptimization] = None, - val_loop: Optional["loops.EvaluationLoop"] = None, - ) -> None: - """Optionally connect a custom batch or validation loop to this training epoch loop.""" - if optimizer_loop is not None: - self.optimizer_loop = optimizer_loop - if manual_loop is not None: - self.manual_loop = manual_loop - if val_loop is not None: - self.val_loop = val_loop - def reset(self) -> None: """Resets the internal state of the loop for a new run.""" if self.restarting: diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index ad1e2c8b4e967..d47a5cec866c2 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -169,10 +169,6 @@ def skip(self) -> bool: # until `on_run_start`, we use `limit_train_batches` instead return self.done or self.trainer.limit_train_batches == 0 - def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override] - """Connects a training epoch loop to this fit loop.""" - self.epoch_loop = epoch_loop - def reset(self) -> None: """Resets the internal state of this loop.""" if self.restarting: diff --git a/src/pytorch_lightning/loops/loop.py b/src/pytorch_lightning/loops/loop.py index 9047be149f68e..461b0a6a2f6f1 100644 --- a/src/pytorch_lightning/loops/loop.py +++ b/src/pytorch_lightning/loops/loop.py @@ -100,12 +100,6 @@ def skip(self): """ return False - def connect(self, **kwargs: "Loop") -> None: - """Optionally connect one or multiple loops to this one. - - Linked loops should form a tree. - """ - def on_skip(self) -> T: """The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. diff --git a/src/pytorch_lightning/loops/optimization/optimizer_loop.py b/src/pytorch_lightning/loops/optimization/optimizer_loop.py index 819178cff15c0..07284198aa183 100644 --- a/src/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/src/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -172,9 +172,6 @@ def done(self) -> bool: """Returns ``True`` when the last optimizer in the sequence has run.""" return self.optim_progress.optimizer_position >= len(self._indices) - def connect(self, **kwargs: "Loop") -> None: - raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") - def reset(self) -> None: if not self.restarting: # when reset() is called from outside (manually), we reset the loop progress diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 1f06fbd11ecaf..0dab4ee15f396 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -371,21 +371,16 @@ def __init__( self._signal_connector = SignalConnector(self) self.tuner = Tuner(self) - fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs) - training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps) - fit_loop.connect(epoch_loop=training_epoch_loop) - - # default .fit() loop - self.fit_loop = fit_loop - - # default .validate() loop + # init loops + self.fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs) + self.fit_loop.epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps) self.validate_loop = EvaluationLoop() - - # default .test() loop self.test_loop = EvaluationLoop() - - # default .predict() loop self.predict_loop = PredictionLoop() + self.fit_loop.trainer = self + self.validate_loop.trainer = self + self.test_loop.trainer = self + self.predict_loop.trainer = self # init callbacks # Declare attributes to be set in _callback_connector on_trainer_init @@ -1103,8 +1098,6 @@ def _run_train(self) -> None: self.model.train() torch.set_grad_enabled(True) - self.fit_loop.trainer = self - with torch.autograd.set_detect_anomaly(self._detect_anomaly): self.fit_loop.run() @@ -1114,9 +1107,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: # reload dataloaders self._evaluation_loop._reload_evaluation_dataloaders() - # reset trainer on this loop and all child loops in case user connected a custom loop - self._evaluation_loop.trainer = self - with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context( self.accelerator, self._inference_mode ): @@ -1133,8 +1123,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: self.reset_predict_dataloader(self.lightning_module) - # reset trainer on this loop and all child loops in case user connected a custom loop - self.predict_loop.trainer = self with _evaluation_context(self.accelerator, self._inference_mode): return self.predict_loop.run() @@ -1955,63 +1943,6 @@ def is_last_batch(self) -> bool: """Whether trainer is executing the last batch.""" return self.fit_loop.epoch_loop.batch_progress.is_last_batch - @property - def fit_loop(self) -> FitLoop: - return self._fit_loop - - @fit_loop.setter - def fit_loop(self, loop: FitLoop) -> None: - """Attach a custom fit loop to this Trainer. - - It will run with - :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`. - """ - loop.trainer = self - self._fit_loop = loop - - @property - def validate_loop(self) -> EvaluationLoop: - return self._validate_loop - - @validate_loop.setter - def validate_loop(self, loop: EvaluationLoop) -> None: - """Attach a custom validation loop to this Trainer. - - It will run with - :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`. Note that this loop is different from the one - running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call. - """ - loop.trainer = self - self._validate_loop = loop - - @property - def test_loop(self) -> EvaluationLoop: - return self._test_loop - - @test_loop.setter - def test_loop(self, loop: EvaluationLoop) -> None: - """Attach a custom test loop to this Trainer. - - It will run with - :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`. - """ - loop.trainer = self - self._test_loop = loop - - @property - def predict_loop(self) -> PredictionLoop: - return self._predict_loop - - @predict_loop.setter - def predict_loop(self, loop: PredictionLoop) -> None: - """Attach a custom prediction loop to this Trainer. - - It will run with - :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. - """ - loop.trainer = self - self._predict_loop = loop - @property def _evaluation_loop(self) -> EvaluationLoop: if self.state.fn == TrainerFn.FITTING: diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 0701071df2a7c..dbb944ae33352 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -25,72 +25,27 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from pytorch_lightning.loops import Loop, OptimizerLoop +from pytorch_lightning.loops import Loop from pytorch_lightning.trainer.progress import BaseProgress from tests_pytorch.helpers.runif import RunIf -class NestedLoop(Loop): - def __init__(self): - super().__init__() - self.child_loop0 = None - self.child_loop1 = None - - @property - def done(self) -> bool: - return False - - def connect(self, child0, child1): - self.child_loop0 = child0 - self.child_loop1 = child1 - - def reset(self) -> None: - pass - - def advance(self, *args, **kwargs): - pass - - -@pytest.mark.parametrize("loop_name", ["fit_loop", "validate_loop", "test_loop", "predict_loop"]) -def test_connect_loops_direct(loop_name): - """Test Trainer references in loops on assignment.""" - loop = NestedLoop() - - with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): - _ = loop.trainer - - trainer = Trainer() - - # trainer.loop_name = loop - setattr(trainer, loop_name, loop) - assert loop.trainer is trainer - - -def test_connect_loops_recursive(): - """Test Trainer references in a nested loop assigned to a Trainer.""" - main_loop = NestedLoop() - child0 = NestedLoop() - child1 = NestedLoop() - main_loop.connect(child0, child1) - - with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): - _ = main_loop.trainer - - with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): - _ = main_loop.child_loop0.trainer - - trainer = Trainer() - trainer.fit_loop = main_loop - assert child0.trainer is child1.trainer - assert child0.trainer is trainer - - def test_restarting_loops_recursive(): - class MyLoop(NestedLoop): + class MyLoop(Loop): def __init__(self, loop=None): super().__init__() self.child = loop + @property + def done(self) -> bool: + return False + + def reset(self) -> None: + pass + + def advance(self, *args, **kwargs): + pass + loop = MyLoop(MyLoop(MyLoop())) assert not loop.restarting @@ -102,23 +57,6 @@ def __init__(self, loop=None): assert loop.child.child.restarting -def test_connect_subloops(tmpdir): - """Test connecting individual subloops by calling `trainer.x.y.connect()`""" - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - - epoch_loop = trainer.fit_loop.epoch_loop - new_optimizer_loop = OptimizerLoop() - epoch_loop.connect(optimizer_loop=new_optimizer_loop) - assert epoch_loop.optimizer_loop is new_optimizer_loop - - with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): - _ = new_optimizer_loop.trainer - - trainer.fit(model) - assert new_optimizer_loop.trainer is trainer - - class CustomException(Exception): pass diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 8cc06315e3e55..ffa3e40995c04 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -158,7 +158,7 @@ def test_fit_loop_done_log_messages(caplog): epoch_loop = Mock() epoch_loop.global_step = 10 - fit_loop.connect(epoch_loop=epoch_loop) + fit_loop.epoch_loop = epoch_loop fit_loop.max_steps = 10 assert fit_loop.done assert "max_steps=10` reached" in caplog.text