Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: model weight updates with automatic_optimization=False in mixed precision training #20460

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pytorch_lightning/plugins/precision/amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from lightning.pytorch.utilities import rank_zero_warn


def optimizer_step(self, optimizer, model, optimizer_idx, closure, **kwargs):
"""Performs the actual optimizer step with proper gradient scaling."""
scaler = self.scaler

# Scale loss and compute gradients
if closure is not None:
with torch.cuda.amp.autocast():
loss = closure()
scaler.scale(loss).backward()

try:
# Unscale gradients before optimizer step
scaler.unscale_(optimizer)

# Check if gradients are finite
valid_gradients = True
for param_group in optimizer.param_groups:
for param in param_group["params"]:
if param.grad is not None and not torch.isfinite(param.grad).all():
valid_gradients = False
break
if not valid_gradients:
break

if valid_gradients:
# If gradients are valid, step optimizer and update scaler
optimizer.step()
scaler.update()
else:
# Skip step and adjust scaler
scaler.update()
rank_zero_warn(
"Gradients have become NaN or inf. Skipping optimizer step but updating scaler. "
"This may affect model convergence.",
category=RuntimeWarning,
)
except RuntimeError as e:
if "unscale_() has already been called" not in str(e):
raise
# Handle case where unscale was already called
optimizer.step()
scaler.update()

optimizer.zero_grad()
35 changes: 35 additions & 0 deletions tests/tests_pytorch/loops/optimization/test_manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,38 @@ def training_step(self, batch, batch_idx):

with pytest.raises(MisconfigurationException, match="return a Tensor or have no return"):
trainer.fit(model)


def test_amp_training_updates_weights(tmp_path):
"""Test that model weights are properly updated with mixed precision training."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.previous_params = None
self.layer = torch.nn.Linear(32, 32) # Same input/output size

def training_step(self, batch, batch_idx):
# Track parameter changes
params = torch.cat([param.view(-1) for param in self.parameters()])
if self.previous_params is not None:
num_different_values = (self.previous_params != params).sum().item()
assert num_different_values > 0, f"Parameters did not update at step {batch_idx}"
self.previous_params = params.clone().detach()

# Regular training step
x = batch[0]
output = self.layer(x)
loss = torch.nn.functional.mse_loss(output, x) # Autoencoder-style loss
return loss

model = TestModel()
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=10,
precision="16-mixed",
accelerator="auto",
devices=1,
)
trainer.fit(model)
1 change: 1 addition & 0 deletions tests/tests_pytorch/utilities/test_model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_is_overridden():
def test_mixed_imports_unified():
from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized as new_unwrap
from lightning.pytorch.utilities.model_helpers import is_overridden as new_is_overridden

from pytorch_lightning.callbacks import EarlyStopping as OldEarlyStopping
from pytorch_lightning.demos.boring_classes import BoringModel as OldBoringModel

Expand Down
Loading