diff --git a/pytorch_lightning/plugins/precision/amp.py b/pytorch_lightning/plugins/precision/amp.py new file mode 100644 index 0000000000000..cb3b263dcb8d4 --- /dev/null +++ b/pytorch_lightning/plugins/precision/amp.py @@ -0,0 +1,48 @@ +import torch +from lightning.pytorch.utilities import rank_zero_warn + + +def optimizer_step(self, optimizer, model, optimizer_idx, closure, **kwargs): + """Performs the actual optimizer step with proper gradient scaling.""" + scaler = self.scaler + + # Scale loss and compute gradients + if closure is not None: + with torch.cuda.amp.autocast(): + loss = closure() + scaler.scale(loss).backward() + + try: + # Unscale gradients before optimizer step + scaler.unscale_(optimizer) + + # Check if gradients are finite + valid_gradients = True + for param_group in optimizer.param_groups: + for param in param_group["params"]: + if param.grad is not None and not torch.isfinite(param.grad).all(): + valid_gradients = False + break + if not valid_gradients: + break + + if valid_gradients: + # If gradients are valid, step optimizer and update scaler + optimizer.step() + scaler.update() + else: + # Skip step and adjust scaler + scaler.update() + rank_zero_warn( + "Gradients have become NaN or inf. Skipping optimizer step but updating scaler. " + "This may affect model convergence.", + category=RuntimeWarning, + ) + except RuntimeError as e: + if "unscale_() has already been called" not in str(e): + raise + # Handle case where unscale was already called + optimizer.step() + scaler.update() + + optimizer.zero_grad() diff --git a/tests/tests_pytorch/loops/optimization/test_manual_loop.py b/tests/tests_pytorch/loops/optimization/test_manual_loop.py index 67be30b24e159..f927781e11f39 100644 --- a/tests/tests_pytorch/loops/optimization/test_manual_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_manual_loop.py @@ -42,3 +42,38 @@ def training_step(self, batch, batch_idx): with pytest.raises(MisconfigurationException, match="return a Tensor or have no return"): trainer.fit(model) + + +def test_amp_training_updates_weights(tmp_path): + """Test that model weights are properly updated with mixed precision training.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.previous_params = None + self.layer = torch.nn.Linear(32, 32) # Same input/output size + + def training_step(self, batch, batch_idx): + # Track parameter changes + params = torch.cat([param.view(-1) for param in self.parameters()]) + if self.previous_params is not None: + num_different_values = (self.previous_params != params).sum().item() + assert num_different_values > 0, f"Parameters did not update at step {batch_idx}" + self.previous_params = params.clone().detach() + + # Regular training step + x = batch[0] + output = self.layer(x) + loss = torch.nn.functional.mse_loss(output, x) # Autoencoder-style loss + return loss + + model = TestModel() + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + limit_train_batches=10, + precision="16-mixed", + accelerator="auto", + devices=1, + ) + trainer.fit(model) diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index 78a63a7e9d2a7..2dca993d261d7 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -43,6 +43,7 @@ def test_is_overridden(): def test_mixed_imports_unified(): from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized as new_unwrap from lightning.pytorch.utilities.model_helpers import is_overridden as new_is_overridden + from pytorch_lightning.callbacks import EarlyStopping as OldEarlyStopping from pytorch_lightning.demos.boring_classes import BoringModel as OldBoringModel