Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model does not update its weights #20215

Open
kopalja opened this issue Aug 19, 2024 · 7 comments
Open

Model does not update its weights #20215

kopalja opened this issue Aug 19, 2024 · 7 comments
Assignees
Labels

Comments

@kopalja
Copy link

kopalja commented Aug 19, 2024

Bug description

Hi, I am using PyTorch lightning to implement some new optimization strategies using automatic_optimization=False. For certain setting my optimization strategy (using automatic_optimization=False) should yield the same results as using standard optimization process (automatic_optimization=True). However I could not make it work. My optimization process was returning slightly different results as using default optimization process. After a while I figured out that PyTorch lightning sometimes does not update the model weights when using the default automatic_optimization=True. I have put together minimal example in which model weights won't get updated on step 5. Model weights also won't get updated when using different hyper-parameters (e.g., batch-size, lr), only at different training step.

Am I missing something or does this look like a bug.
Thanks!

What version are you seeing the problem on?

v2.4

How to reproduce the bug

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.ModuleList(
            [
                nn.Conv2d(1, 64, 3, 1),
                nn.Conv2d(64, 64, 3, 1),
                nn.Conv2d(64, 128, 3, 1),
            ]
        )
        self.fc1 = nn.Linear(128, 10)

    def forward(self, x, target):
        for conv in self.convs:
            x = conv(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        logits = F.log_softmax(x, dim=1)
        return F.nll_loss(logits, target)


class MRELoop(pl.LightningModule):
    def __init__(self):
        super(MRELoop, self).__init__()
        self.model = CNN()
        self.dataset = datasets.MNIST(root=".mnist_data", download=True, transform=transforms.ToTensor())
        self.previous_params = None

    def training_step(self, batch, batch_idx):
        # Check whether new model weights differs from previous ones
        params = torch.cat([param.view(-1) for param in self.model.parameters()])
        if self.previous_params is not None:
            num_different_values = (self.previous_params != params).sum().item()
            self.trainer.should_stop = num_different_values == 0
        else:
            num_different_values = None

        self.previous_params = params
        loss = self.model.forward(*batch)
        print(
            f"step {batch_idx} | diff weights: {num_different_values} | all weights: {params.numel()} | weights mean: {torch.mean(params)} | loss: {loss.item()}"
        )
        return loss

    def configure_optimizers(self):
        # Bug occurs also with different lr only at differnt training step
        return torch.optim.AdamW(self.parameters(), lr=2e-3)
        # return torch.optim.SGD(self.parameters(), lr=9e-4) # Also with SGD

    def train_dataloader(self):
        return DataLoader(self.dataset)


if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")
    pl.seed_everything(1337)
    pl_trainer = pl.Trainer(
        precision="16-mixed",  # So far bug has occured only with 16-mixed
        deterministic=True,
        enable_progress_bar=False,
    )
    pl_trainer.fit(MRELoop())

Error messages and logs

/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python mvp.py ...
Using 16bit Automatic Mixed Precision (AMP)
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/amp.py:52: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | CNN  | 112 K  | train
---------------------------------------
112 K     Trainable params
0         Non-trainable params
112 K     Total params
0.451     Total estimated model params size (MB)
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
step 0 | diff weights: None | all weights: 112714 | weights mean: 1.6999114450300112e-05 | loss: 2.334902763366699
step 1 | diff weights: 112714 | all weights: 112714 | weights mean: 3.690078665385954e-05 | loss: 2.32588529586792
step 2 | diff weights: 112714 | all weights: 112714 | weights mean: -0.00010425636719446629 | loss: 2.621901512145996
step 3 | diff weights: 112714 | all weights: 112714 | weights mean: -0.00030326732667163014 | loss: 2.4029626846313477
step 4 | diff weights: 112714 | all weights: 112714 | weights mean: -0.0005236949073150754 | loss: 2.657553195953369
step 5 | diff weights: 0 | all weights: 112714 | weights mean: -0.0005236949073150754 | loss: 2.5822641849517822

Environment

Current environment
* CUDA:
	- GPU:
		- NVIDIA A100-PCIE-40GB
	- available:         True
	- version:           12.1
* Lightning:
	- lightning-utilities: 0.11.6
	- pytorch-lightning: 2.3.3
	- torch:             2.4.0
	- torchmetrics:      1.4.1
	- torchvision:       0.19.0
* Packages:
	- absl-py:           2.1.0
	- aiohappyeyeballs:  2.3.4
	- aiohttp:           3.10.1
	- aiosignal:         1.3.1
	- asttokens:         2.4.1
	- attrs:             24.1.0
	- autocommand:       2.2.2
	- backports.tarfile: 1.2.0
	- beautifulsoup4:    4.12.3
	- black:             24.8.0
	- certifi:           2024.7.4
	- charset-normalizer: 3.3.2
	- click:             8.1.7
	- comm:              0.2.2
	- datasets:          2.20.0
	- debugpy:           1.8.5
	- decorator:         5.1.1
	- dill:              0.3.8
	- exceptiongroup:    1.2.2
	- executing:         2.0.1
	- filelock:          3.15.4
	- frozenlist:        1.4.1
	- fsspec:            2024.5.0
	- gdown:             5.2.0
	- grpcio:            1.65.4
	- huggingface-hub:   0.24.5
	- idna:              3.7
	- importlib-metadata: 8.2.0
	- importlib-resources: 6.4.0
	- inflect:           7.3.1
	- ipykernel:         6.29.5
	- ipython:           8.26.0
	- isort:             5.13.2
	- jaraco.context:    5.3.0
	- jaraco.functools:  4.0.1
	- jaraco.text:       3.12.1
	- jedi:              0.19.1
	- jinja2:            3.1.4
	- jupyter-client:    8.6.2
	- jupyter-core:      5.7.2
	- lightning-utilities: 0.11.6
	- markdown:          3.6
	- markupsafe:        2.1.5
	- matplotlib-inline: 0.1.7
	- more-itertools:    10.3.0
	- mpmath:            1.3.0
	- multidict:         6.0.5
	- multiprocess:      0.70.16
	- mypy-extensions:   1.0.0
	- nest-asyncio:      1.6.0
	- networkx:          3.3
	- numpy:             2.0.1
	- nvidia-cublas-cu12: 12.1.3.1
	- nvidia-cuda-cupti-cu12: 12.1.105
	- nvidia-cuda-nvrtc-cu12: 12.1.105
	- nvidia-cuda-runtime-cu12: 12.1.105
	- nvidia-cudnn-cu12: 9.1.0.70
	- nvidia-cufft-cu12: 11.0.2.54
	- nvidia-curand-cu12: 10.3.2.106
	- nvidia-cusolver-cu12: 11.4.5.107
	- nvidia-cusparse-cu12: 12.1.0.106
	- nvidia-nccl-cu12:  2.20.5
	- nvidia-nvjitlink-cu12: 12.6.20
	- nvidia-nvtx-cu12:  12.1.105
	- ordered-set:       4.1.0
	- packaging:         24.1
	- pandas:            2.2.2
	- parso:             0.8.4
	- pathspec:          0.12.1
	- pexpect:           4.9.0
	- pickleshare:       0.7.5
	- pillow:            10.4.0
	- pip:               24.2
	- platformdirs:      4.2.2
	- prompt-toolkit:    3.0.47
	- protobuf:          4.25.4
	- psutil:            6.0.0
	- ptyprocess:        0.7.0
	- pure-eval:         0.2.3
	- pyarrow:           17.0.0
	- pyarrow-hotfix:    0.6
	- pygments:          2.18.0
	- pynvml:            11.5.3
	- pysocks:           1.7.1
	- python-dateutil:   2.9.0
	- pytorch-lightning: 2.3.3
	- pytz:              2024.1
	- pyyaml:            6.0.1
	- pyzmq:             26.1.0
	- regex:             2024.7.24
	- requests:          2.32.3
	- safetensors:       0.4.4
	- setuptools:        72.1.0
	- six:               1.16.0
	- soupsieve:         2.5
	- stack-data:        0.6.2
	- sympy:             1.13.1
	- tensorboard:       2.17.0
	- tensorboard-data-server: 0.7.2
	- tiktoken:          0.7.0
	- tokenizers:        0.19.1
	- tomli:             2.0.1
	- torch:             2.4.0
	- torchmetrics:      1.4.1
	- torchvision:       0.19.0
	- tornado:           6.4.1
	- tqdm:              4.66.5
	- traitlets:         5.14.3
	- transformers:      4.44.0
	- triton:            3.0.0
	- typeguard:         4.3.0
	- typing-extensions: 4.12.2
	- tzdata:            2024.1
	- urllib3:           2.2.2
	- wcwidth:           0.2.13
	- werkzeug:          3.0.3
	- wheel:             0.44.0
	- xxhash:            3.4.1
	- yarl:              1.9.4
	- zipp:              3.19.2
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.12.4
	- release:           3.10.0-1160.71.1.el7.x86_64
	- version:           #1 SMP Tue Jun 28 15:37:28 UTC 2022

More info

No response

@kopalja kopalja added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 19, 2024
@richbai90
Copy link

Thanks for reporting the issue. Setting precision to '32-true' fixes the problem for me.

@kopalja
Copy link
Author

kopalja commented Aug 19, 2024

Yes but that is not really the solution. In addition the problem might be still present and manifest itself at different training step.

@richbai90
Copy link

Agreed it's not a fix, but it saved me from having to rewrite my implementation or tell my PI that we had to wait for a bug to be fixed before we could finish our paper.

@b5y
Copy link

b5y commented Sep 15, 2024

Looks like it's affected lighting verion 2.3.3.

@iamarunbrahma
Copy link

This PR #20460 fix this issue, ptal

@lantiga lantiga removed the needs triage Waiting to be triaged by maintainers label Dec 4, 2024
@lantiga lantiga self-assigned this Dec 4, 2024
@lantiga
Copy link
Collaborator

lantiga commented Dec 4, 2024

@kopalja thank you for the investigation and the reproduction

@lantiga
Copy link
Collaborator

lantiga commented Dec 4, 2024

So this happens because automatic mixed precision in PyTorch is explicitly designed to work this way:

https://github.com/pytorch/pytorch/blob/51b7528e274d350c1d5091acc40572d6b43879b8/torch/amp/grad_scaler.py#L99

in order to avoid issues with nan gradients.

Here is the equivalent raw PyTorch code:

def train():
    model = CNN()
    dataset = datasets.MNIST(root=".mnist_data", download=True, transform=transforms.ToTensor())
    dataloader = DataLoader(dataset)

    scaler = GradScaler(device="cuda")

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
    previous_params = None

    for batch_idx, batch in enumerate(dataloader):
        params = torch.cat([param.view(-1) for param in model.parameters()])
        if previous_params is not None:
            num_different_values = (previous_params != params).sum().item()
            assert num_different_values != 0:
        else:
            num_different_values = None
        previous_params = params

        optimizer.zero_grad()

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            loss = model.forward(*batch)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        print(
            f"step {batch_idx} | diff weights: {num_different_values} | all weights: {params.numel()} | weights mean: {torch.mean(params)} | loss: {loss.item()}"
        )

if __name__ == '__main__':
    torch.set_float32_matmul_precision("high")
    torch.manual_seed(1337)

    train()

As you can see you get steps with no weight updates as expected.

Note that if you don't use the scaler in order to step you won't get the assert (weights will always change), but updates will ultimately be incorrect and likely to blow up:

def train():
    model = CNN()
    dataset = datasets.MNIST(root=".mnist_data", download=True, transform=transforms.ToTensor())
    dataloader = DataLoader(dataset)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
    previous_params = None

    for batch_idx, batch in enumerate(dataloader):
        params = torch.cat([param.view(-1) for param in model.parameters()])
        if previous_params is not None:
            num_different_values = (previous_params != params).sum().item()
            if num_different_values == 0:
                return
        else:
            num_different_values = None
        previous_params = params

        optimizer.zero_grad()

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            loss = model.forward(*batch)

        loss.backward()
        optimizer.step()

        print(
            f"step {batch_idx} | diff weights: {num_different_values} | all weights: {params.numel()} | weights mean: {torch.mean(params)} | loss: {loss.item()}"
        )


if __name__ == '__main__':
    torch.set_float32_matmul_precision("high")
    torch.manual_seed(1337)

    train()

@lantiga lantiga added working as intended Working as intended and removed bug Something isn't working labels Dec 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
5 participants