Skip to content

Commit

Permalink
Merge branch 'master' into ci/bump-pt-2.6
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jan 7, 2025
2 parents 6e1f1ba + 76f0c54 commit daacd5d
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 23 deletions.
85 changes: 64 additions & 21 deletions docs/source-pytorch/common/tbptt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split.
.. code-block:: python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset, DataLoader
class LitModel(LightningModule):
import lightning as L
class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
self.dataset_len = dataset_len
self.sequence_len = sequence_len
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
top, bottom = self.input_seq.chunk(2, -1)
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
def __len__(self):
return self.dataset_len
def __getitem__(self, item):
return self.input_seq[item], self.output_seq[item]
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.batch_size = 10
self.in_features = 10
self.out_features = 5
self.hidden_dim = 20
# 1. Switch to manual optimization
self.automatic_optimization = False
self.truncated_bptt_steps = 10
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
def forward(self, x, hs):
seq, hs = self.rnn(x, hs)
return self.linear_out(seq), hs
# 2. Remove the `hiddens` argument
def training_step(self, batch, batch_idx):
# 3. Split the batch in chunks along the time dimension
split_batches = split_batch(batch, self.truncated_bptt_steps)
batch_size = 10
hidden_dim = 20
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
for split_batch in range(split_batches):
# 4. Perform the optimization in a loop
loss, hiddens = self.my_rnn(split_batch, hiddens)
self.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
x, y = batch
split_x, split_y = [
x.tensor_split(self.truncated_bptt_steps, dim=1),
y.tensor_split(self.truncated_bptt_steps, dim=1)
]
hiddens = None
optimizer = self.optimizers()
losses = []
# 4. Perform the optimization in a loop
for x, y in zip(split_x, split_y):
y_pred, hiddens = self(x, hiddens)
loss = F.mse_loss(y_pred, y)
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
# 5. "Truncate"
hiddens = hiddens.detach()
hiddens = [h.detach() for h in hiddens]
losses.append(loss.detach())
avg_loss = sum(losses) / len(losses)
self.log("train_loss", avg_loss, prog_bar=True)
# 6. Remove the return of `hiddens`
# Returning loss in manual optimization is not needed
return None
def configure_optimizers(self):
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
return optim.Adam(self.parameters(), lr=0.001)
def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=self.batch_size)
if __name__ == "__main__":
model = LitModel()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloader) # Define your own dataloader
trainer = L.Trainer(max_epochs=5)
trainer.fit(model)
2 changes: 1 addition & 1 deletion src/lightning/pytorch/utilities/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _check_mixed_imports(instance: object) -> None:
_R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method


class _restricted_classmethod_impl(Generic[_T, _P, _R_co]):
class _restricted_classmethod_impl(Generic[_T, _R_co, _P]):
"""Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
instead of a class type."""

Expand Down
51 changes: 51 additions & 0 deletions tests/tests_pytorch/helpers/advanced_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,54 @@ def configure_optimizers(self):

def train_dataloader(self):
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)


class TBPTTModule(LightningModule):
def __init__(self):
super().__init__()

self.batch_size = 10
self.in_features = 10
self.out_features = 5
self.hidden_dim = 20

self.automatic_optimization = False
self.truncated_bptt_steps = 10

self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)

def forward(self, x, hs):
seq, hs = self.rnn(x, hs)
return self.linear_out(seq), hs

def training_step(self, batch, batch_idx):
x, y = batch
split_x, split_y = [
x.tensor_split(self.truncated_bptt_steps, dim=1),
y.tensor_split(self.truncated_bptt_steps, dim=1),
]

hiddens = None
optimizer = self.optimizers()
losses = []

for x, y in zip(split_x, split_y):
y_pred, hiddens = self(x, hiddens)
loss = F.mse_loss(y_pred, y)

optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()

# "Truncate"
hiddens = [h.detach() for h in hiddens]
losses.append(loss.detach())

return

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)

def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=self.batch_size)
9 changes: 8 additions & 1 deletion tests/tests_pytorch/helpers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel

from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule
from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel
Expand Down Expand Up @@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class):
model.to_torchscript()
if data_class:
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)


def test_tbptt(tmp_path):
model = TBPTTModule()

trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
trainer.fit(model)

0 comments on commit daacd5d

Please sign in to comment.