diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 775dc5dee77dc..51bd2056177bb 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -182,7 +182,7 @@ We welcome any useful contribution! For your convenience here's a recommended wo 1. Use tags in PR name for the following cases: - **\[blocked by #\]** if your work is dependent on other PRs. - - **\[wip\]** when you start to re-edit your work, mark it so no one will accidentally merge it in meantime. + - **[wip]** when you start to re-edit your work, mark it so no one will accidentally merge it in meantime. ### Question & Answer diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index b1d54bc5e12fc..742d748ed00c6 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -23,26 +23,26 @@ subprojects: - "pl-cpu (macOS-14, lightning, 3.10, 2.1)" - "pl-cpu (macOS-14, lightning, 3.11, 2.2.2)" - "pl-cpu (macOS-14, lightning, 3.11, 2.3)" - - "pl-cpu (macOS-14, lightning, 3.12, 2.4.1)" - - "pl-cpu (macOS-14, lightning, 3.12, 2.5.1)" + - "pl-cpu (macOS-14, lightning, 3.12.7, 2.4.1)" + - "pl-cpu (macOS-14, lightning, 3.12.7, 2.5.1)" - "pl-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)" - "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" - "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)" - "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.3)" - - "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)" - - "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)" + - "pl-cpu (ubuntu-22.04, lightning, 3.12.7, 2.4.1)" + - "pl-cpu (ubuntu-22.04, lightning, 3.12.7, 2.5.1)" - "pl-cpu (windows-2022, lightning, 3.9, 2.1, oldest)" - "pl-cpu (windows-2022, lightning, 3.10, 2.1)" - "pl-cpu (windows-2022, lightning, 3.11, 2.2.2)" - "pl-cpu (windows-2022, lightning, 3.11, 2.3)" - - "pl-cpu (windows-2022, lightning, 3.12, 2.4.1)" - - "pl-cpu (windows-2022, lightning, 3.12, 2.5.1)" + - "pl-cpu (windows-2022, lightning, 3.12.7, 2.4.1)" + - "pl-cpu (windows-2022, lightning, 3.12.7, 2.5.1)" - "pl-cpu (macOS-14, pytorch, 3.9, 2.1)" - "pl-cpu (ubuntu-20.04, pytorch, 3.9, 2.1)" - "pl-cpu (windows-2022, pytorch, 3.9, 2.1)" - - "pl-cpu (macOS-14, pytorch, 3.12, 2.5.1)" - - "pl-cpu (ubuntu-22.04, pytorch, 3.12, 2.5.1)" - - "pl-cpu (windows-2022, pytorch, 3.12, 2.5.1)" + - "pl-cpu (macOS-14, pytorch, 3.12.7, 2.5.1)" + - "pl-cpu (ubuntu-22.04, pytorch, 3.12.7, 2.5.1)" + - "pl-cpu (windows-2022, pytorch, 3.12.7, 2.5.1)" - id: "pytorch_lightning: Azure GPU" paths: @@ -176,26 +176,26 @@ subprojects: - "fabric-cpu (macOS-14, lightning, 3.10, 2.1)" - "fabric-cpu (macOS-14, lightning, 3.11, 2.2.2)" - "fabric-cpu (macOS-14, lightning, 3.11, 2.3)" - - "fabric-cpu (macOS-14, lightning, 3.12, 2.4.1)" - - "fabric-cpu (macOS-14, lightning, 3.12, 2.5.1)" + - "fabric-cpu (macOS-14, lightning, 3.12.7, 2.4.1)" + - "fabric-cpu (macOS-14, lightning, 3.12.7, 2.5.1)" - "fabric-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)" - "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)" - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)" - - "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)" - - "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)" + - "fabric-cpu (ubuntu-22.04, lightning, 3.12.7, 2.4.1)" + - "fabric-cpu (ubuntu-22.04, lightning, 3.12.7, 2.5.1)" - "fabric-cpu (windows-2022, lightning, 3.9, 2.1, oldest)" - "fabric-cpu (windows-2022, lightning, 3.10, 2.1)" - "fabric-cpu (windows-2022, lightning, 3.11, 2.2.2)" - "fabric-cpu (windows-2022, lightning, 3.11, 2.3)" - - "fabric-cpu (windows-2022, lightning, 3.12, 2.4.1)" - - "fabric-cpu (windows-2022, lightning, 3.12, 2.5.1)" + - "fabric-cpu (windows-2022, lightning, 3.12.7, 2.4.1)" + - "fabric-cpu (windows-2022, lightning, 3.12.7, 2.5.1)" - "fabric-cpu (macOS-14, fabric, 3.9, 2.1)" - "fabric-cpu (ubuntu-20.04, fabric, 3.9, 2.1)" - "fabric-cpu (windows-2022, fabric, 3.9, 2.1)" - - "fabric-cpu (macOS-14, fabric, 3.12, 2.5.1)" - - "fabric-cpu (ubuntu-22.04, fabric, 3.12, 2.5.1)" - - "fabric-cpu (windows-2022, fabric, 3.12, 2.5.1)" + - "fabric-cpu (macOS-14, fabric, 3.12.7, 2.5.1)" + - "fabric-cpu (ubuntu-22.04, fabric, 3.12.7, 2.5.1)" + - "fabric-cpu (windows-2022, fabric, 3.12.7, 2.5.1)" - id: "lightning_fabric: Azure GPU" paths: diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 58f4afe529509..3bdc8f9a0b07f 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -16,7 +16,7 @@ Brief description of all our automation tools used for boosting development perf | .azure-pipelines/gpu-benchmarks.yml | Run speed/memory benchmarks for parity with vanila PyTorch. | GPU | | .github/workflows/ci-flagship-apps.yml | Run end-2-end tests with full applications, including deployment to the production cloud. | CPU | | .github/workflows/ci-tests-pytorch.yml | Run all tests except for accelerator-specific, standalone and slow tests. | CPU | -| .github/workflows/tpu-tests.yml | Run only TPU-specific tests. Requires that the PR title contains '\[TPU\]' | TPU | +| .github/workflows/tpu-tests.yml | Run only TPU-specific tests. Requires that the PR title contains '[TPU]' | TPU | \* Each standalone test needs to be run in separate processes to avoid unwanted interactions between test cases. diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index d84b7bed7a34a..c566f6c4611f1 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -174,6 +174,21 @@ jobs: with: project_id: ${{ secrets.GCS_PROJECT }} + # Uploading docs as archive to GCS, so they can be as backup + - name: Upload docs as archive to GCS 🪣 + if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' + working-directory: docs/build + run: | + zip ${{ env.VERSION }}.zip -r html/ + gsutil cp ${{ env.VERSION }}.zip ${GCP_TARGET} + + - name: Inject version selector + working-directory: docs/build + run: | + pip install -q wget + python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/inject-selector-script.py + python inject-selector-script.py html ${{ matrix.pkg-name }} + # Uploading docs to GCS, so they can be served on lightning.ai - name: Upload docs/${{ matrix.pkg-name }}/stable to GCS 🪣 if: startsWith(github.ref, 'refs/heads/release/') && github.event_name == 'push' @@ -188,11 +203,3 @@ jobs: - name: Upload docs/${{ matrix.pkg-name }}/release to GCS 🪣 if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/${{ env.VERSION }} - - # Uploading docs as archive to GCS, so they can be as backup - - name: Upload docs as archive to GCS 🪣 - if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' - working-directory: docs/build - run: | - zip ${{ env.VERSION }}.zip -r html/ - gsutil cp ${{ env.VERSION }}.zip ${GCP_TARGET} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c5e65de1d7eb7..f2e475f602913 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -65,12 +65,12 @@ repos: args: ["--in-place"] - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.9.1 + rev: v1.0.0 hooks: - id: sphinx-lint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.8.6 hooks: # try to fix what is possible - id: ruff @@ -81,7 +81,7 @@ repos: - id: ruff - repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 + rev: 0.7.21 hooks: - id: mdformat additional_dependencies: diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 063ef8c33d319..04b8ea33b9235 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split. .. code-block:: python import torch + import torch.nn as nn + import torch.nn.functional as F import torch.optim as optim - import pytorch_lightning as pl - from pytorch_lightning import LightningModule + from torch.utils.data import Dataset, DataLoader - class LitModel(LightningModule): + import lightning as L + + + class AverageDataset(Dataset): + def __init__(self, dataset_len=300, sequence_len=100): + self.dataset_len = dataset_len + self.sequence_len = sequence_len + self.input_seq = torch.randn(dataset_len, sequence_len, 10) + top, bottom = self.input_seq.chunk(2, -1) + self.output_seq = top + bottom.roll(shifts=1, dims=-1) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, item): + return self.input_seq[item], self.output_seq[item] + + + class LitModel(L.LightningModule): def __init__(self): super().__init__() + self.batch_size = 10 + self.in_features = 10 + self.out_features = 5 + self.hidden_dim = 20 + # 1. Switch to manual optimization self.automatic_optimization = False - self.truncated_bptt_steps = 10 - self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN + + self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True) + self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + + def forward(self, x, hs): + seq, hs = self.rnn(x, hs) + return self.linear_out(seq), hs # 2. Remove the `hiddens` argument def training_step(self, batch, batch_idx): - # 3. Split the batch in chunks along the time dimension - split_batches = split_batch(batch, self.truncated_bptt_steps) - - batch_size = 10 - hidden_dim = 20 - hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) - for split_batch in range(split_batches): - # 4. Perform the optimization in a loop - loss, hiddens = self.my_rnn(split_batch, hiddens) - self.backward(loss) - self.optimizer.step() - self.optimizer.zero_grad() + x, y = batch + split_x, split_y = [ + x.tensor_split(self.truncated_bptt_steps, dim=1), + y.tensor_split(self.truncated_bptt_steps, dim=1) + ] + + hiddens = None + optimizer = self.optimizers() + losses = [] + + # 4. Perform the optimization in a loop + for x, y in zip(split_x, split_y): + y_pred, hiddens = self(x, hiddens) + loss = F.mse_loss(y_pred, y) + + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() # 5. "Truncate" - hiddens = hiddens.detach() + hiddens = [h.detach() for h in hiddens] + losses.append(loss.detach()) + + avg_loss = sum(losses) / len(losses) + self.log("train_loss", avg_loss, prog_bar=True) # 6. Remove the return of `hiddens` # Returning loss in manual optimization is not needed return None def configure_optimizers(self): - return optim.Adam(self.my_rnn.parameters(), lr=0.001) + return optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=self.batch_size) + if __name__ == "__main__": model = LitModel() - trainer = pl.Trainer(max_epochs=5) - trainer.fit(model, train_dataloader) # Define your own dataloader + trainer = L.Trainer(max_epochs=5) + trainer.fit(model) diff --git a/docs/source-pytorch/tuning/profiler_intermediate.rst b/docs/source-pytorch/tuning/profiler_intermediate.rst index 802bfc5e6db4e..87aed86ac3653 100644 --- a/docs/source-pytorch/tuning/profiler_intermediate.rst +++ b/docs/source-pytorch/tuning/profiler_intermediate.rst @@ -55,7 +55,7 @@ The profiler will generate an output like this: Self CPU time total: 1.681ms .. note:: - When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. + When using the PyTorch Profiler, wall clock time will not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the ``SimpleProfiler``. @@ -142,7 +142,7 @@ This profiler will record ``training_step``, ``validation_step``, ``test_step``, The output above shows the profiling for the action ``training_step``. .. note:: - When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. + When using the PyTorch Profiler, wall clock time will not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the ``SimpleProfiler``. diff --git a/examples/fabric/build_your_own_trainer/run.py b/examples/fabric/build_your_own_trainer/run.py index c0c2ff28ddc41..936b590f5041a 100644 --- a/examples/fabric/build_your_own_trainer/run.py +++ b/examples/fabric/build_your_own_trainer/run.py @@ -1,8 +1,9 @@ -import lightning as L import torch from torchmetrics.functional.classification.accuracy import accuracy from trainer import MyCustomTrainer +import lightning as L + class MNISTModule(L.LightningModule): def __init__(self) -> None: diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index f4f31c114f084..d9d081a2aea69 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -3,15 +3,16 @@ from functools import partial from typing import Any, Literal, Optional, Union, cast -import lightning as L import torch +from lightning_utilities import apply_to_collection +from tqdm import tqdm + +import lightning as L from lightning.fabric.accelerators import Accelerator from lightning.fabric.loggers import Logger from lightning.fabric.strategies import Strategy from lightning.fabric.wrappers import _unwrap_objects from lightning.pytorch.utilities.model_helpers import is_overridden -from lightning_utilities import apply_to_collection -from tqdm import tqdm class MyCustomTrainer: diff --git a/examples/fabric/dcgan/train_fabric.py b/examples/fabric/dcgan/train_fabric.py index f7a18b2b5bc17..66f11e1c6fcfe 100644 --- a/examples/fabric/dcgan/train_fabric.py +++ b/examples/fabric/dcgan/train_fabric.py @@ -16,9 +16,10 @@ import torch.utils.data import torchvision.transforms as transforms import torchvision.utils -from lightning.fabric import Fabric, seed_everything from torchvision.datasets import CelebA +from lightning.fabric import Fabric, seed_everything + # Root directory for dataset dataroot = "data/" # Number of workers for dataloader diff --git a/examples/fabric/fp8_distributed_transformer/train.py b/examples/fabric/fp8_distributed_transformer/train.py index ba88603268945..a30e2de2fc5ed 100644 --- a/examples/fabric/fp8_distributed_transformer/train.py +++ b/examples/fabric/fp8_distributed_transformer/train.py @@ -1,15 +1,16 @@ -import lightning as L import torch import torch.nn as nn import torch.nn.functional as F -from lightning.fabric.strategies import ModelParallelStrategy -from lightning.pytorch.demos import Transformer, WikiText2 from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.distributed.device_mesh import DeviceMesh from torch.utils.data import DataLoader from torchao.float8 import Float8LinearConfig, convert_to_float8_training from tqdm import tqdm +import lightning as L +from lightning.fabric.strategies import ModelParallelStrategy +from lightning.pytorch.demos import Transformer, WikiText2 + def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: float8_config = Float8LinearConfig( diff --git a/examples/fabric/image_classifier/train_fabric.py b/examples/fabric/image_classifier/train_fabric.py index 02487a65e3989..d207595e9d2ba 100644 --- a/examples/fabric/image_classifier/train_fabric.py +++ b/examples/fabric/image_classifier/train_fabric.py @@ -36,11 +36,12 @@ import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as T -from lightning.fabric import Fabric, seed_everything from torch.optim.lr_scheduler import StepLR from torchmetrics.classification import Accuracy from torchvision.datasets import MNIST +from lightning.fabric import Fabric, seed_everything + DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets") diff --git a/examples/fabric/kfold_cv/train_fabric.py b/examples/fabric/kfold_cv/train_fabric.py index b3aa08e9aae9b..05d9885190dbc 100644 --- a/examples/fabric/kfold_cv/train_fabric.py +++ b/examples/fabric/kfold_cv/train_fabric.py @@ -20,12 +20,13 @@ import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as T -from lightning.fabric import Fabric, seed_everything from sklearn import model_selection from torch.utils.data import DataLoader, SubsetRandomSampler from torchmetrics.classification import Accuracy from torchvision.datasets import MNIST +from lightning.fabric import Fabric, seed_everything + DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets") diff --git a/examples/fabric/language_model/train.py b/examples/fabric/language_model/train.py index cafe6ceeb18b1..01947893be926 100644 --- a/examples/fabric/language_model/train.py +++ b/examples/fabric/language_model/train.py @@ -1,9 +1,10 @@ -import lightning as L import torch import torch.nn.functional as F -from lightning.pytorch.demos import Transformer, WikiText2 from torch.utils.data import DataLoader, random_split +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 + def main(): L.seed_everything(42) diff --git a/examples/fabric/meta_learning/train_fabric.py b/examples/fabric/meta_learning/train_fabric.py index 3cf1390477aeb..203155f7b2ada 100644 --- a/examples/fabric/meta_learning/train_fabric.py +++ b/examples/fabric/meta_learning/train_fabric.py @@ -18,6 +18,7 @@ import cherry import learn2learn as l2l import torch + from lightning.fabric import Fabric, seed_everything diff --git a/examples/fabric/reinforcement_learning/rl/agent.py b/examples/fabric/reinforcement_learning/rl/agent.py index 16a4cd6d86c73..b3d024d720d11 100644 --- a/examples/fabric/reinforcement_learning/rl/agent.py +++ b/examples/fabric/reinforcement_learning/rl/agent.py @@ -3,11 +3,11 @@ import gymnasium as gym import torch import torch.nn.functional as F -from lightning.pytorch import LightningModule from torch import Tensor from torch.distributions import Categorical from torchmetrics import MeanMetric +from lightning.pytorch import LightningModule from rl.loss import entropy_loss, policy_loss, value_loss from rl.utils import layer_init diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index 4df52d7cd0455..7c9536fbd9532 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -25,13 +25,14 @@ import gymnasium as gym import torch import torchmetrics -from lightning.fabric import Fabric -from lightning.fabric.loggers import TensorBoardLogger from rl.agent import PPOLightningAgent from rl.utils import linear_annealing, make_env, parse_args, test from torch import Tensor from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler +from lightning.fabric import Fabric +from lightning.fabric.loggers import TensorBoardLogger + def train( fabric: Fabric, diff --git a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py index 7150ac3a12529..1d8967ec6f1dc 100644 --- a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py +++ b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py @@ -25,17 +25,18 @@ import gymnasium as gym import torch -from lightning.fabric import Fabric -from lightning.fabric.loggers import TensorBoardLogger -from lightning.fabric.plugins.collectives import TorchCollective -from lightning.fabric.plugins.collectives.collective import CollectibleGroup -from lightning.fabric.strategies import DDPStrategy from rl.agent import PPOLightningAgent from rl.utils import linear_annealing, make_env, parse_args, test from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler from torchmetrics import MeanMetric +from lightning.fabric import Fabric +from lightning.fabric.loggers import TensorBoardLogger +from lightning.fabric.plugins.collectives import TorchCollective +from lightning.fabric.plugins.collectives.collective import CollectibleGroup +from lightning.fabric.strategies import DDPStrategy + @torch.no_grad() def player(args, world_collective: TorchCollective, player_trainer_collective: TorchCollective): diff --git a/examples/fabric/tensor_parallel/README.md b/examples/fabric/tensor_parallel/README.md index e66d9acd2848b..1f551109cc5e7 100644 --- a/examples/fabric/tensor_parallel/README.md +++ b/examples/fabric/tensor_parallel/README.md @@ -41,5 +41,5 @@ Training successfully completed! Peak memory usage: 17.95 GB ``` -> \[!NOTE\] +> [!NOTE] > The `ModelParallelStrategy` is experimental and subject to change. Report issues on [GitHub](https://github.com/Lightning-AI/pytorch-lightning/issues). diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py index 4a98f12cf6168..35ee9074f18a8 100644 --- a/examples/fabric/tensor_parallel/train.py +++ b/examples/fabric/tensor_parallel/train.py @@ -1,13 +1,14 @@ -import lightning as L import torch import torch.nn.functional as F from data import RandomTokenDataset -from lightning.fabric.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader +import lightning as L +from lightning.fabric.strategies import ModelParallelStrategy + def train(): strategy = ModelParallelStrategy( diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index d6f594b12f57b..332c9a811e3e4 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -22,13 +22,14 @@ import torch import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader, random_split + from lightning.pytorch import LightningDataModule, LightningModule, Trainer, callbacks, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.mnist_datamodule import MNIST from lightning.pytorch.utilities import rank_zero_only from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from torch import nn -from torch.utils.data import DataLoader, random_split if _TORCHVISION_AVAILABLE: import torchvision diff --git a/examples/pytorch/basics/backbone_image_classifier.py b/examples/pytorch/basics/backbone_image_classifier.py index fceb97dc41cff..965f636d7fc0b 100644 --- a/examples/pytorch/basics/backbone_image_classifier.py +++ b/examples/pytorch/basics/backbone_image_classifier.py @@ -21,12 +21,13 @@ from typing import Optional import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader, random_split + from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.mnist_datamodule import MNIST from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/examples/pytorch/basics/profiler_example.py b/examples/pytorch/basics/profiler_example.py index 5b9ba08fec761..366aaecefe7e4 100644 --- a/examples/pytorch/basics/profiler_example.py +++ b/examples/pytorch/basics/profiler_example.py @@ -28,6 +28,7 @@ import torch import torchvision import torchvision.transforms as T + from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.profilers.pytorch import PyTorchProfiler diff --git a/examples/pytorch/basics/transformer.py b/examples/pytorch/basics/transformer.py index 93cb39d829acc..dbd990d7f2759 100644 --- a/examples/pytorch/basics/transformer.py +++ b/examples/pytorch/basics/transformer.py @@ -1,9 +1,10 @@ -import lightning as L import torch import torch.nn.functional as F -from lightning.pytorch.demos import Transformer, WikiText2 from torch.utils.data import DataLoader, random_split +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 + class LanguageModel(L.LightningModule): def __init__(self, vocab_size): diff --git a/examples/pytorch/bug_report/bug_report_model.py b/examples/pytorch/bug_report/bug_report_model.py index aa3f4cad710fe..551ea21721754 100644 --- a/examples/pytorch/bug_report/bug_report_model.py +++ b/examples/pytorch/bug_report/bug_report_model.py @@ -1,9 +1,10 @@ import os import torch -from lightning.pytorch import LightningModule, Trainer from torch.utils.data import DataLoader, Dataset +from lightning.pytorch import LightningModule, Trainer + class RandomDataset(Dataset): def __init__(self, size, length): diff --git a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py index d03bb0c4edd16..69721214748ee 100644 --- a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py +++ b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py @@ -46,11 +46,6 @@ import torch import torch.nn.functional as F -from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo -from lightning.pytorch.callbacks.finetuning import BaseFinetuning -from lightning.pytorch.cli import LightningCLI -from lightning.pytorch.utilities import rank_zero_info -from lightning.pytorch.utilities.model_helpers import get_torchvision_model from torch import nn, optim from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer @@ -60,6 +55,12 @@ from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive +from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo +from lightning.pytorch.callbacks.finetuning import BaseFinetuning +from lightning.pytorch.cli import LightningCLI +from lightning.pytorch.utilities import rank_zero_info +from lightning.pytorch.utilities.model_helpers import get_torchvision_model + log = logging.getLogger(__name__) DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py index 417e167df0d93..7ce7682d82c76 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + from lightning.pytorch import cli_lightning_logo from lightning.pytorch.core import LightningModule from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule diff --git a/examples/pytorch/domain_templates/imagenet.py b/examples/pytorch/domain_templates/imagenet.py index 49f9b509ce6e7..fd2050e2ed38b 100644 --- a/examples/pytorch/domain_templates/imagenet.py +++ b/examples/pytorch/domain_templates/imagenet.py @@ -43,13 +43,14 @@ import torch.utils.data.distributed import torchvision.datasets as datasets import torchvision.transforms as transforms +from torch.utils.data import Dataset +from torchmetrics import Accuracy + from lightning.pytorch import LightningModule from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar from lightning.pytorch.cli import LightningCLI from lightning.pytorch.strategies import ParallelStrategy from lightning.pytorch.utilities.model_helpers import get_torchvision_model -from torch.utils.data import Dataset -from torchmetrics import Accuracy class ImageNetLightningModel(LightningModule): diff --git a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py index b3bfaaea93e7f..193e6495a4182 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py +++ b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py @@ -41,11 +41,12 @@ import torch import torch.nn as nn import torch.optim as optim -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from torch.utils.data.dataset import IterableDataset +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything + class DQN(nn.Module): """Simple MLP network. diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py index 1fb083894c284..af503dbb925cd 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py +++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py @@ -35,12 +35,13 @@ import gym import torch -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything from torch import nn from torch.distributions import Categorical, Normal from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader, IterableDataset +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything + def create_mlp(input_shape: tuple[int], n_actions: int, hidden_size: int = 128): """Simple Multi-Layer Perceptron network.""" diff --git a/examples/pytorch/domain_templates/semantic_segmentation.py b/examples/pytorch/domain_templates/semantic_segmentation.py index 12ecbeeb5f0a9..0f19349f7a0fc 100644 --- a/examples/pytorch/domain_templates/semantic_segmentation.py +++ b/examples/pytorch/domain_templates/semantic_segmentation.py @@ -19,11 +19,12 @@ import torch import torch.nn.functional as F import torchvision.transforms as transforms -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo from PIL import Image from torch import nn from torch.utils.data import DataLoader, Dataset +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo + DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) diff --git a/examples/pytorch/fp8_distributed_transformer/train.py b/examples/pytorch/fp8_distributed_transformer/train.py index 6c7be98ee7dbd..78aa6f13be6c2 100644 --- a/examples/pytorch/fp8_distributed_transformer/train.py +++ b/examples/pytorch/fp8_distributed_transformer/train.py @@ -1,13 +1,14 @@ -import lightning as L import torch import torch.nn as nn import torch.nn.functional as F -from lightning.pytorch.demos import Transformer, WikiText2 -from lightning.pytorch.strategies import ModelParallelStrategy from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.utils.data import DataLoader from torchao.float8 import Float8LinearConfig, convert_to_float8_training +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 +from lightning.pytorch.strategies import ModelParallelStrategy + class LanguageModel(L.LightningModule): def __init__(self, vocab_size): diff --git a/examples/pytorch/hpu/mnist_sample.py b/examples/pytorch/hpu/mnist_sample.py index 4d2e22c03fe7e..0d04074519c8c 100644 --- a/examples/pytorch/hpu/mnist_sample.py +++ b/examples/pytorch/hpu/mnist_sample.py @@ -13,11 +13,12 @@ # limitations under the License. import torch from jsonargparse import lazy_instance +from lightning_habana import HPUPrecisionPlugin +from torch.nn import functional as F + from lightning.pytorch import LightningModule from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule -from lightning_habana import HPUPrecisionPlugin -from torch.nn import functional as F class LitClassifier(LightningModule): diff --git a/examples/pytorch/servable_module/production.py b/examples/pytorch/servable_module/production.py index da0c42d12a865..854ff1176b619 100644 --- a/examples/pytorch/servable_module/production.py +++ b/examples/pytorch/servable_module/production.py @@ -8,11 +8,12 @@ import torch import torchvision import torchvision.transforms as T +from PIL import Image as PILImage + from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.serve import ServableModule, ServableModuleValidator from lightning.pytorch.utilities.model_helpers import get_torchvision_model -from PIL import Image as PILImage DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/examples/pytorch/tensor_parallel/README.md b/examples/pytorch/tensor_parallel/README.md index d8b81b6de1bff..92e6fcb038268 100644 --- a/examples/pytorch/tensor_parallel/README.md +++ b/examples/pytorch/tensor_parallel/README.md @@ -45,5 +45,5 @@ Training successfully completed! Peak memory usage: 36.73 GB ``` -> \[!NOTE\] +> [!NOTE] > The `ModelParallelStrategy` is experimental and subject to change. Report issues on [GitHub](https://github.com/Lightning-AI/pytorch-lightning/issues). diff --git a/examples/pytorch/tensor_parallel/train.py b/examples/pytorch/tensor_parallel/train.py index 6a91e1242e4af..f9b971fd39f82 100644 --- a/examples/pytorch/tensor_parallel/train.py +++ b/examples/pytorch/tensor_parallel/train.py @@ -1,13 +1,14 @@ -import lightning as L import torch import torch.nn.functional as F from data import RandomTokenDataset -from lightning.pytorch.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader +import lightning as L +from lightning.pytorch.strategies import ModelParallelStrategy + class Llama3(L.LightningModule): def __init__(self): diff --git a/requirements/collect_env_details.py b/requirements/collect_env_details.py index 5e6f9ba3dd350..392dd637fa8fd 100644 --- a/requirements/collect_env_details.py +++ b/requirements/collect_env_details.py @@ -70,7 +70,7 @@ def nice_print(details: dict, level: int = 0) -> list: lines += [level * LEVEL_OFFSET + key] lines += [(level + 1) * LEVEL_OFFSET + "- " + v for v in details[k]] else: - template = "{:%is} {}" % KEY_PADDING + template = "{:%is} {}" % KEY_PADDING # noqa: UP031 key_val = template.format(key, details[k]) lines += [(level * LEVEL_OFFSET) + key_val] return lines diff --git a/src/lightning/data/README.md b/src/lightning/data/README.md index 525a7e14f894d..c61e7eacf26f2 100644 --- a/src/lightning/data/README.md +++ b/src/lightning/data/README.md @@ -31,11 +31,11 @@ Find the reproducible [Studio Benchmark](https://lightning.ai/lightning-ai/studi ### Imagenet-1.2M Streaming from AWS S3 -| Framework | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) | Images / sec 1st Epoch (torch16) | Images / sec 2nd Epoch (torch16) | -| ----------- | --------------------------------- | ---------------------------------- | -------------------------------- | -------------------------------- | -| PL Data | **5800.34** | **6589.98** | **6282.17** | **7221.88** | -| Web Dataset | 3134.42 | 3924.95 | 3343.40 | 4424.62 | -| Mosaic ML | 2898.61 | 5099.93 | 2809.69 | 5158.98 | +| Framework | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) | Images / sec 1st Epoch (torch16) | Images / sec 2nd Epoch (torch16) | +| ----------- | -------------------------------- | -------------------------------- | -------------------------------- | -------------------------------- | +| PL Data | **5800.34** | **6589.98** | **6282.17** | **7221.88** | +| Web Dataset | 3134.42 | 3924.95 | 3343.40 | 4424.62 | +| Mosaic ML | 2898.61 | 5099.93 | 2809.69 | 5158.98 | Higher is better. diff --git a/src/lightning/fabric/_graveyard/tpu.py b/src/lightning/fabric/_graveyard/tpu.py index c537ffc032322..138830e4e3b1b 100644 --- a/src/lightning/fabric/_graveyard/tpu.py +++ b/src/lightning/fabric/_graveyard/tpu.py @@ -71,7 +71,7 @@ class TPUPrecision(XLAPrecision): def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision`" " instead." + "The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(precision="32-true") @@ -85,8 +85,7 @@ class XLABf16Precision(XLAPrecision): def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `XLABf16Precision` class is deprecated. Use" - " `lightning.fabric.plugins.precision.XLAPrecision` instead." + "The `XLABf16Precision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(precision="bf16-true") @@ -100,8 +99,7 @@ class TPUBf16Precision(XLABf16Precision): def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `TPUBf16Precision` class is deprecated. Use" - " `lightning.fabric.plugins.precision.XLAPrecision` instead." + "The `TPUBf16Precision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(*args, **kwargs) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 1f679ba7ffe1a..1c2d62a3c59cb 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -253,8 +253,7 @@ def _check_config_and_set_final_flags( if plugins_flags_types.get(Precision.__name__) and precision_input is not None: raise ValueError( - f"Received both `precision={precision_input}` and `plugins={self._precision_instance}`." - f" Choose one." + f"Received both `precision={precision_input}` and `plugins={self._precision_instance}`. Choose one." ) self._precision_input = "32-true" if precision_input is None else precision_input diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 82c636a282b25..355e55b576894 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -144,7 +144,7 @@ def __init__( nvme_path: Filesystem path for NVMe device for optimizer/parameter state offloading. optimizer_buffer_count: Number of buffers in buffer pool for optimizer state offloading - when ``offload_optimizer_device`` is set to to ``nvme``. + when ``offload_optimizer_device`` is set to ``nvme``. This should be at least the number of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter, gradient, momentum, and variance). @@ -335,8 +335,7 @@ def setup_module_and_optimizers( """ if len(optimizers) != 1: raise ValueError( - f"Currently only one optimizer is supported with DeepSpeed." - f" Got {len(optimizers)} optimizers instead." + f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." ) self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0]) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index ad1fc19074d06..ace23a9c7a2c5 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -292,8 +292,7 @@ def load_checkpoint( if isinstance(state, Optimizer): raise NotImplementedError( - f"Loading a single optimizer object from a checkpoint is not supported yet with" - f" {type(self).__name__}." + f"Loading a single optimizer object from a checkpoint is not supported yet with {type(self).__name__}." ) return _load_checkpoint(path=path, state=state, strict=strict) diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index 78c4215f9ce23..d108894f614e6 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -145,7 +145,7 @@ def _validate_condition_metric(self, logs: dict[str, Tensor]) -> bool: error_msg = ( f"Early stopping conditioned on metric `{self.monitor}` which is not available." " Pass in or modify your `EarlyStopping` callback to use any of the following:" - f' `{"`, `".join(list(logs.keys()))}`' + f" `{'`, `'.join(list(logs.keys()))}`" ) if monitor_val is None: diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index f1d1da924eac4..b8624daac3fa3 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -979,7 +979,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, - # Metric to to monitor for schedulers like `ReduceLROnPlateau` + # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index e1752c67d9183..ced8a6f1f2bd3 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -74,7 +74,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: continue lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] if ( - type(lm_val) != type(dm_val) + type(lm_val) != type(dm_val) # noqa: E721 or (isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val)) or lm_val != dm_val ): diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index df9f97b6fd45c..4a1ae6f1766ae 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -166,7 +166,7 @@ def __init__( nvme_path: Filesystem path for NVMe device for optimizer/parameter state offloading. optimizer_buffer_count: Number of buffers in buffer pool for optimizer state offloading - when ``offload_optimizer_device`` is set to to ``nvme``. + when ``offload_optimizer_device`` is set to ``nvme``. This should be at least the number of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter, gradient, momentum, and variance). @@ -410,8 +410,7 @@ def _setup_model_and_optimizers( """ if len(optimizers) != 1: raise ValueError( - f"Currently only one optimizer is supported with DeepSpeed." - f" Got {len(optimizers)} optimizers instead." + f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." ) # train_micro_batch_size_per_gpu is used for throughput logging purposes diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index fdde19aa80eea..0881ac0b3fa08 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -157,7 +157,7 @@ def forked(self) -> bool: def forked_name(self, on_step: bool) -> str: if self.forked: - return f'{self.name}_{"step" if on_step else "epoch"}' + return f"{self.name}_{'step' if on_step else 'epoch'}" return self.name @property diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index 44591aa7f4dc1..daf1c400c03df 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -104,7 +104,7 @@ def _check_mixed_imports(instance: object) -> None: _R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method -class _restricted_classmethod_impl(Generic[_T, _P, _R_co]): +class _restricted_classmethod_impl(Generic[_T, _R_co, _P]): """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance instead of a class type.""" diff --git a/src/lightning_fabric/README.md b/src/lightning_fabric/README.md index d842c0d19118d..076800caeb71e 100644 --- a/src/lightning_fabric/README.md +++ b/src/lightning_fabric/README.md @@ -215,7 +215,7 @@ Lightning is rigorously tested across multiple CPUs and GPUs and against major P | System / PyTorch ver. | 1.12 | 1.13 | 2.0 | 2.1 | | :--------------------------------: | :-------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------: | -| Linux py3.9 \[GPUs\] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Flightning-fabric%20%28GPUs%29) | +| Linux py3.9 [GPUs] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Flightning-fabric%20%28GPUs%29) | | Linux (multiple Python versions) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | | OSX (multiple Python versions) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | | Windows (multiple Python versions) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index ae9339dfb2b0d..f3fb8cb2fd2b3 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -79,7 +79,7 @@ Lightning is rigorously tested across multiple CPUs, GPUs and TPUs and against m | System / PyTorch ver. | 1.12 | 1.13 | 2.0 | 2.1 | | :--------------------------------: | :---------------------------------------------------------------------------------------------------------: | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------- | -| Linux py3.9 \[GPUs\] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Fpytorch-lightning%20%28GPUs%29) | +| Linux py3.9 [GPUs] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Fpytorch-lightning%20%28GPUs%29) | | Linux (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | | OSX (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | | Windows (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | diff --git a/tests/legacy/simple_classif_training.py b/tests/legacy/simple_classif_training.py index d2cf4cd2166f3..dd767ee9075f2 100644 --- a/tests/legacy/simple_classif_training.py +++ b/tests/legacy/simple_classif_training.py @@ -14,13 +14,14 @@ import os import sys -import lightning.pytorch as pl import torch -from lightning.pytorch import seed_everything -from lightning.pytorch.callbacks import EarlyStopping from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.simple_models import ClassificationModel +import lightning.pytorch as pl +from lightning.pytorch import seed_everything +from lightning.pytorch.callbacks import EarlyStopping + PATH_LEGACY = os.path.dirname(__file__) diff --git a/tests/parity_fabric/conftest.py b/tests/parity_fabric/conftest.py index ceb19e061c774..9fc6f9d908d81 100644 --- a/tests/parity_fabric/conftest.py +++ b/tests/parity_fabric/conftest.py @@ -17,7 +17,7 @@ import torch.distributed -@pytest.fixture() +@pytest.fixture def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield @@ -25,7 +25,7 @@ def reset_deterministic_algorithm(): torch.use_deterministic_algorithms(False) -@pytest.fixture() +@pytest.fixture def reset_cudnn_benchmark(): """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" yield diff --git a/tests/parity_fabric/test_parity_ddp.py b/tests/parity_fabric/test_parity_ddp.py index 217d401ad6fba..d30d2b6233886 100644 --- a/tests/parity_fabric/test_parity_ddp.py +++ b/tests/parity_fabric/test_parity_ddp.py @@ -18,11 +18,11 @@ import torch import torch.distributed import torch.nn.functional -from lightning.fabric.fabric import Fabric from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from lightning.fabric.fabric import Fabric from parity_fabric.models import ConvNet from parity_fabric.utils import ( cuda_reset, diff --git a/tests/parity_fabric/test_parity_simple.py b/tests/parity_fabric/test_parity_simple.py index a97d39dfa1035..54c0de7297ac5 100644 --- a/tests/parity_fabric/test_parity_simple.py +++ b/tests/parity_fabric/test_parity_simple.py @@ -19,9 +19,9 @@ import torch import torch.distributed import torch.nn.functional -from lightning.fabric.fabric import Fabric from tests_fabric.helpers.runif import RunIf +from lightning.fabric.fabric import Fabric from parity_fabric.models import ConvNet from parity_fabric.utils import ( cuda_reset, diff --git a/tests/parity_fabric/utils.py b/tests/parity_fabric/utils.py index 7f0028dc23421..7d7a14732bb0e 100644 --- a/tests/parity_fabric/utils.py +++ b/tests/parity_fabric/utils.py @@ -14,6 +14,7 @@ import os import torch + from lightning.fabric.accelerators.cuda import _clear_cuda_memory diff --git a/tests/parity_pytorch/__init__.py b/tests/parity_pytorch/__init__.py index 148237ab9c718..6d7cadefc20fa 100644 --- a/tests/parity_pytorch/__init__.py +++ b/tests/parity_pytorch/__init__.py @@ -1,4 +1,5 @@ import pytest + from lightning.pytorch.utilities.testing import _runif_reasons diff --git a/tests/parity_pytorch/models.py b/tests/parity_pytorch/models.py index f55b0d6f1f36e..17cbef6a76faa 100644 --- a/tests/parity_pytorch/models.py +++ b/tests/parity_pytorch/models.py @@ -14,11 +14,12 @@ import torch import torch.nn.functional as F +from tests_pytorch import _PATH_DATASETS +from torch.utils.data import DataLoader + from lightning.pytorch.core.module import LightningModule from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE from lightning.pytorch.utilities.model_helpers import get_torchvision_model -from tests_pytorch import _PATH_DATASETS -from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/tests/parity_pytorch/test_basic_parity.py b/tests/parity_pytorch/test_basic_parity.py index d7f15815c831f..6b413dada8659 100644 --- a/tests/parity_pytorch/test_basic_parity.py +++ b/tests/parity_pytorch/test_basic_parity.py @@ -16,9 +16,9 @@ import numpy as np import pytest import torch -from lightning.pytorch import LightningModule, Trainer, seed_everything from tests_pytorch.helpers.advanced_models import ParityModuleMNIST, ParityModuleRNN +from lightning.pytorch import LightningModule, Trainer, seed_everything from parity_pytorch.measure import measure_loops from parity_pytorch.models import ParityModuleCIFAR diff --git a/tests/parity_pytorch/test_sync_batchnorm_parity.py b/tests/parity_pytorch/test_sync_batchnorm_parity.py index 4d2300cf15670..af22d7470e524 100644 --- a/tests/parity_pytorch/test_sync_batchnorm_parity.py +++ b/tests/parity_pytorch/test_sync_batchnorm_parity.py @@ -14,9 +14,9 @@ import torch import torch.nn as nn -from lightning.pytorch import LightningModule, Trainer, seed_everything from torch.utils.data import DataLoader, DistributedSampler +from lightning.pytorch import LightningModule, Trainer, seed_everything from parity_pytorch import RunIf diff --git a/tests/tests_fabric/accelerators/test_cpu.py b/tests/tests_fabric/accelerators/test_cpu.py index 5efb5d6afddbc..7c7029f9ec9e3 100644 --- a/tests/tests_fabric/accelerators/test_cpu.py +++ b/tests/tests_fabric/accelerators/test_cpu.py @@ -14,6 +14,7 @@ import pytest import torch + from lightning.fabric.accelerators.cpu import CPUAccelerator, _parse_cpu_cores diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py index 0aed3675d93e1..037eb8d400825 100644 --- a/tests/tests_fabric/accelerators/test_cuda.py +++ b/tests/tests_fabric/accelerators/test_cuda.py @@ -18,15 +18,15 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest import torch + +import lightning.fabric from lightning.fabric.accelerators.cuda import ( CUDAAccelerator, _check_cuda_matmul_precision, find_usable_cuda_devices, ) - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/accelerators/test_mps.py b/tests/tests_fabric/accelerators/test_mps.py index 20dc9b6a93581..612bd8f74a640 100644 --- a/tests/tests_fabric/accelerators/test_mps.py +++ b/tests/tests_fabric/accelerators/test_mps.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest import torch + from lightning.fabric.accelerators.mps import MPSAccelerator from lightning.fabric.utilities.exceptions import MisconfigurationException - from tests_fabric.helpers.runif import RunIf _MAYBE_MPS = "mps" if MPSAccelerator.is_available() else "cpu" diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 383bcd3a9c0a6..7cb5a1dad3730 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -14,6 +14,7 @@ from typing import Any import torch + from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator diff --git a/tests/tests_fabric/accelerators/test_xla.py b/tests/tests_fabric/accelerators/test_xla.py index 7a906c8ae0c54..95d5cab90a5f2 100644 --- a/tests/tests_fabric/accelerators/test_xla.py +++ b/tests/tests_fabric/accelerators/test_xla.py @@ -13,8 +13,8 @@ # limitations under the License import pytest -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, XLAAccelerator +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, XLAAccelerator from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index dd272257b3923..ad16c96305135 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -18,9 +18,10 @@ from pathlib import Path from unittest.mock import Mock -import lightning.fabric import pytest import torch.distributed + +import lightning.fabric from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection @@ -29,9 +30,10 @@ @pytest.fixture(autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" - from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities + from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric + functions = (rank_zero_only_fabric, rank_zero_only_utilities) ranks = [getattr(fn, "rank", None) for fn in functions] yield @@ -126,7 +128,7 @@ def reset_in_fabric_backward(): wrappers._in_fabric_backward = False -@pytest.fixture() +@pytest.fixture def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield @@ -134,7 +136,7 @@ def reset_deterministic_algorithm(): torch.use_deterministic_algorithms(False) -@pytest.fixture() +@pytest.fixture def reset_cudnn_benchmark(): """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" yield @@ -155,7 +157,7 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock()) -@pytest.fixture() +@pytest.fixture def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_xla_available(monkeypatch) @@ -166,12 +168,12 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) -@pytest.fixture() +@pytest.fixture def tpu_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_tpu_available(monkeypatch) -@pytest.fixture() +@pytest.fixture def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. diff --git a/tests/tests_fabric/helpers/runif.py b/tests/tests_fabric/helpers/runif.py index 813c4f93b1788..23a620295bcbf 100644 --- a/tests/tests_fabric/helpers/runif.py +++ b/tests/tests_fabric/helpers/runif.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.fabric.utilities.testing import _runif_reasons diff --git a/tests/tests_fabric/loggers/test_csv.py b/tests/tests_fabric/loggers/test_csv.py index 784bb6fb45aba..f03e8c5125e54 100644 --- a/tests/tests_fabric/loggers/test_csv.py +++ b/tests/tests_fabric/loggers/test_csv.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.fabric.loggers import CSVLogger from lightning.fabric.loggers.csv_logs import _ExperimentWriter diff --git a/tests/tests_fabric/loggers/test_tensorboard.py b/tests/tests_fabric/loggers/test_tensorboard.py index fa685241ea1b5..4dcb86f0e7406 100644 --- a/tests/tests_fabric/loggers/test_tensorboard.py +++ b/tests/tests_fabric/loggers/test_tensorboard.py @@ -19,10 +19,10 @@ import numpy as np import pytest import torch + from lightning.fabric.loggers import TensorBoardLogger from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.fabric.wrappers import _FabricModule - from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/plugins/collectives/test_single_device.py b/tests/tests_fabric/plugins/collectives/test_single_device.py index e7aefdb6078b1..a5a909da64b8a 100644 --- a/tests/tests_fabric/plugins/collectives/test_single_device.py +++ b/tests/tests_fabric/plugins/collectives/test_single_device.py @@ -1,6 +1,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.collectives import SingleDeviceCollective diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index b4c223e770282..5cef70f3b91ba 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -6,12 +6,12 @@ import pytest import torch + from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.ddp import DDPStrategy from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher - from tests_fabric.helpers.runif import RunIf if TorchCollective.is_available(): diff --git a/tests/tests_fabric/plugins/environments/test_kubeflow.py b/tests/tests_fabric/plugins/environments/test_kubeflow.py index 3c44273825510..3436adc9ce2aa 100644 --- a/tests/tests_fabric/plugins/environments/test_kubeflow.py +++ b/tests/tests_fabric/plugins/environments/test_kubeflow.py @@ -16,6 +16,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import KubeflowEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_lightning.py b/tests/tests_fabric/plugins/environments/test_lightning.py index 02100329d173c..cc0179af4c3f5 100644 --- a/tests/tests_fabric/plugins/environments/test_lightning.py +++ b/tests/tests_fabric/plugins/environments/test_lightning.py @@ -15,6 +15,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import LightningEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_lsf.py b/tests/tests_fabric/plugins/environments/test_lsf.py index 4e60d968dc953..31cc5976cfe09 100644 --- a/tests/tests_fabric/plugins/environments/test_lsf.py +++ b/tests/tests_fabric/plugins/environments/test_lsf.py @@ -15,6 +15,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import LSFEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_mpi.py b/tests/tests_fabric/plugins/environments/test_mpi.py index 649d4dcb1dab2..3df0000cf2766 100644 --- a/tests/tests_fabric/plugins/environments/test_mpi.py +++ b/tests/tests_fabric/plugins/environments/test_mpi.py @@ -16,8 +16,9 @@ from unittest import mock from unittest.mock import MagicMock -import lightning.fabric.plugins.environments.mpi import pytest + +import lightning.fabric.plugins.environments.mpi from lightning.fabric.plugins.environments import MPIEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py index f237478a533f4..75ca43577d579 100644 --- a/tests/tests_fabric/plugins/environments/test_slurm.py +++ b/tests/tests_fabric/plugins/environments/test_slurm.py @@ -18,10 +18,10 @@ from unittest import mock import pytest -from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning_utilities.test.warning import no_warning_call +from lightning.fabric.plugins.environments import SLURMEnvironment +from lightning.fabric.utilities.warnings import PossibleUserWarning from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/environments/test_torchelastic.py b/tests/tests_fabric/plugins/environments/test_torchelastic.py index 3cf8619a00c22..161d42894df30 100644 --- a/tests/tests_fabric/plugins/environments/test_torchelastic.py +++ b/tests/tests_fabric/plugins/environments/test_torchelastic.py @@ -17,6 +17,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import TorchElasticEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index 7a3610b65b4bb..7e33d5db87dd4 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -14,11 +14,11 @@ import os from unittest import mock -import lightning.fabric import pytest + +import lightning.fabric from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 from lightning.fabric.plugins.environments import XLAEnvironment - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_all.py b/tests/tests_fabric/plugins/precision/test_all.py index 5e86a35647489..94e5efaa74eed 100644 --- a/tests/tests_fabric/plugins/precision/test_all.py +++ b/tests/tests_fabric/plugins/precision/test_all.py @@ -1,5 +1,6 @@ import pytest import torch + from lightning.fabric.plugins import DeepSpeedPrecision, DoublePrecision, FSDPPrecision, HalfPrecision diff --git a/tests/tests_fabric/plugins/precision/test_amp.py b/tests/tests_fabric/plugins/precision/test_amp.py index 93d53eb406f71..73507f085936b 100644 --- a/tests/tests_fabric/plugins/precision/test_amp.py +++ b/tests/tests_fabric/plugins/precision/test_amp.py @@ -16,6 +16,7 @@ import pytest import torch + from lightning.fabric.plugins.precision.amp import MixedPrecision from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py index aa6c6cfce4504..bcbd9435d47ac 100644 --- a/tests/tests_fabric/plugins/precision/test_amp_integration.py +++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py @@ -16,9 +16,9 @@ import pytest import torch import torch.nn as nn + from lightning.fabric import Fabric, seed_everything from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index b8b9020b201a7..152f9a1c01fe9 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -14,16 +14,16 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed + +import lightning.fabric from lightning.fabric import Fabric from lightning.fabric.connector import _Connector from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.utilities.init import _materialize_meta_tensors from lightning.fabric.utilities.load import _lazy_load - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed.py b/tests/tests_fabric/plugins/precision/test_deepspeed.py index 248f616646842..170f15afaa2ad 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed.py @@ -16,9 +16,9 @@ import pytest import torch + from lightning.fabric.plugins.precision.deepspeed import DeepSpeedPrecision from lightning.fabric.utilities.types import Steppable - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py index 27e44398fc095..e989534343b8c 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py @@ -14,10 +14,10 @@ from unittest import mock import pytest + from lightning.fabric.connector import _Connector from lightning.fabric.plugins import DeepSpeedPrecision from lightning.fabric.strategies import DeepSpeedStrategy - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_double.py b/tests/tests_fabric/plugins/precision/test_double.py index 97d4a2303100e..4921e0f4e659b 100644 --- a/tests/tests_fabric/plugins/precision/test_double.py +++ b/tests/tests_fabric/plugins/precision/test_double.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch + from lightning.fabric.plugins.precision.double import DoublePrecision diff --git a/tests/tests_fabric/plugins/precision/test_double_integration.py b/tests/tests_fabric/plugins/precision/test_double_integration.py index 6701bc1a80e59..8f96f75ad1c67 100644 --- a/tests/tests_fabric/plugins/precision/test_double_integration.py +++ b/tests/tests_fabric/plugins/precision/test_double_integration.py @@ -15,8 +15,8 @@ import torch import torch.nn as nn -from lightning.fabric import Fabric +from lightning.fabric import Fabric from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index 6a4968736ea86..3b8d916e20c8f 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -15,9 +15,9 @@ import pytest import torch + from lightning.fabric.plugins import FSDPPrecision from lightning.fabric.plugins.precision.utils import _DtypeContextManager - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_half.py b/tests/tests_fabric/plugins/precision/test_half.py index 4037feebbd178..00d23df4ae5b6 100644 --- a/tests/tests_fabric/plugins/precision/test_half.py +++ b/tests/tests_fabric/plugins/precision/test_half.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from lightning.fabric.plugins.precision import HalfPrecision diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py index c003715dead8a..033484aca9c90 100644 --- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py +++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py @@ -14,10 +14,11 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed + +import lightning.fabric from lightning.fabric.connector import _Connector from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision diff --git a/tests/tests_fabric/plugins/precision/test_utils.py b/tests/tests_fabric/plugins/precision/test_utils.py index 74899c86e9e2d..6e459c3fe6637 100644 --- a/tests/tests_fabric/plugins/precision/test_utils.py +++ b/tests/tests_fabric/plugins/precision/test_utils.py @@ -1,5 +1,6 @@ import pytest import torch + from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager, _DtypeContextManager diff --git a/tests/tests_fabric/plugins/precision/test_xla.py b/tests/tests_fabric/plugins/precision/test_xla.py index 0cdc11b00b99a..cfdc32112a957 100644 --- a/tests/tests_fabric/plugins/precision/test_xla.py +++ b/tests/tests_fabric/plugins/precision/test_xla.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.fabric.plugins import XLAPrecision diff --git a/tests/tests_fabric/plugins/precision/test_xla_integration.py b/tests/tests_fabric/plugins/precision/test_xla_integration.py index 14a5cd1442e4a..75dede49e2fe0 100644 --- a/tests/tests_fabric/plugins/precision/test_xla_integration.py +++ b/tests/tests_fabric/plugins/precision/test_xla_integration.py @@ -17,9 +17,9 @@ import pytest import torch import torch.nn as nn + from lightning.fabric import Fabric from lightning.fabric.plugins import XLAPrecision - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py index 6c595fba7acab..5bb85e070f17d 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py @@ -16,8 +16,8 @@ import pytest import torch -from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher +from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py index c3bc7d3c2c6cd..6ae96b9bcafc6 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py @@ -15,8 +15,8 @@ import pytest import torch import torch.nn as nn -from lightning.fabric import Fabric +from lightning.fabric import Fabric from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/launchers/test_subprocess_script.py b/tests/tests_fabric/strategies/launchers/test_subprocess_script.py index a2d04e29bc0d6..70587ca2877a6 100644 --- a/tests/tests_fabric/strategies/launchers/test_subprocess_script.py +++ b/tests/tests_fabric/strategies/launchers/test_subprocess_script.py @@ -17,8 +17,9 @@ from unittest import mock from unittest.mock import ANY, Mock -import lightning.fabric import pytest + +import lightning.fabric from lightning.fabric.strategies.launchers.subprocess_script import ( _HYDRA_AVAILABLE, _ChildProcessObserver, diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index b98d5f8226dc2..fa5c975228a5e 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -19,12 +19,12 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel + from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl -from torch.nn.parallel import DistributedDataParallel - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py index a7ed09b00b09e..70dd25aa99603 100644 --- a/tests/tests_fabric/strategies/test_ddp_integration.py +++ b/tests/tests_fabric/strategies/test_ddp_integration.py @@ -18,11 +18,11 @@ import pytest import torch -from lightning.fabric import Fabric from lightning_utilities.core.imports import RequirementCache from torch._dynamo import OptimizedModule from torch.nn.parallel.distributed import DistributedDataParallel +from lightning.fabric import Fabric from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _run_test_clip_gradients from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 4ee87b265b086..032ee63cd4721 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -18,14 +18,14 @@ import pytest import torch -from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator -from lightning.fabric.strategies import DeepSpeedStrategy from torch.optim import Optimizer +from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator +from lightning.fabric.strategies import DeepSpeedStrategy from tests_fabric.helpers.runif import RunIf -@pytest.fixture() +@pytest.fixture def deepspeed_config(): return { "optimizer": {"type": "SGD", "params": {"lr": 3e-5}}, @@ -36,7 +36,7 @@ def deepspeed_config(): } -@pytest.fixture() +@pytest.fixture def deepspeed_zero_config(deepspeed_config): return {**deepspeed_config, "zero_allow_untested_optimizer": True, "zero_optimization": {"stage": 2}} diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 4811599ed05ab..5970b673cee5f 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -20,11 +20,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader + from lightning.fabric import Fabric from lightning.fabric.plugins import DeepSpeedPrecision from lightning.fabric.strategies import DeepSpeedStrategy -from torch.utils.data import DataLoader - from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset from tests_fabric.helpers.runif import RunIf from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/strategies/test_dp.py b/tests/tests_fabric/strategies/test_dp.py index e50abb1882870..ff470c646f4b3 100644 --- a/tests/tests_fabric/strategies/test_dp.py +++ b/tests/tests_fabric/strategies/test_dp.py @@ -16,9 +16,9 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.strategies import DataParallelStrategy - from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _run_test_clip_gradients diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index cb6542cdb6243..d5f82752a9176 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -19,6 +19,10 @@ import pytest import torch import torch.nn as nn +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.optim import Adam + from lightning.fabric.plugins import HalfPrecision from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import FSDPStrategy @@ -28,9 +32,6 @@ _is_sharded_checkpoint, ) from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision -from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from torch.optim import Adam def test_custom_mixed_precision(): diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 0697c3043d496..11a7a1a6f8f7f 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -20,17 +20,17 @@ import pytest import torch import torch.nn as nn -from lightning.fabric import Fabric -from lightning.fabric.plugins import FSDPPrecision -from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.utilities.load import _load_distributed_checkpoint -from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo import OptimizedModule from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType from torch.distributed.fsdp.wrap import always_wrap_policy, wrap from torch.nn import Parameter from torch.utils.data import DataLoader +from lightning.fabric import Fabric +from lightning.fabric.plugins import FSDPPrecision +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities.load import _load_distributed_checkpoint +from lightning.fabric.wrappers import _FabricOptimizer from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py index 1f8b5b783b73e..78622adf66fa6 100644 --- a/tests/tests_fabric/strategies/test_model_parallel.py +++ b/tests/tests_fabric/strategies/test_model_parallel.py @@ -19,12 +19,12 @@ import pytest import torch import torch.nn as nn +from torch.optim import Adam + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import ModelParallelStrategy from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint from lightning.fabric.strategies.model_parallel import _ParallelBackwardSyncControl -from torch.optim import Adam - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index b04a29b691529..bddfadd9a2c54 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -20,16 +20,16 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader, DistributedSampler + from lightning.fabric import Fabric from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state from lightning.fabric.utilities.load import _load_distributed_checkpoint -from torch.utils.data import DataLoader, DistributedSampler - from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf -@pytest.fixture() +@pytest.fixture def distributed(): yield if torch.distributed.is_initialized(): diff --git a/tests/tests_fabric/strategies/test_single_device.py b/tests/tests_fabric/strategies/test_single_device.py index 95ed9787f40a2..fff7175909222 100644 --- a/tests/tests_fabric/strategies/test_single_device.py +++ b/tests/tests_fabric/strategies/test_single_device.py @@ -15,9 +15,9 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.strategies import SingleDeviceStrategy - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_strategy.py b/tests/tests_fabric/strategies/test_strategy.py index a7a1dba87cb97..37ccbea5e6c95 100644 --- a/tests/tests_fabric/strategies/test_strategy.py +++ b/tests/tests_fabric/strategies/test_strategy.py @@ -16,10 +16,10 @@ import pytest import torch + from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.strategies import SingleDeviceStrategy from lightning.fabric.utilities.types import _Stateful - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py index f711eb3470b45..a260b3f231e1d 100644 --- a/tests/tests_fabric/strategies/test_xla.py +++ b/tests/tests_fabric/strategies/test_xla.py @@ -18,13 +18,13 @@ import pytest import torch +from torch.utils.data import DataLoader + from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1, XLAAccelerator from lightning.fabric.strategies import XLAStrategy from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.utilities.distributed import ReduceOp from lightning.fabric.utilities.seed import seed_everything -from torch.utils.data import DataLoader - from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index 879a55cf77f34..c2634283ad110 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -18,12 +18,12 @@ import pytest import torch.nn import torch.nn as nn +from torch.optim import Adam + from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.plugins import XLAPrecision from lightning.fabric.strategies import XLAFSDPStrategy from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl -from torch.optim import Adam - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py index 20c2ef042272e..b77803744b6c4 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py @@ -18,10 +18,10 @@ import pytest import torch -from lightning.fabric import Fabric -from lightning.fabric.strategies import XLAFSDPStrategy from torch.utils.data import DataLoader +from lightning.fabric import Fabric +from lightning.fabric.strategies import XLAFSDPStrategy from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index a57f413ff6081..ec3160111e6ee 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -20,12 +20,12 @@ from unittest.mock import Mock import pytest -from lightning.fabric.cli import _consolidate, _get_supported_strategies, _run +from lightning.fabric.cli import _consolidate, _get_supported_strategies, _run from tests_fabric.helpers.runif import RunIf -@pytest.fixture() +@pytest.fixture def fake_script(tmp_path): script = tmp_path / "script.py" script.touch() @@ -184,8 +184,7 @@ def test_run_through_lightning_entry_point(): result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True) deprecation_message = ( - "`lightning run model` is deprecated and will be removed in future versions. " - "Please call `fabric run` instead" + "`lightning run model` is deprecated and will be removed in future versions. Please call `fabric run` instead" ) message = "Usage: lightning run [OPTIONS] SCRIPT [SCRIPT_ARGS]" assert deprecation_message in result.stdout diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 17817dee64abe..f5c9a7c714b78 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -19,10 +19,12 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed +from lightning_utilities.test.warning import no_warning_call + +import lightning.fabric from lightning.fabric import Fabric from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.accelerators.accelerator import Accelerator @@ -63,8 +65,6 @@ from lightning.fabric.strategies.ddp import _DDP_FORK_ALIASES from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.fabric.utilities.imports import _IS_WINDOWS -from lightning_utilities.test.warning import no_warning_call - from tests_fabric.conftest import mock_tpu_available from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 7bb6b29eceaf2..ee002b5d8061c 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -17,11 +17,15 @@ from unittest import mock from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call -import lightning.fabric import pytest import torch import torch.distributed import torch.nn.functional +from lightning_utilities.test.warning import no_warning_call +from torch import nn +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset + +import lightning.fabric from lightning.fabric.fabric import Fabric from lightning.fabric.strategies import ( DataParallelStrategy, @@ -37,10 +41,6 @@ from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer -from lightning_utilities.test.warning import no_warning_call -from torch import nn -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 363da022d285b..de8ae208e4dc4 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -16,6 +16,10 @@ import pytest import torch +from torch._dynamo import OptimizedModule +from torch.utils.data import BatchSampler, DistributedSampler +from torch.utils.data.dataloader import DataLoader + from lightning.fabric.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin @@ -28,10 +32,6 @@ _unwrap_objects, is_wrapped, ) -from torch._dynamo import OptimizedModule -from torch.utils.data import BatchSampler, DistributedSampler -from torch.utils.data.dataloader import DataLoader - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_apply_func.py b/tests/tests_fabric/utilities/test_apply_func.py index 9e137561aa525..055fa89101c96 100644 --- a/tests/tests_fabric/utilities/test_apply_func.py +++ b/tests/tests_fabric/utilities/test_apply_func.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch -from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars, move_data_to_device from torch import Tensor +from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars, move_data_to_device + @pytest.mark.parametrize("should_return", [False, True]) def test_wrongly_implemented_transferable_data_type(should_return): diff --git a/tests/tests_fabric/utilities/test_cloud_io.py b/tests/tests_fabric/utilities/test_cloud_io.py index d502199da1493..e1333ddff87f3 100644 --- a/tests/tests_fabric/utilities/test_cloud_io.py +++ b/tests/tests_fabric/utilities/test_cloud_io.py @@ -16,6 +16,7 @@ import fsspec from fsspec.implementations.local import LocalFileSystem from fsspec.spec import AbstractFileSystem + from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem diff --git a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py index 2584aab8bdc2e..78690a9870982 100644 --- a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py +++ b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py @@ -16,8 +16,9 @@ from pathlib import Path from unittest import mock -import lightning.fabric import pytest + +import lightning.fabric from lightning.fabric.utilities.consolidate_checkpoint import _parse_cli_args, _process_cli_args from lightning.fabric.utilities.load import _METADATA_FILENAME diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 656b9cac3d77e..faff6e182a06f 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -4,10 +4,14 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import BatchSampler, DataLoader, RandomSampler + +import lightning.fabric from lightning.fabric.utilities.data import ( AttributeDict, _get_dataloader_init_args_and_kwargs, @@ -21,10 +25,6 @@ suggested_max_num_workers, ) from lightning.fabric.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, RandomSampler - from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset diff --git a/tests/tests_fabric/utilities/test_device_dtype_mixin.py b/tests/tests_fabric/utilities/test_device_dtype_mixin.py index 9958e48c624ee..1261ca5e0accb 100644 --- a/tests/tests_fabric/utilities/test_device_dtype_mixin.py +++ b/tests/tests_fabric/utilities/test_device_dtype_mixin.py @@ -1,8 +1,8 @@ import pytest import torch -from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from torch import nn as nn +from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_device_parser.py b/tests/tests_fabric/utilities/test_device_parser.py index 9b5a09e370860..f9f9e49cf6afe 100644 --- a/tests/tests_fabric/utilities/test_device_parser.py +++ b/tests/tests_fabric/utilities/test_device_parser.py @@ -14,6 +14,7 @@ from unittest import mock import pytest + from lightning.fabric.utilities import device_parser from lightning.fabric.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index f5a78a1529a52..9282f00f1ffb6 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -5,9 +5,11 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest import torch +from lightning_utilities.core.imports import RequirementCache + +import lightning.fabric from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy @@ -23,8 +25,6 @@ _sync_ddp, is_shared_filesystem, ) -from lightning_utilities.core.imports import RequirementCache - from tests_fabric.helpers.runif import RunIf @@ -105,6 +105,8 @@ def _test_all_reduce(strategy): assert result is tensor # inplace +# flaky with "process 0 terminated with signal SIGABRT" (GLOO) +@pytest.mark.flaky(reruns=3, only_rerun="torch.multiprocessing.spawn.ProcessExitedException") @RunIf(skip_windows=True) @pytest.mark.parametrize( "process", diff --git a/tests/tests_fabric/utilities/test_init.py b/tests/tests_fabric/utilities/test_init.py index dd08dec020669..69758c5a3e17e 100644 --- a/tests/tests_fabric/utilities/test_init.py +++ b/tests/tests_fabric/utilities/test_init.py @@ -16,12 +16,12 @@ import pytest import torch.nn + from lightning.fabric.utilities.init import ( _EmptyInit, _has_meta_device_parameters_or_buffers, _materialize_meta_tensors, ) - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index 39d257f8b685b..ed38aa2459af7 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -14,6 +14,7 @@ import pytest import torch import torch.nn as nn + from lightning.fabric.utilities.load import ( _lazy_load, _materialize_tensors, @@ -55,7 +56,7 @@ def test_lazy_load_tensor(tmp_path): for t0, t1 in zip(expected.values(), loaded.values()): assert isinstance(t1, _NotYetLoadedTensor) t1_materialized = _materialize_tensors(t1) - assert type(t0) == type(t1_materialized) + assert type(t0) == type(t1_materialized) # noqa: E721 assert torch.equal(t0, t1_materialized) @@ -91,7 +92,7 @@ def test_materialize_tensors(tmp_path): loaded = _lazy_load(tmp_path / "tensor.pt") materialized = _materialize_tensors(loaded) assert torch.equal(materialized, tensor) - assert type(tensor) == type(materialized) + assert type(tensor) == type(materialized) # noqa: E721 # Collection of tensors collection = { diff --git a/tests/tests_fabric/utilities/test_logger.py b/tests/tests_fabric/utilities/test_logger.py index 0f6500cb42be1..26823143102a7 100644 --- a/tests/tests_fabric/utilities/test_logger.py +++ b/tests/tests_fabric/utilities/test_logger.py @@ -17,6 +17,7 @@ import numpy as np import torch + from lightning.fabric.utilities.logger import ( _add_prefix, _convert_json_serializable, diff --git a/tests/tests_fabric/utilities/test_optimizer.py b/tests/tests_fabric/utilities/test_optimizer.py index 83c7ed44120b9..d96c58049ed3a 100644 --- a/tests/tests_fabric/utilities/test_optimizer.py +++ b/tests/tests_fabric/utilities/test_optimizer.py @@ -2,9 +2,9 @@ import pytest import torch -from lightning.fabric.utilities.optimizer import _optimizer_to_device from torch import Tensor +from lightning.fabric.utilities.optimizer import _optimizer_to_device from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_rank_zero.py b/tests/tests_fabric/utilities/test_rank_zero.py index 0c1b39fe9d9b8..b8ea54f90625d 100644 --- a/tests/tests_fabric/utilities/test_rank_zero.py +++ b/tests/tests_fabric/utilities/test_rank_zero.py @@ -3,6 +3,7 @@ from unittest import mock import pytest + from lightning.fabric.utilities.rank_zero import _get_rank diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index be2ecba3294b1..0973709bf84fd 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -6,6 +6,7 @@ import numpy import pytest import torch + from lightning.fabric.utilities.seed import ( _collect_rng_states, _set_rng_states, diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py index 9739540af7f18..6054bf224d3df 100644 --- a/tests/tests_fabric/utilities/test_spike.py +++ b/tests/tests_fabric/utilities/test_spike.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index d410d0766d97b..00dafbb72cb8f 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.utilities.throughput import ( @@ -12,7 +13,6 @@ get_available_flops, measure_flops, ) - from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/utilities/test_warnings.py b/tests/tests_fabric/utilities/test_warnings.py index b7989d85b5932..bfccaaa8481d9 100644 --- a/tests/tests_fabric/utilities/test_warnings.py +++ b/tests/tests_fabric/utilities/test_warnings.py @@ -27,16 +27,17 @@ from pathlib import Path from unittest import mock -import lightning.fabric import pytest +from lightning_utilities.core.rank_zero import WarningCache, _warn +from lightning_utilities.test.warning import no_warning_call + +import lightning.fabric from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from lightning.fabric.utilities.warnings import ( PossibleUserWarning, _is_path_in_lightning, disable_possible_user_warnings, ) -from lightning_utilities.core.rank_zero import WarningCache, _warn -from lightning_utilities.test.warning import no_warning_call def line_number(): diff --git a/tests/tests_pytorch/__init__.py b/tests/tests_pytorch/__init__.py index a43ffae6a83b4..efbfa8bb14c76 100644 --- a/tests/tests_pytorch/__init__.py +++ b/tests/tests_pytorch/__init__.py @@ -25,7 +25,7 @@ # todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages if str(_PROJECT_ROOT) not in os.getenv("PYTHONPATH", ""): splitter = ":" if os.environ.get("PYTHONPATH", "") else "" - os.environ["PYTHONPATH"] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}' + os.environ["PYTHONPATH"] = f"{_PROJECT_ROOT}{splitter}{os.environ.get('PYTHONPATH', '')}" # Ignore cleanup warnings from pytest (rarely happens due to a race condition when executing pytest in parallel) warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*") diff --git a/tests/tests_pytorch/accelerators/test_common.py b/tests/tests_pytorch/accelerators/test_common.py index 6967bffd9ffa2..42fd66a247d1d 100644 --- a/tests/tests_pytorch/accelerators/test_common.py +++ b/tests/tests_pytorch/accelerators/test_common.py @@ -14,6 +14,7 @@ from typing import Any import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators import Accelerator from lightning.pytorch.strategies import DDPStrategy diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index 844556064621d..ec2aabf559dc7 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -3,16 +3,16 @@ from typing import Any, Union from unittest.mock import Mock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.fabric.plugins import TorchCheckpointIO from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.strategies import SingleDeviceStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/accelerators/test_gpu.py b/tests/tests_pytorch/accelerators/test_gpu.py index e175f8aa7647c..5a71887e17eec 100644 --- a/tests/tests_pytorch/accelerators/test_gpu.py +++ b/tests/tests_pytorch/accelerators/test_gpu.py @@ -15,11 +15,11 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.accelerators.cuda import get_nvidia_gpu_stats from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/accelerators/test_mps.py b/tests/tests_pytorch/accelerators/test_mps.py index 73d0785f9592d..c0a28840f0ef6 100644 --- a/tests/tests_pytorch/accelerators/test_mps.py +++ b/tests/tests_pytorch/accelerators/test_mps.py @@ -16,11 +16,11 @@ import pytest import torch + +import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer from lightning.pytorch.accelerators import MPSAccelerator from lightning.pytorch.demos.boring_classes import BoringModel - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py index 48b346006786f..6322a36a6cbc1 100644 --- a/tests/tests_pytorch/accelerators/test_xla.py +++ b/tests/tests_pytorch/accelerators/test_xla.py @@ -17,9 +17,12 @@ from unittest import mock from unittest.mock import MagicMock, call, patch -import lightning.fabric import pytest import torch +from torch import nn +from torch.utils.data import DataLoader + +import lightning.fabric from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator @@ -27,9 +30,6 @@ from lightning.pytorch.plugins import Precision, XLACheckpointIO, XLAPrecision from lightning.pytorch.strategies import DDPStrategy, XLAStrategy from lightning.pytorch.utilities import find_shared_parameters -from torch import nn -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.trainer.connectors.test_accelerator_connector import DeviceMock from tests_pytorch.trainer.optimization.test_manual_optimization import assert_emtpy_grad diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 89c1effe839a8..430fb9842cddc 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -17,14 +17,15 @@ from unittest.mock import DEFAULT, Mock import pytest +from tests_pytorch.helpers.runif import RunIf +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ProgressBar, RichProgressBar from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.logger import DummyLogger -from tests_pytorch.helpers.runif import RunIf -from torch.utils.data import DataLoader @RunIf(rich=True) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index d5187d5a1e325..d93bf1cf60e9c 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -22,6 +22,9 @@ import pytest import torch +from tests_pytorch.helpers.runif import RunIf +from torch.utils.data.dataloader import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint, ProgressBar, TQDMProgressBar from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm @@ -30,8 +33,6 @@ from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.logger import DummyLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.runif import RunIf -from torch.utils.data.dataloader import DataLoader class MockTqdm(Tqdm): diff --git a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py b/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py index 8f32c756881da..366a924a5867c 100644 --- a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py +++ b/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/callbacks/test_callbacks.py b/tests/tests_pytorch/callbacks/test_callbacks.py index 38b9428526505..53ea109b6ddf3 100644 --- a/tests/tests_pytorch/callbacks/test_callbacks.py +++ b/tests/tests_pytorch/callbacks/test_callbacks.py @@ -16,10 +16,11 @@ from unittest.mock import Mock import pytest +from lightning_utilities.test.warning import no_warning_call + from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel -from lightning_utilities.test.warning import no_warning_call def test_callbacks_configured_in_model(tmp_path): diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index f1d999f1df61a..290a0921cb06d 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -20,6 +20,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators.cpu import _CPU_PERCENT, _CPU_SWAP_PERCENT, _CPU_VM_PERCENT, get_cpu_stats from lightning.pytorch.callbacks import DeviceStatsMonitor @@ -28,7 +29,6 @@ from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index a3d56bb0135c3..c11aa37b456dc 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -22,11 +22,11 @@ import cloudpickle import pytest import torch + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index 0c09ae5d5042a..07343c1ecc12a 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -15,14 +15,14 @@ import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 -from lightning.pytorch import LightningModule, Trainer, seed_everything -from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from torch import nn from torch.optim import SGD, Optimizer from torch.utils.data import DataLoader +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py b/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py index cc584b2da624c..9ad9759ff248c 100644 --- a/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py +++ b/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py @@ -15,6 +15,7 @@ from unittest.mock import Mock, patch import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import GradientAccumulationScheduler from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/callbacks/test_lambda_function.py b/tests/tests_pytorch/callbacks/test_lambda_function.py index 40d694bb35ebc..2b5e025653940 100644 --- a/tests/tests_pytorch/callbacks/test_lambda_function.py +++ b/tests/tests_pytorch/callbacks/test_lambda_function.py @@ -14,10 +14,10 @@ from functools import partial import pytest + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import Callback, LambdaCallback from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.models.test_hooks import get_members diff --git a/tests/tests_pytorch/callbacks/test_lr_monitor.py b/tests/tests_pytorch/callbacks/test_lr_monitor.py index 4aedb4f23fa14..4f3a20ce82de5 100644 --- a/tests/tests_pytorch/callbacks/test_lr_monitor.py +++ b/tests/tests_pytorch/callbacks/test_lr_monitor.py @@ -13,6 +13,8 @@ # limitations under the License. import pytest import torch +from torch import optim + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor from lightning.pytorch.callbacks.callback import Callback @@ -20,8 +22,6 @@ from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import optim - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -44,9 +44,9 @@ def test_lr_monitor_single_lr(tmp_path): assert lr_monitor.lrs, "No learning rates logged" assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default" - assert all( - v is None for v in lr_monitor.last_weight_decay_values.values() - ), "Weight decay should not be logged by default" + assert all(v is None for v in lr_monitor.last_weight_decay_values.values()), ( + "Weight decay should not be logged by default" + ) assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-SGD"] @@ -87,9 +87,9 @@ def configure_optimizers(self): assert len(lr_monitor.last_momentum_values) == len(trainer.lr_scheduler_configs) assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values) - assert all( - v is not None for v in lr_monitor.last_weight_decay_values.values() - ), "Expected weight decay to be logged" + assert all(v is not None for v in lr_monitor.last_weight_decay_values.values()), ( + "Expected weight decay to be logged" + ) assert len(lr_monitor.last_weight_decay_values) == len(trainer.lr_scheduler_configs) assert all(k == f"lr-{opt}-weight_decay" for k in lr_monitor.last_weight_decay_values) diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py index 02604f5a195fe..7249d5739686b 100644 --- a/tests/tests_pytorch/callbacks/test_prediction_writer.py +++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py @@ -14,11 +14,12 @@ from unittest.mock import ANY, Mock, call import pytest +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import BasePredictionWriter from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader class DummyPredictionWriter(BasePredictionWriter): diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index f3ec7e2ccc029..d70ab68b78b32 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -19,13 +19,13 @@ import pytest import torch import torch.nn.utils.prune as pytorch_prune +from torch import nn +from torch.nn import Sequential + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint, ModelPruning from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.nn import Sequential - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_rich_model_summary.py b/tests/tests_pytorch/callbacks/test_rich_model_summary.py index 73709fd80a833..7534c23d5679c 100644 --- a/tests/tests_pytorch/callbacks/test_rich_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_rich_model_summary.py @@ -16,11 +16,11 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import RichModelSummary, RichProgressBar from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.model_summary import summarize - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py index 5634feaf221cd..692a28dcc38c4 100644 --- a/tests/tests_pytorch/callbacks/test_spike.py +++ b/tests/tests_pytorch/callbacks/test_spike.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks.spike import SpikeDetection diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 8d3a1800e1fa2..e50eef7f258e1 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -20,17 +20,17 @@ import pytest import torch +from torch import nn +from torch.optim.lr_scheduler import LambdaLR +from torch.optim.swa_utils import SWALR +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import StochasticWeightAveraging from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.strategies import Strategy from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.optim.lr_scheduler import LambdaLR -from torch.optim.swa_utils import SWALR -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 4867134a85642..9f77e4371e69e 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric.utilities.throughput import measure_flops from lightning.pytorch import Trainer from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor diff --git a/tests/tests_pytorch/callbacks/test_timer.py b/tests/tests_pytorch/callbacks/test_timer.py index e6359a2e9a5e1..e91170f2096d3 100644 --- a/tests/tests_pytorch/callbacks/test_timer.py +++ b/tests/tests_pytorch/callbacks/test_timer.py @@ -17,12 +17,12 @@ from unittest.mock import Mock, patch import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks.timer import Timer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py index ff8f3c95e43c5..2e998c42ed2b7 100644 --- a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py @@ -16,9 +16,9 @@ import pytest import torch + from lightning.pytorch import Trainer, callbacks from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index be754d3911ade..006a123356c98 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -16,11 +16,11 @@ import sys from unittest.mock import patch -import lightning.pytorch as pl import pytest import torch -from lightning.pytorch import Callback, Trainer +import lightning.pytorch as pl +from lightning.pytorch import Callback, Trainer from tests_pytorch import _PATH_LEGACY from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -75,6 +75,7 @@ def test_legacy_ckpt_threading(pl_version: str): def load_model(): import torch + from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index d43f07179e7bb..2351cc7548f79 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -25,11 +25,13 @@ from unittest.mock import Mock, call, patch import cloudpickle -import lightning.pytorch as pl import pytest import torch import yaml from jsonargparse import ArgumentParser +from torch import optim + +import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -37,8 +39,6 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from torch import optim - from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: diff --git a/tests/tests_pytorch/checkpointing/test_torch_saving.py b/tests/tests_pytorch/checkpointing/test_torch_saving.py index 4422a4063c719..e49c45bf54ade 100644 --- a/tests/tests_pytorch/checkpointing/test_torch_saving.py +++ b/tests/tests_pytorch/checkpointing/test_torch_saving.py @@ -13,9 +13,9 @@ # limitations under the License. import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py index c07400eaf8446..b0f1528d42d28 100644 --- a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py @@ -15,9 +15,10 @@ from unittest import mock from unittest.mock import ANY, Mock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.fabric.plugins import TorchCheckpointIO, XLACheckpointIO from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index ea5207516cad1..ea7380be3a42c 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -21,18 +21,18 @@ from pathlib import Path from unittest.mock import Mock -import lightning.fabric -import lightning.pytorch import pytest import torch.distributed +from tqdm import TMonitor + +import lightning.fabric +import lightning.pytorch from lightning.fabric.plugins.environments.lightning import find_free_network_port from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector -from tqdm import TMonitor - from tests_pytorch import _PATH_DATASETS @@ -44,9 +44,10 @@ def datadir(): @pytest.fixture(autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" + from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities + from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric from lightning.pytorch.utilities.rank_zero import rank_zero_only as rank_zero_only_pytorch - from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities functions = (rank_zero_only_pytorch, rank_zero_only_fabric, rank_zero_only_utilities) ranks = [getattr(fn, "rank", None) for fn in functions] @@ -176,22 +177,22 @@ def mock_cuda_count(monkeypatch, n: int) -> None: monkeypatch.setattr(lightning.pytorch.accelerators.cuda, "num_cuda_devices", lambda: n) -@pytest.fixture() +@pytest.fixture def cuda_count_0(monkeypatch): mock_cuda_count(monkeypatch, 0) -@pytest.fixture() +@pytest.fixture def cuda_count_1(monkeypatch): mock_cuda_count(monkeypatch, 1) -@pytest.fixture() +@pytest.fixture def cuda_count_2(monkeypatch): mock_cuda_count(monkeypatch, 2) -@pytest.fixture() +@pytest.fixture def cuda_count_4(monkeypatch): mock_cuda_count(monkeypatch, 4) @@ -201,12 +202,12 @@ def mock_mps_count(monkeypatch, n: int) -> None: monkeypatch.setattr(lightning.fabric.accelerators.mps.MPSAccelerator, "is_available", lambda *_: n > 0) -@pytest.fixture() +@pytest.fixture def mps_count_0(monkeypatch): mock_mps_count(monkeypatch, 0) -@pytest.fixture() +@pytest.fixture def mps_count_1(monkeypatch): mock_mps_count(monkeypatch, 1) @@ -222,7 +223,7 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) -@pytest.fixture() +@pytest.fixture def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_xla_available(monkeypatch) @@ -238,12 +239,12 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) -@pytest.fixture() +@pytest.fixture def tpu_available(monkeypatch) -> None: mock_tpu_available(monkeypatch) -@pytest.fixture() +@pytest.fixture def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. @@ -271,7 +272,7 @@ def caplog(caplog): logging.getLogger(name).propagate = propagate -@pytest.fixture() +@pytest.fixture def tmpdir_server(tmp_path): Handler = partial(SimpleHTTPRequestHandler, directory=str(tmp_path)) from http.server import ThreadingHTTPServer @@ -285,7 +286,7 @@ def tmpdir_server(tmp_path): server.shutdown() -@pytest.fixture() +@pytest.fixture def single_process_pg(): """Initialize the default process group with only the current process for testing purposes. diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index b3ccd88aae704..49e8bf0f5b36d 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -21,6 +21,7 @@ import pytest import torch + from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import ( @@ -34,7 +35,6 @@ from lightning.pytorch.utilities import AttributeDict from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 5ee91e82689f4..2036014762ebf 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -17,15 +17,15 @@ import pytest import torch +from torch import nn +from torch.optim import SGD, Adam + from lightning.fabric import Fabric from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.core.module import _TrainerFabricShim from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.optim import SGD, Adam - from tests_pytorch.helpers.runif import RunIf @@ -336,9 +336,9 @@ def __init__(self, spec): ), "Expect the shards to be different before `m_1` loading `m_0`'s state dict" m_1.load_state_dict(m_0.state_dict(), strict=False) - assert torch.allclose( - m_1.sharded_tensor.local_shards()[0].tensor, m_0.sharded_tensor.local_shards()[0].tensor - ), "Expect the shards to be same after `m_1` loading `m_0`'s state dict" + assert torch.allclose(m_1.sharded_tensor.local_shards()[0].tensor, m_0.sharded_tensor.local_shards()[0].tensor), ( + "Expect the shards to be same after `m_1` loading `m_0`'s state dict" + ) def test_lightning_module_configure_gradient_clipping(tmp_path): diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index b25b7ae648a3a..ed1ca2b4db03f 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -15,13 +15,13 @@ from unittest.mock import DEFAULT, Mock, patch import torch +from torch.optim import SGD, Adam, Optimizer + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops.optimization.automatic import Closure from lightning.pytorch.tuner.tuning import Tuner -from torch.optim import SGD, Adam, Optimizer - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index dcb3f71c7499c..4a5df7d37cd7a 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -16,9 +16,15 @@ from contextlib import nullcontext, suppress from unittest import mock -import lightning.pytorch as pl import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor, tensor +from torch.nn import ModuleDict, ModuleList +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import Accuracy + +import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint @@ -30,12 +36,6 @@ _Sync, ) from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor, tensor -from torch.nn import ModuleDict, ModuleList -from torchmetrics import Metric, MetricCollection -from torchmetrics.classification import Accuracy - from tests_pytorch.core.test_results import spawn_launch from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/core/test_results.py b/tests/tests_pytorch/core/test_results.py index 006731f3e7c60..edc6efbfa4889 100644 --- a/tests/tests_pytorch/core/test_results.py +++ b/tests/tests_pytorch/core/test_results.py @@ -15,12 +15,12 @@ import torch import torch.distributed as dist + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.models.test_tpu import wrap_launch_function diff --git a/tests/tests_pytorch/core/test_saving.py b/tests/tests_pytorch/core/test_saving.py index c7e48239754c5..8e1e9584a7c68 100644 --- a/tests/tests_pytorch/core/test_saving.py +++ b/tests/tests_pytorch/core/test_saving.py @@ -1,11 +1,11 @@ from unittest.mock import ANY, Mock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel - from tests_pytorch.conftest import mock_cuda_count, mock_mps_count from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/demos/transformer.py b/tests/tests_pytorch/demos/transformer.py index 47ecbb2083273..50de873511053 100644 --- a/tests/tests_pytorch/demos/transformer.py +++ b/tests/tests_pytorch/demos/transformer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch + from lightning.pytorch.demos import Transformer diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index e6da72c777dbb..0b79b638534fa 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -1,9 +1,10 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch.nn + +import lightning.fabric from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision.double import LightningDoublePrecisionModule diff --git a/tests/tests_pytorch/helpers/__init__.py b/tests/tests_pytorch/helpers/__init__.py index 82a6332c56738..1299d7e542955 100644 --- a/tests/tests_pytorch/helpers/__init__.py +++ b/tests/tests_pytorch/helpers/__init__.py @@ -4,5 +4,4 @@ ManualOptimBoringModel, RandomDataset, ) - from tests_pytorch.helpers.datasets import TrialMNIST # noqa: F401 diff --git a/tests/tests_pytorch/helpers/advanced_models.py b/tests/tests_pytorch/helpers/advanced_models.py index 4fecf516018c1..959e6e5968d18 100644 --- a/tests/tests_pytorch/helpers/advanced_models.py +++ b/tests/tests_pytorch/helpers/advanced_models.py @@ -16,9 +16,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from lightning.pytorch.core.module import LightningModule from torch.utils.data import DataLoader +from lightning.pytorch.core.module import LightningModule from tests_pytorch import _PATH_DATASETS from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST @@ -219,3 +219,54 @@ def configure_optimizers(self): def train_dataloader(self): return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1) + + +class TBPTTModule(LightningModule): + def __init__(self): + super().__init__() + + self.batch_size = 10 + self.in_features = 10 + self.out_features = 5 + self.hidden_dim = 20 + + self.automatic_optimization = False + self.truncated_bptt_steps = 10 + + self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True) + self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + + def forward(self, x, hs): + seq, hs = self.rnn(x, hs) + return self.linear_out(seq), hs + + def training_step(self, batch, batch_idx): + x, y = batch + split_x, split_y = [ + x.tensor_split(self.truncated_bptt_steps, dim=1), + y.tensor_split(self.truncated_bptt_steps, dim=1), + ] + + hiddens = None + optimizer = self.optimizers() + losses = [] + + for x, y in zip(split_x, split_y): + y_pred, hiddens = self(x, hiddens) + loss = F.mse_loss(y_pred, y) + + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + + # "Truncate" + hiddens = [h.detach() for h in hiddens] + losses.append(loss.detach()) + + return + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=self.batch_size) diff --git a/tests/tests_pytorch/helpers/datamodules.py b/tests/tests_pytorch/helpers/datamodules.py index 5a91d8ebb981d..6282acf3be547 100644 --- a/tests/tests_pytorch/helpers/datamodules.py +++ b/tests/tests_pytorch/helpers/datamodules.py @@ -13,10 +13,10 @@ # limitations under the License. import torch -from lightning.pytorch.core.datamodule import LightningDataModule from lightning_utilities.core.imports import RequirementCache from torch.utils.data import DataLoader +from lightning.pytorch.core.datamodule import LightningDataModule from tests_pytorch.helpers.datasets import MNIST, SklearnDataset, TrialMNIST _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") diff --git a/tests/tests_pytorch/helpers/deterministic_model.py b/tests/tests_pytorch/helpers/deterministic_model.py index 95975e9ad1654..fbf3488158825 100644 --- a/tests/tests_pytorch/helpers/deterministic_model.py +++ b/tests/tests_pytorch/helpers/deterministic_model.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from lightning.pytorch.core.module import LightningModule from torch import Tensor, nn from torch.utils.data import DataLoader, Dataset +from lightning.pytorch.core.module import LightningModule + class DeterministicModel(LightningModule): def __init__(self, weights=None): diff --git a/tests/tests_pytorch/helpers/pipelines.py b/tests/tests_pytorch/helpers/pipelines.py index ab33878010123..b6c63a5702bfc 100644 --- a/tests/tests_pytorch/helpers/pipelines.py +++ b/tests/tests_pytorch/helpers/pipelines.py @@ -14,11 +14,11 @@ from functools import partial import torch +from torchmetrics.functional import accuracy + from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from torchmetrics.functional import accuracy - from tests_pytorch.helpers.utils import get_default_logger, load_model_from_checkpoint diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 1c5b059d679a5..25fadd524adf8 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch.utilities.testing import _runif_reasons diff --git a/tests/tests_pytorch/helpers/simple_models.py b/tests/tests_pytorch/helpers/simple_models.py index 940adc0ac49a6..a9dc635bba275 100644 --- a/tests/tests_pytorch/helpers/simple_models.py +++ b/tests/tests_pytorch/helpers/simple_models.py @@ -15,11 +15,12 @@ import torch import torch.nn.functional as F -from lightning.pytorch import LightningModule from lightning_utilities.core.imports import compare_version from torch import nn from torchmetrics import Accuracy, MeanSquaredError +from lightning.pytorch import LightningModule + # using new API with task _TM_GE_0_11 = compare_version("torchmetrics", operator.ge, "0.11.0") diff --git a/tests/tests_pytorch/helpers/test_models.py b/tests/tests_pytorch/helpers/test_models.py index 7e44f79413863..721641ae8343a 100644 --- a/tests/tests_pytorch/helpers/test_models.py +++ b/tests/tests_pytorch/helpers/test_models.py @@ -14,10 +14,10 @@ import os import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - -from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN +from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel @@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class): model.to_torchscript() if data_class: model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample) + + +def test_tbptt(tmp_path): + model = TBPTTModule() + + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) + trainer.fit(model) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 7cc5cc94fe8cc..ab1149ca9651a 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -18,7 +18,7 @@ import pytest -@pytest.fixture() +@pytest.fixture def mlflow_mock(monkeypatch): mlflow = ModuleType("mlflow") mlflow.set_tracking_uri = Mock() @@ -43,7 +43,7 @@ def mlflow_mock(monkeypatch): return mlflow -@pytest.fixture() +@pytest.fixture def wandb_mock(monkeypatch): class RunType: # to make isinstance checks pass pass @@ -89,7 +89,7 @@ class RunType: # to make isinstance checks pass return wandb -@pytest.fixture() +@pytest.fixture def comet_mock(monkeypatch): comet = ModuleType("comet_ml") monkeypatch.setitem(sys.modules, "comet_ml", comet) @@ -110,7 +110,7 @@ def comet_mock(monkeypatch): return comet -@pytest.fixture() +@pytest.fixture def neptune_mock(monkeypatch): class RunType: # to make isinstance checks pass def get_root_object(self): diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 1b845c57ec35d..4ac28ef487ad8 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -19,6 +19,7 @@ import pytest import torch + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import ( @@ -32,7 +33,6 @@ from lightning.pytorch.loggers.logger import DummyExperiment, Logger from lightning.pytorch.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.pytorch.tuner.tuning import Tuner - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.loggers.test_comet import _patch_comet_atexit from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index e467c63543ede..de99454548302 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -16,11 +16,12 @@ from unittest.mock import DEFAULT, Mock, patch import pytest +from torch import tensor + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CometLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import tensor def _patch_comet_atexit(monkeypatch): diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py index 27b85bb4ad745..1b09302ffb74a 100644 --- a/tests/tests_pytorch/loggers/test_csv.py +++ b/tests/tests_pytorch/loggers/test_csv.py @@ -18,11 +18,11 @@ import fsspec import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.core.saving import load_hparams_from_yaml from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.csv_logs import ExperimentWriter - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index dcdd504fd4660..124a9120a9197 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -20,6 +20,7 @@ import numpy as np import pytest import torch + from lightning.fabric.utilities.logger import _convert_params, _sanitize_params from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 14af36680904c..c7f9dbe1fe2c6 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, Mock import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.mlflow import ( diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index 0a39337ac5c16..b5e98fbe99113 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -17,9 +17,10 @@ from unittest import mock from unittest.mock import MagicMock, call -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import NeptuneLogger diff --git a/tests/tests_pytorch/loggers/test_tensorboard.py b/tests/tests_pytorch/loggers/test_tensorboard.py index 82ffff25cac7c..173805f1a6f3c 100644 --- a/tests/tests_pytorch/loggers/test_tensorboard.py +++ b/tests/tests_pytorch/loggers/test_tensorboard.py @@ -20,12 +20,12 @@ import pytest import torch import yaml + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE - from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index a8e70bfb6589d..f3d82b0582be2 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -18,14 +18,14 @@ import pytest import yaml +from lightning_utilities.test.warning import no_warning_call + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call - from tests_pytorch.test_cli import _xfail_python_ge_3_11_9 diff --git a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py index 2fb04d0d9d8d1..e20c1789be023 100644 --- a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops.optimization.automatic import ClosureResult diff --git a/tests/tests_pytorch/loops/optimization/test_closure.py b/tests/tests_pytorch/loops/optimization/test_closure.py index d7d4e51794aca..7766a385c3057 100644 --- a/tests/tests_pytorch/loops/optimization/test_closure.py +++ b/tests/tests_pytorch/loops/optimization/test_closure.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/loops/optimization/test_manual_loop.py b/tests/tests_pytorch/loops/optimization/test_manual_loop.py index 67be30b24e159..cedfefb4791ea 100644 --- a/tests/tests_pytorch/loops/optimization/test_manual_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_manual_loop.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops.optimization.manual import ManualResult diff --git a/tests/tests_pytorch/loops/test_all.py b/tests/tests_pytorch/loops/test_all.py index 1eb67064fb300..51b7bbeedf90b 100644 --- a/tests/tests_pytorch/loops/test_all.py +++ b/tests/tests_pytorch/loops/test_all.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 588672e19a05c..6d8c4be70d9c6 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -16,13 +16,13 @@ import pytest import torch +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.sampler import BatchSampler, RandomSampler + from lightning.fabric.accelerators.cuda import _clear_cuda_memory from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities import CombinedLoader -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.sampler import BatchSampler, RandomSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py index 09ded9a777679..e698a15cb0fb2 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py @@ -14,11 +14,11 @@ """Tests the evaluation loop.""" import torch +from torch import Tensor + from lightning.pytorch import Trainer from lightning.pytorch.core.module import LightningModule from lightning.pytorch.trainer.states import RunningStage -from torch import Tensor - from tests_pytorch.helpers.deterministic_model import DeterministicModel diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 75b25e3d98fd8..f66e9f9f3b16f 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -17,6 +17,9 @@ import pytest import torch +from torch import Tensor +from torch.utils.data import DataLoader, Dataset, IterableDataset + from lightning.pytorch import LightningDataModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.loops.fetchers import _DataLoaderIterDataFetcher, _PrefetchDataFetcher @@ -24,9 +27,6 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import STEP_OUTPUT -from torch import Tensor -from torch.utils.data import DataLoader, Dataset, IterableDataset - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 1820ca3568173..384ae2b47859b 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -20,6 +20,8 @@ import pytest import torch +from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback, ModelCheckpoint, OnExceptionCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -27,8 +29,6 @@ from lightning.pytorch.loops.progress import _BaseProgress from lightning.pytorch.utilities import CombinedLoader from lightning.pytorch.utilities.types import STEP_OUTPUT -from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index f27413955cae9..470cbcdc195f5 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -14,10 +14,11 @@ import itertools import pytest +from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper -from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler def test_prediction_loop_stores_predictions(tmp_path): diff --git a/tests/tests_pytorch/loops/test_progress.py b/tests/tests_pytorch/loops/test_progress.py index e7256d4504402..27184d7b17afb 100644 --- a/tests/tests_pytorch/loops/test_progress.py +++ b/tests/tests_pytorch/loops/test_progress.py @@ -14,6 +14,7 @@ from copy import deepcopy import pytest + from lightning.pytorch.loops.progress import ( _BaseProgress, _OptimizerProgress, diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index a110a20bfaf84..71252fe2011b8 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -16,11 +16,12 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call + from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.trainer import Trainer -from lightning_utilities.test.warning import no_warning_call def test_no_val_on_train_epoch_loop_restart(tmp_path): diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 2afeb338fd9fd..29afd1ba1a250 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -16,6 +16,7 @@ import pytest import torch + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _FitLoop diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py index c89913b29dbdb..44a0c85184d9e 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py @@ -14,9 +14,9 @@ """Tests to ensure that the training loop works with a dict (1.0)""" import torch + from lightning.pytorch import Trainer from lightning.pytorch.core.module import LightningModule - from tests_pytorch.helpers.deterministic_model import DeterministicModel diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py index ccdeb55d50ea8..4a4e8cb81c6a2 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from lightning.pytorch import Trainer -from lightning.pytorch.core.module import LightningModule -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.loops.optimization.automatic import Closure -from lightning.pytorch.trainer.states import RunningStage from lightning_utilities.test.warning import no_warning_call from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate +from lightning.pytorch import Trainer +from lightning.pytorch.core.module import LightningModule +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.loops.optimization.automatic import Closure +from lightning.pytorch.trainer.states import RunningStage from tests_pytorch.helpers.deterministic_model import DeterministicModel diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index c28d6300131ae..24323f5c1d691 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -16,12 +16,12 @@ import pytest import torch -from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from torch.utils.data import DataLoader import tests_pytorch.helpers.utils as tutils +from lightning.fabric.plugins.environments import SLURMEnvironment +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_cpu.py b/tests/tests_pytorch/models/test_cpu.py index 435b8437760d2..a2d38aca7c56c 100644 --- a/tests/tests_pytorch/models/test_cpu.py +++ b/tests/tests_pytorch/models/test_cpu.py @@ -15,12 +15,12 @@ from unittest import mock import torch -from lightning.pytorch import Trainer, seed_everything -from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils +from lightning.pytorch import Trainer, seed_everything +from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py index 54d394948eeee..a1d7c6a7f8ac1 100644 --- a/tests/tests_pytorch/models/test_ddp_fork_amp.py +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -14,8 +14,8 @@ import multiprocessing import torch -from lightning.pytorch.plugins import MixedPrecision +from lightning.pytorch.plugins import MixedPrecision from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_fabric_integration.py b/tests/tests_pytorch/models/test_fabric_integration.py index fd5ecd34e96be..3e2811e8a9413 100644 --- a/tests/tests_pytorch/models/test_fabric_integration.py +++ b/tests/tests_pytorch/models/test_fabric_integration.py @@ -16,6 +16,7 @@ from unittest.mock import Mock import torch + from lightning.fabric import Fabric from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel diff --git a/tests/tests_pytorch/models/test_gpu.py b/tests/tests_pytorch/models/test_gpu.py index b411774c3e164..797120312436f 100644 --- a/tests/tests_pytorch/models/test_gpu.py +++ b/tests/tests_pytorch/models/test_gpu.py @@ -18,14 +18,14 @@ import pytest import torch + +import tests_pytorch.helpers.pipelines as tpipes from lightning.fabric.plugins.environments import TorchElasticEnvironment from lightning.fabric.utilities.device_parser import _parse_gpu_ids from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 1a8aeb4b297a9..e943d0533cab5 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -18,12 +18,12 @@ import pytest import torch -from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__ -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset -from lightning.pytorch.utilities.model_helpers import is_overridden from torch import Tensor from torch.utils.data import DataLoader +from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__ +from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset +from lightning.pytorch.utilities.model_helpers import is_overridden from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 871b1cba673eb..6fd400aab2724 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -25,6 +25,10 @@ import pytest import torch from fsspec.implementations.local import LocalFileSystem +from lightning_utilities.core.imports import RequirementCache +from lightning_utilities.test.warning import no_warning_call +from torch.utils.data import DataLoader + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.core.datamodule import LightningDataModule @@ -34,10 +38,6 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities import AttributeDict, is_picklable from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from lightning_utilities.core.imports import RequirementCache -from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index ee670cd66e871..b3032bea5560d 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -21,11 +21,11 @@ import onnxruntime import pytest import torch -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel from lightning_utilities import compare_version import tests_pytorch.helpers.pipelines as tpipes +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf from tests_pytorch.utilities.test_model_summary import UnorderedModel diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 64f70b176a971..099493890831d 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -22,16 +22,16 @@ import cloudpickle import pytest import torch -from lightning.fabric import seed_everything -from lightning.pytorch import Callback, Trainer -from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer.states import TrainerFn from lightning_utilities.test.warning import no_warning_call from torch import Tensor import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils +from lightning.fabric import seed_everything +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.trainer.states import TrainerFn from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 993085729e545..8f9151265d21a 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -18,11 +18,11 @@ import pytest import torch from fsspec.implementations.local import LocalFileSystem + from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch.core.module import LightningModule from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index af927e0d5596a..8067fd63b6562 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -17,6 +17,9 @@ import pytest import torch +from torch.utils.data import DataLoader + +import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.callbacks import EarlyStopping @@ -25,9 +28,6 @@ from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 3e2fba54bcd03..dd5f8f1504af7 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -15,11 +15,11 @@ import pytest import torch +from torch.utils.data import BatchSampler, SequentialSampler + from lightning.fabric.utilities.data import has_len from lightning.pytorch import LightningModule, Trainer, seed_everything from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper -from torch.utils.data import BatchSampler, SequentialSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_all.py b/tests/tests_pytorch/plugins/precision/test_all.py index 2668311c8b452..2a11b7c66c772 100644 --- a/tests/tests_pytorch/plugins/precision/test_all.py +++ b/tests/tests_pytorch/plugins/precision/test_all.py @@ -1,5 +1,6 @@ import pytest import torch + from lightning.pytorch.plugins import ( DeepSpeedPrecision, DoublePrecision, diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 90ecc703c8945..cb061c540b2be 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,9 +14,10 @@ from unittest.mock import Mock import pytest +from torch.optim import Optimizer + from lightning.pytorch.plugins import MixedPrecision from lightning.pytorch.utilities import GradClipAlgorithmType -from torch.optim import Optimizer def test_clip_gradients(): diff --git a/tests/tests_pytorch/plugins/precision/test_amp_integration.py b/tests/tests_pytorch/plugins/precision/test_amp_integration.py index bc9f77907919a..f231e3ce91e0b 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp_integration.py +++ b/tests/tests_pytorch/plugins/precision/test_amp_integration.py @@ -14,12 +14,12 @@ from unittest.mock import Mock import torch + from lightning.fabric import seed_everything from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision import MixedPrecision - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py b/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py index a88e38d6303e4..8f331e26f979d 100644 --- a/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py @@ -14,10 +14,11 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed + +import lightning.fabric from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index 3e1aaa17763e9..da4ce6b89aaab 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -14,6 +14,7 @@ import pytest import torch + from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision diff --git a/tests/tests_pytorch/plugins/precision/test_double.py b/tests/tests_pytorch/plugins/precision/test_double.py index 1ee89752fcbae..74c178c65d05d 100644 --- a/tests/tests_pytorch/plugins/precision/test_double.py +++ b/tests/tests_pytorch/plugins/precision/test_double.py @@ -16,11 +16,11 @@ import pytest import torch +from torch.utils.data import DataLoader, Dataset + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.plugins.precision.double import DoublePrecision -from torch.utils.data import DataLoader, Dataset - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index 3ad3af1f1b56b..0389d364dcb79 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -15,9 +15,9 @@ import pytest import torch + from lightning.fabric.plugins.precision.utils import _DtypeContextManager from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_half.py b/tests/tests_pytorch/plugins/precision/test_half.py index d51392a00e3d0..9597e01ea428b 100644 --- a/tests/tests_pytorch/plugins/precision/test_half.py +++ b/tests/tests_pytorch/plugins/precision/test_half.py @@ -14,6 +14,7 @@ import pytest import torch + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins import HalfPrecision diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py index 7c92ff47d909a..a9967280e3f23 100644 --- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py +++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py @@ -15,9 +15,10 @@ from contextlib import nullcontext from unittest.mock import ANY, Mock -import lightning.fabric import pytest import torch + +import lightning.fabric from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins import TransformerEnginePrecision from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector diff --git a/tests/tests_pytorch/plugins/precision/test_xla.py b/tests/tests_pytorch/plugins/precision/test_xla.py index 97990b6380dab..b456c49e8ff50 100644 --- a/tests/tests_pytorch/plugins/precision/test_xla.py +++ b/tests/tests_pytorch/plugins/precision/test_xla.py @@ -18,8 +18,8 @@ import pytest import torch -from lightning.pytorch.plugins import XLAPrecision +from lightning.pytorch.plugins import XLAPrecision from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index e8adb59c39e51..0b68c098cc713 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -18,11 +18,11 @@ import pytest import torch +from torch import Tensor + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins import MixedPrecision -from torch import Tensor - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 58baa47e7a620..cae26fc1fe775 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, Mock import torch + from lightning.fabric.plugins import CheckpointIO, TorchCheckpointIO from lightning.fabric.utilities.types import _PATH from lightning.pytorch import Trainer diff --git a/tests/tests_pytorch/plugins/test_cluster_integration.py b/tests/tests_pytorch/plugins/test_cluster_integration.py index 026465ac8b17b..08bd1707b5cfd 100644 --- a/tests/tests_pytorch/plugins/test_cluster_integration.py +++ b/tests/tests_pytorch/plugins/test_cluster_integration.py @@ -16,11 +16,11 @@ import pytest import torch + from lightning.fabric.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment from lightning.pytorch import Trainer from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy from lightning.pytorch.utilities.rank_zero import rank_zero_only - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index 5b0c13e605ee2..df9bc16c284ad 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -22,6 +22,7 @@ import numpy as np import pytest import torch + from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import EarlyStopping, StochasticWeightAveraging @@ -30,7 +31,6 @@ from lightning.pytorch.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from lightning.pytorch.profilers.pytorch import _KINETO_AVAILABLE, RegisterRecordFunction, warning_cache from lightning.pytorch.utilities.exceptions import MisconfigurationException - from tests_pytorch.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -55,7 +55,7 @@ def _sleep_generator(durations): yield duration -@pytest.fixture() +@pytest.fixture def simple_profiler(): return SimpleProfiler() @@ -264,7 +264,7 @@ def test_simple_profiler_summary(tmp_path, extended): assert expected_text == summary -@pytest.fixture() +@pytest.fixture def advanced_profiler(tmp_path): return AdvancedProfiler(dirpath=tmp_path, filename="profiler") @@ -336,7 +336,7 @@ def test_advanced_profiler_deepcopy(advanced_profiler): assert deepcopy(advanced_profiler) -@pytest.fixture() +@pytest.fixture def pytorch_profiler(tmp_path): return PyTorchProfiler(dirpath=tmp_path, filename="profiler") diff --git a/tests/tests_pytorch/profilers/test_xla_profiler.py b/tests/tests_pytorch/profilers/test_xla_profiler.py index 980a4dac74731..80337382ddcb5 100644 --- a/tests/tests_pytorch/profilers/test_xla_profiler.py +++ b/tests/tests_pytorch/profilers/test_xla_profiler.py @@ -16,10 +16,10 @@ from unittest import mock import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.profilers import XLAProfiler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index ec4dd8825c8ea..ba90949132ba2 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -1,9 +1,10 @@ import pytest import torch +from torch import Tensor + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator -from torch import Tensor class ServableBoringModel(BoringModel, ServableModule): diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 394d827058987..d26f6c4d2c3ef 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -18,13 +18,13 @@ import pytest import torch + from lightning.fabric.plugins import ClusterEnvironment from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher from lightning.pytorch.trainer.states import TrainerFn - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py b/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py index b8a5ddb29de23..dd8576ec0cafe 100644 --- a/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py +++ b/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py @@ -4,9 +4,9 @@ from unittest.mock import Mock import pytest -from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning_utilities.core.imports import RequirementCache +from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from tests_pytorch.helpers.runif import RunIf _HYDRA_WITH_RUN_PROCESS = RequirementCache("hydra-core>=1.0.7") diff --git a/tests/tests_pytorch/strategies/test_common.py b/tests/tests_pytorch/strategies/test_common.py index 699424b3c53b9..6ab4f49374c27 100644 --- a/tests/tests_pytorch/strategies/test_common.py +++ b/tests/tests_pytorch/strategies/test_common.py @@ -15,10 +15,10 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision from lightning.pytorch.strategies import SingleDeviceStrategy - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/strategies/test_custom_strategy.py b/tests/tests_pytorch/strategies/test_custom_strategy.py index 347dacbd9a811..8a297db217943 100644 --- a/tests/tests_pytorch/strategies/test_custom_strategy.py +++ b/tests/tests_pytorch/strategies/test_custom_strategy.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import SingleDeviceStrategy diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index b23d306b9d907..915e57440b40f 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -17,14 +17,14 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.trainer.states import TrainerFn -from torch.nn.parallel import DistributedDataParallel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 836072d36be83..048403366ebc7 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -15,9 +15,14 @@ from unittest import mock from unittest.mock import Mock -import lightning.pytorch as pl import pytest import torch +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.multiprocessing import ProcessRaisedException +from torch.nn.parallel.distributed import DistributedDataParallel + +import lightning.pytorch as pl +import tests_pytorch.helpers.pipelines as tpipes from lightning.fabric.plugins.environments import ClusterEnvironment, LightningEnvironment from lightning.fabric.utilities.distributed import _distributed_is_initialized from lightning.pytorch import Trainer @@ -27,11 +32,6 @@ from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher from lightning.pytorch.strategies.launchers.multiprocessing import _MultiProcessingLauncher from lightning.pytorch.trainer import seed_everything -from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.multiprocessing import ProcessRaisedException -from torch.nn.parallel.distributed import DistributedDataParallel - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py b/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py index 3723ee2fc5a8f..c4d40e1dfa7d2 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py @@ -15,10 +15,10 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DDPStrategy - from tests_pytorch.helpers.runif import RunIf if torch.distributed.is_available(): diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 73697ea131545..7e7d2eacd0617 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -22,6 +22,10 @@ import pytest import torch import torch.nn.functional as F +from torch import Tensor, nn +from torch.utils.data import DataLoader +from torchmetrics import Accuracy + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint @@ -31,10 +35,6 @@ from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from torch import Tensor, nn -from torch.utils.data import DataLoader -from torchmetrics import Accuracy - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -81,7 +81,7 @@ def automatic_optimization(self) -> bool: return False -@pytest.fixture() +@pytest.fixture def deepspeed_config(): return { "optimizer": {"type": "SGD", "params": {"lr": 3e-5}}, @@ -92,7 +92,7 @@ def deepspeed_config(): } -@pytest.fixture() +@pytest.fixture def deepspeed_zero_config(deepspeed_config): return {**deepspeed_config, "zero_allow_untested_optimizer": True, "zero_optimization": {"stage": 2}} diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 2aee68f7ae733..f3e88ca356764 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -12,6 +12,10 @@ import pytest import torch import torch.nn as nn +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision +from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap +from torchmetrics import Accuracy + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 @@ -24,10 +28,6 @@ from lightning.pytorch.strategies import FSDPStrategy from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision -from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap -from torchmetrics import Accuracy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py index 731da66d4a61f..86a95944ac20d 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel.py +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -20,11 +20,11 @@ import pytest import torch import torch.nn as nn + from lightning.fabric.strategies.model_parallel import _is_sharded_checkpoint from lightning.pytorch import LightningModule from lightning.pytorch.plugins.environments import LightningEnvironment from lightning.pytorch.strategies import ModelParallelStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 9dcbcc802834b..00600183f4293 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -18,12 +18,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from lightning.pytorch import LightningModule, Trainer, seed_everything -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.strategies import ModelParallelStrategy from torch.utils.data import DataLoader, DistributedSampler from torchmetrics.classification import Accuracy +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.strategies import ModelParallelStrategy from tests_pytorch.helpers.runif import RunIf @@ -86,7 +86,7 @@ def fn(model, device_mesh): return fn -@pytest.fixture() +@pytest.fixture def distributed(): yield if torch.distributed.is_initialized(): diff --git a/tests/tests_pytorch/strategies/test_registry.py b/tests/tests_pytorch/strategies/test_registry.py index 90e15638bfd06..d2c580fd28c0f 100644 --- a/tests/tests_pytorch/strategies/test_registry.py +++ b/tests/tests_pytorch/strategies/test_registry.py @@ -14,10 +14,10 @@ from unittest import mock import pytest + from lightning.pytorch import Trainer from lightning.pytorch.plugins import CheckpointIO from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy, FSDPStrategy, StrategyRegistry, XLAStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_single_device.py b/tests/tests_pytorch/strategies/test_single_device.py index 5b10c4a17d726..7582dfe86dd3c 100644 --- a/tests/tests_pytorch/strategies/test_single_device.py +++ b/tests/tests_pytorch/strategies/test_single_device.py @@ -16,12 +16,12 @@ import pytest import torch +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.strategies import SingleDeviceStrategy -from torch.utils.data import DataLoader - from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py index b4f0e2c37ec05..3fde2600c9483 100644 --- a/tests/tests_pytorch/strategies/test_xla.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -16,11 +16,11 @@ from unittest.mock import Mock import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import XLAStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index de89d094cdfcf..b9decf781386b 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -27,6 +27,14 @@ import pytest import torch import yaml +from lightning_utilities import compare_version +from lightning_utilities.test.warning import no_warning_call +from packaging.version import Version +from tensorboard.backend.event_processing import event_accumulator +from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData +from torch.optim import SGD +from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR + from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__, seed_everything from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint @@ -46,14 +54,6 @@ from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from lightning_utilities import compare_version -from lightning_utilities.test.warning import no_warning_call -from packaging.version import Version -from tensorboard.backend.event_processing import event_accumulator -from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData -from torch.optim import SGD -from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR - from tests_pytorch.helpers.runif import RunIf if _JSONARGPARSE_SIGNATURES_AVAILABLE: @@ -84,7 +84,7 @@ def mock_subclasses(baseclass, *subclasses): yield None -@pytest.fixture() +@pytest.fixture def cleandir(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) return @@ -666,7 +666,7 @@ class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) - match = "BoringModel.configure_optimizers` will be overridden by " "`MyLightningCLI.configure_optimizers`" + match = "BoringModel.configure_optimizers` will be overridden by `MyLightningCLI.configure_optimizers`" argv = ["fit", "--trainer.fast_dev_run=1"] if run else [] with mock.patch("sys.argv", ["any.py"] + argv), pytest.warns(UserWarning, match=match): cli = MyLightningCLI(BoringModel, run=run) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 7de9470fbe186..9b4bc87205608 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -19,11 +19,12 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric -import lightning.pytorch import pytest import torch import torch.distributed + +import lightning.fabric +import lightning.pytorch from lightning.fabric.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, @@ -61,7 +62,6 @@ from lightning.pytorch.utilities.imports import ( _LIGHTNING_HABANA_AVAILABLE, ) - from tests_pytorch.conftest import mock_cuda_count, mock_mps_count, mock_tpu_available, mock_xla_available from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index eb09413cafcce..94b5fcba652be 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -18,6 +18,7 @@ import pytest import torch + from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_10_0 from lightning.pytorch import Callback, LightningModule, Trainer from lightning.pytorch.callbacks import ( diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index d29e2285e983c..722742a3ccae0 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index ca5690ed20f41..ceb0418f2cb1d 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -16,8 +16,12 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler + +import lightning.fabric from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer @@ -33,10 +37,6 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import _update_dataloader from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py index 8825db3727e86..83c5c2bb7e02b 100644 --- a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py @@ -18,13 +18,13 @@ from unittest.mock import Mock import pytest + from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector from lightning.pytorch.utilities.exceptions import SIGTERMException - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py index f7b7d3c8ac8a7..89c842f2bbf0f 100644 --- a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py +++ b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch +from torch.utils.data import Dataset + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from torch.utils.data import Dataset class RandomDatasetA(Dataset): diff --git a/tests/tests_pytorch/trainer/flags/test_barebones.py b/tests/tests_pytorch/trainer/flags/test_barebones.py index 329fcf915d751..875aaef40a123 100644 --- a/tests/tests_pytorch/trainer/flags/test_barebones.py +++ b/tests/tests_pytorch/trainer/flags/test_barebones.py @@ -14,6 +14,7 @@ import logging import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelSummary from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py b/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py index fba938d82e761..301721e4b28f5 100644 --- a/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +from torch.utils.data import DataLoader + from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.trainer.trainer import Trainer -from torch.utils.data import DataLoader @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py index 62087619ae42f..63b05a0a131f2 100644 --- a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py +++ b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/flags/test_inference_mode.py b/tests/tests_pytorch/trainer/flags/test_inference_mode.py index bae7b66dbbd55..802d9bf30c59b 100644 --- a/tests/tests_pytorch/trainer/flags/test_inference_mode.py +++ b/tests/tests_pytorch/trainer/flags/test_inference_mode.py @@ -16,6 +16,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _Loop diff --git a/tests/tests_pytorch/trainer/flags/test_limit_batches.py b/tests/tests_pytorch/trainer/flags/test_limit_batches.py index e190a0b380377..ef405a8ee95b1 100644 --- a/tests/tests_pytorch/trainer/flags/test_limit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_limit_batches.py @@ -14,6 +14,7 @@ import logging import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.states import TrainerFn diff --git a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py index 25aaeb8cff77e..3315c328b6249 100644 --- a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py +++ b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py @@ -1,8 +1,9 @@ import pytest +from lightning_utilities.test.warning import no_warning_call + from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning_utilities.test.warning import no_warning_call @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index dfe8a63a8bfc4..050818287ba45 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -16,11 +16,11 @@ import pytest import torch +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.trainer.states import RunningStage -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.datasets import SklearnDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index b776263e9953d..b6cc446cb0840 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -14,10 +14,11 @@ import logging import pytest +from torch.utils.data import DataLoader + from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.trainer.trainer import Trainer from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader @pytest.mark.parametrize("max_epochs", [1, 2, 3]) diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py index af7cecdb21a08..90f9a3e697535 100644 --- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py @@ -19,7 +19,6 @@ from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.logger import Logger - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 40c82bec2fd10..be6de37ddff3a 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -24,6 +24,8 @@ import numpy as np import pytest import torch +from torch import Tensor + from lightning.pytorch import Trainer, callbacks from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -31,8 +33,6 @@ from lightning.pytorch.loops import _EvaluationLoop from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import Tensor - from tests_pytorch.helpers.runif import RunIf if _RICH_AVAILABLE: diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index e96857a6c192d..faf88a09f6499 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -18,6 +18,11 @@ import pytest import torch +from lightning_utilities.core.imports import compare_version +from torch.utils.data import DataLoader +from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection +from torchmetrics import AveragePrecision as AvgPre + from lightning.pytorch import LightningModule from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -28,11 +33,6 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1 from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from lightning_utilities.core.imports import compare_version -from torch.utils.data import DataLoader -from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection -from torchmetrics import AveragePrecision as AvgPre - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.models.test_hooks import get_members diff --git a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py index e43f3b8d6bffd..6bc9dd1de0587 100644 --- a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py @@ -18,6 +18,7 @@ from unittest.mock import ANY import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index e48f80d2d1680..3981ddd64d773 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -22,6 +22,11 @@ import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import DataLoader +from torchmetrics import Accuracy + from lightning.pytorch import Trainer, callbacks from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from lightning.pytorch.core.module import LightningModule @@ -30,11 +35,6 @@ from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import DataLoader -from torchmetrics import Accuracy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/optimization/test_backward_calls.py b/tests/tests_pytorch/trainer/optimization/test_backward_calls.py index b91dbff8c6d09..42332bc05580f 100644 --- a/tests/tests_pytorch/trainer/optimization/test_backward_calls.py +++ b/tests/tests_pytorch/trainer/optimization/test_backward_calls.py @@ -2,6 +2,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index f0ab8fe401633..3f89e1459298d 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -21,11 +21,11 @@ import torch import torch.distributed as torch_distrib import torch.nn.functional as F + from lightning.fabric.utilities.exceptions import MisconfigurationException from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel from lightning.pytorch.strategies import Strategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py index 319eafeb0d0bb..dcbae32827c50 100644 --- a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py @@ -13,9 +13,10 @@ # limitations under the License. """Tests to ensure that the behaviours related to multiple optimizers works.""" -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index ac660b6651be5..6b88534f3430d 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -16,6 +16,8 @@ import pytest import torch +from torch import optim + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.core.optimizer import ( @@ -26,8 +28,6 @@ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import LRSchedulerConfig -from torch import optim - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 76c0c695b3c02..aa60db594447d 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -18,11 +18,11 @@ import pytest import torch +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset from lightning.pytorch.strategies import SingleDeviceXLAStrategy -from torch.utils.data import DataLoader - from tests_pytorch.conftest import mock_cuda_count from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_get_model.py b/tests/tests_pytorch/trainer/properties/test_get_model.py index 72967ce929eea..47d2bffcfca22 100644 --- a/tests/tests_pytorch/trainer/properties/test_get_model.py +++ b/tests/tests_pytorch/trainer/properties/test_get_model.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_log_dir.py b/tests/tests_pytorch/trainer/properties/test_log_dir.py index 1fc4f3454f9d0..0f045c2e815fd 100644 --- a/tests/tests_pytorch/trainer/properties/test_log_dir.py +++ b/tests/tests_pytorch/trainer/properties/test_log_dir.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 7fcf663f59122..1d07f3e99d412 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger - from tests_pytorch.loggers.test_logger import CustomLogger diff --git a/tests/tests_pytorch/trainer/test_config_validator.py b/tests/tests_pytorch/trainer/test_config_validator.py index 8963b1f76186d..cfca98e04c8c8 100644 --- a/tests/tests_pytorch/trainer/test_config_validator.py +++ b/tests/tests_pytorch/trainer/test_config_validator.py @@ -15,6 +15,7 @@ import pytest import torch + from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index a2d29baa9fa6f..7fbe55030770e 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -14,10 +14,17 @@ import os from unittest.mock import Mock, call, patch -import lightning.pytorch import numpy import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch.utils.data import RandomSampler +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import SequentialSampler + +import lightning.pytorch from lightning.fabric.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset from lightning.pytorch import Callback, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -33,13 +40,6 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import RandomSampler -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset, IterableDataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import SequentialSampler - from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_states.py b/tests/tests_pytorch/trainer/test_states.py index d89e99c9319c6..fff2d0b464d42 100644 --- a/tests/tests_pytorch/trainer/test_states.py +++ b/tests/tests_pytorch/trainer/test_states.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index d66f3aafee5df..18ae7ce77bdfc 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -27,6 +27,12 @@ import pytest import torch import torch.nn as nn +from torch.multiprocessing import ProcessRaisedException +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim import SGD +from torch.utils.data import DataLoader, IterableDataset + +import tests_pytorch.helpers.utils as tutils from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.seed import seed_everything @@ -50,12 +56,6 @@ from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from torch.multiprocessing import ProcessRaisedException -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.optim import SGD -from torch.utils.data import DataLoader, IterableDataset - -import tests_pytorch.helpers.utils as tutils from tests_pytorch.conftest import mock_cuda_count, mock_mps_count from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -335,9 +335,9 @@ def mock_save_function(filepath, *args): file_lists = set(os.listdir(tmp_path)) - assert len(file_lists) == len( - expected_files - ), f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}" + assert len(file_lists) == len(expected_files), ( + f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}" + ) # verify correct naming for fname in expected_files: diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index a31be67911409..ec894688ccb6c 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -20,6 +20,8 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks.lr_finder import LearningRateFinder from lightning.pytorch.demos.boring_classes import BoringModel @@ -27,8 +29,6 @@ from lightning.pytorch.tuner.tuning import Tuner from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import STEP_OUTPUT -from lightning_utilities.test.warning import no_warning_call - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -73,9 +73,9 @@ def test_model_reset_correctly(tmp_path): after_state_dict = model.state_dict() for key in before_state_dict: - assert torch.all( - torch.eq(before_state_dict[key], after_state_dict[key]) - ), "Model was not reset correctly after learning rate finder" + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), ( + "Model was not reset correctly after learning rate finder" + ) assert not any(f for f in os.listdir(tmp_path) if f.startswith(".lr_find")) diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index 8dd66fe9bfcff..e4ed533c6fa83 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -18,14 +18,14 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset from lightning.pytorch.tuner.tuning import Tuner from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf @@ -114,9 +114,9 @@ def test_trainer_reset_correctly(tmp_path, trainer_fn): after_state_dict = model.state_dict() for key in before_state_dict: - assert torch.all( - torch.eq(before_state_dict[key], after_state_dict[key]) - ), "Model was not reset correctly after scaling batch size" + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), ( + "Model was not reset correctly after scaling batch size" + ) assert not any(f for f in os.listdir(tmp_path) if f.startswith(".scale_batch_size_temp_model")) diff --git a/tests/tests_pytorch/tuner/test_tuning.py b/tests/tests_pytorch/tuner/test_tuning.py index dda08354575c0..e3b24a69b7999 100644 --- a/tests/tests_pytorch/tuner/test_tuning.py +++ b/tests/tests_pytorch/tuner/test_tuning.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import BatchSizeFinder, LearningRateFinder from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index 9680c90a94c5b..f9c921f6f1bfd 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -13,9 +13,10 @@ # limitations under the License. from unittest.mock import ANY, MagicMock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 56b8701cfcfc2..41ca1a779f8a5 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -18,16 +18,16 @@ import sys from unittest.mock import ANY -import lightning.pytorch as pl import pytest import torch -from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.pytorch.utilities.migration import migrate_checkpoint, pl_legacy_patch -from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint, _RedirectingUnpickler from lightning_utilities.core.imports import module_available from lightning_utilities.test.warning import no_warning_call from packaging.version import Version +import lightning.pytorch as pl +from lightning.fabric.utilities.warnings import PossibleUserWarning +from lightning.pytorch.utilities.migration import migrate_checkpoint, pl_legacy_patch +from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint, _RedirectingUnpickler from tests_pytorch.checkpointing.test_legacy_checkpoints import ( CHECKPOINT_EXTENSION, LEGACY_BACK_COMPATIBLE_PL_VERSIONS, @@ -75,9 +75,9 @@ def _list_sys_modules(pattern: str) -> str: @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @pytest.mark.skipif(module_available("lightning"), reason="This test is ONLY relevant for the STANDALONE package") def test_test_patch_legacy_imports_standalone(pl_version): - assert any( - key.startswith("pytorch_lightning") for key in sys.modules - ), f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + assert any(key.startswith("pytorch_lightning") for key in sys.modules), ( + f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + ) path_legacy = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) path_ckpts = sorted(glob.glob(os.path.join(path_legacy, f"*{CHECKPOINT_EXTENSION}"))) assert path_ckpts, f'No checkpoints found in folder "{path_legacy}"' @@ -86,9 +86,9 @@ def test_test_patch_legacy_imports_standalone(pl_version): with no_warning_call(match="Redirecting import of*"), pl_legacy_patch(): torch.load(path_ckpt, weights_only=False) - assert any( - key.startswith("pytorch_lightning") for key in sys.modules - ), f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + assert any(key.startswith("pytorch_lightning") for key in sys.modules), ( + f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + ) assert not any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( "Did not import the unified package," f" so it should not be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" @@ -98,9 +98,9 @@ def test_test_patch_legacy_imports_standalone(pl_version): @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @pytest.mark.skipif(not module_available("lightning"), reason="This test is ONLY relevant for the UNIFIED package") def test_patch_legacy_imports_unified(pl_version): - assert any( - key.startswith("lightning." + "pytorch") for key in sys.modules - ), f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( + f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + ) assert not any(key.startswith("pytorch_lightning") for key in sys.modules), ( "Should not import standalone package, all imports should be redirected to the unified package;\n" f" environment: {_list_sys_modules('pytorch_lightning')}" @@ -119,9 +119,9 @@ def test_patch_legacy_imports_unified(pl_version): with context, pl_legacy_patch(): torch.load(path_ckpt, weights_only=False) - assert any( - key.startswith("lightning." + "pytorch") for key in sys.modules - ), f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( + f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + ) assert not any(key.startswith("pytorch_lightning") for key in sys.modules), ( "Should not import standalone package, all imports should be redirected to the unified package;\n" f" environment: {_list_sys_modules('pytorch_lightning')}" diff --git a/tests/tests_pytorch/utilities/test_all_gather_grad.py b/tests/tests_pytorch/utilities/test_all_gather_grad.py index 9b034cdcd34e2..82ca15fd87432 100644 --- a/tests/tests_pytorch/utilities/test_all_gather_grad.py +++ b/tests/tests_pytorch/utilities/test_all_gather_grad.py @@ -15,9 +15,9 @@ import numpy as np import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.core.test_results import spawn_launch from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index d4b2fdaf82834..4da7bfd098a0b 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -14,13 +14,13 @@ import inspect import pytest +from torch.utils.data.dataloader import DataLoader + from lightning.fabric.utilities.seed import seed_everything from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities.exceptions import SIGTERMException -from torch.utils.data.dataloader import DataLoader - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 43a146c6eb089..da168be1e3e8a 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -19,6 +19,13 @@ import pytest import torch +from torch import Tensor +from torch.utils._pytree import tree_flatten +from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import RandomSampler, SequentialSampler + from lightning.fabric.utilities.types import _Stateful from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -31,13 +38,6 @@ _MinSize, _Sequential, ) -from torch import Tensor -from torch.utils._pytree import tree_flatten -from torch.utils.data import DataLoader, TensorDataset -from torch.utils.data.dataset import Dataset, IterableDataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler, SequentialSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_compile.py b/tests/tests_pytorch/utilities/test_compile.py index 67f992421f7ce..a053c847dfd6c 100644 --- a/tests/tests_pytorch/utilities/test_compile.py +++ b/tests/tests_pytorch/utilities/test_compile.py @@ -18,12 +18,12 @@ import pytest import torch +from lightning_utilities.core.imports import RequirementCache + from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled -from lightning_utilities.core.imports import RequirementCache - from tests_pytorch.conftest import mock_cuda_count from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index d79e9e24383a0..65a3a47715bfe 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -4,6 +4,10 @@ import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler + from lightning.fabric.utilities.data import _replace_dunder_methods from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer @@ -19,9 +23,6 @@ warning_cache, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler def test_extract_batch_size(): diff --git a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py index a44ed655c2e61..05657854eb0b1 100644 --- a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py @@ -14,11 +14,11 @@ import os import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DeepSpeedStrategy from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py index c8a138ba0a02a..256233e01fa98 100644 --- a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py +++ b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py @@ -17,7 +17,6 @@ from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DeepSpeedStrategy from lightning.pytorch.utilities.model_summary import DeepSpeedSummary - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py index 1ba3aff359609..171656d072076 100644 --- a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py +++ b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch.nn as nn + from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_grads.py b/tests/tests_pytorch/utilities/test_grads.py index ade88f767670c..277971d64b180 100644 --- a/tests/tests_pytorch/utilities/test_grads.py +++ b/tests/tests_pytorch/utilities/test_grads.py @@ -16,6 +16,7 @@ import pytest import torch import torch.nn as nn + from lightning.pytorch.utilities import grad_norm diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 56ee326f076dc..301c97d756899 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -19,10 +19,10 @@ from unittest import mock import pytest -from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from lightning_utilities.core.imports import RequirementCache from torch.distributed import is_available +from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from tests_pytorch.helpers.runif import RunIf @@ -60,7 +60,7 @@ def new_fn(*args, **kwargs): return new_fn -@pytest.fixture() +@pytest.fixture def clean_import(): """This fixture allows test to import {pytorch_}lightning* modules completely cleanly, regardless of the current state of the imported modules. diff --git a/tests/tests_pytorch/utilities/test_memory.py b/tests/tests_pytorch/utilities/test_memory.py index c1ebff03a3a4f..336a9fafa3243 100644 --- a/tests/tests_pytorch/utilities/test_memory.py +++ b/tests/tests_pytorch/utilities/test_memory.py @@ -13,6 +13,7 @@ # limitations under the License. import torch + from lightning.pytorch.utilities.memory import recursive_detach diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index 78a63a7e9d2a7..e7a9d9275a484 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -16,10 +16,11 @@ import pytest import torch.nn +from lightning_utilities import module_available + from lightning.pytorch import LightningDataModule from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.utilities.model_helpers import _ModuleMode, _restricted_classmethod, is_overridden -from lightning_utilities import module_available def test_is_overridden(): diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index cced6546aab75..54c5572d01767 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -18,6 +18,7 @@ import pytest import torch import torch.nn as nn + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.model_summary.model_summary import ( @@ -27,7 +28,6 @@ ModelSummary, summarize, ) - from tests_pytorch.helpers.advanced_models import ParityModuleRNN from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_parameter_tying.py b/tests/tests_pytorch/utilities/test_parameter_tying.py index 9dc9b5648ff01..e45fb39f81b34 100644 --- a/tests/tests_pytorch/utilities/test_parameter_tying.py +++ b/tests/tests_pytorch/utilities/test_parameter_tying.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch +from torch import nn + from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters -from torch import nn class ParameterSharingModule(BoringModel): diff --git a/tests/tests_pytorch/utilities/test_parsing.py b/tests/tests_pytorch/utilities/test_parsing.py index 1c126723d89a6..a2671eb8790de 100644 --- a/tests/tests_pytorch/utilities/test_parsing.py +++ b/tests/tests_pytorch/utilities/test_parsing.py @@ -15,6 +15,8 @@ import threading import pytest +from torch.jit import ScriptModule + from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.utilities.parsing import ( _get_init_args, @@ -26,7 +28,6 @@ lightning_setattr, parse_class_init_keys, ) -from torch.jit import ScriptModule unpicklable_function = lambda: None @@ -103,12 +104,12 @@ def test_lightning_hasattr(): assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable" assert not lightning_hasattr(model4, "learning_rate"), "lightning_hasattr found variable when it should not" assert lightning_hasattr(model5, "batch_size"), "lightning_hasattr failed to find batch_size in datamodule" - assert lightning_hasattr( - model6, "batch_size" - ), "lightning_hasattr failed to find batch_size in datamodule w/ hparams present" - assert lightning_hasattr( - model7, "batch_size" - ), "lightning_hasattr failed to find batch_size in hparams w/ datamodule present" + assert lightning_hasattr(model6, "batch_size"), ( + "lightning_hasattr failed to find batch_size in datamodule w/ hparams present" + ) + assert lightning_hasattr(model7, "batch_size"), ( + "lightning_hasattr failed to find batch_size in hparams w/ datamodule present" + ) assert lightning_hasattr(model8, "batch_size") for m in models: diff --git a/tests/tests_pytorch/utilities/test_pytree.py b/tests/tests_pytorch/utilities/test_pytree.py index afd198919e23f..c87a83f85f6ea 100644 --- a/tests/tests_pytorch/utilities/test_pytree.py +++ b/tests/tests_pytorch/utilities/test_pytree.py @@ -1,7 +1,8 @@ import torch -from lightning.pytorch.utilities._pytree import _tree_flatten, tree_unflatten from torch.utils.data import DataLoader, TensorDataset +from lightning.pytorch.utilities._pytree import _tree_flatten, tree_unflatten + def assert_tree_flatten_unflatten(pytree, leaves): flat, spec = _tree_flatten(pytree) diff --git a/tests/tests_pytorch/utilities/test_seed.py b/tests/tests_pytorch/utilities/test_seed.py index 282009f6b93cf..00484009481e9 100644 --- a/tests/tests_pytorch/utilities/test_seed.py +++ b/tests/tests_pytorch/utilities/test_seed.py @@ -4,8 +4,8 @@ import numpy as np import pytest import torch -from lightning.pytorch.utilities.seed import isolate_rng +from lightning.pytorch.utilities.seed import isolate_rng from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_signature_utils.py b/tests/tests_pytorch/utilities/test_signature_utils.py index e453459670360..23f9258b0d56b 100644 --- a/tests/tests_pytorch/utilities/test_signature_utils.py +++ b/tests/tests_pytorch/utilities/test_signature_utils.py @@ -1,4 +1,5 @@ import torch + from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 2ef1ecd4fe3e5..15db0becb9551 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -18,6 +18,7 @@ import pytest import torch + from lightning.pytorch.utilities.upgrade_checkpoint import main as upgrade_main diff --git a/tests/tests_pytorch/utilities/test_warnings.py b/tests/tests_pytorch/utilities/test_warnings.py index 8c385a2d2e49f..ab7b7d2c15a4d 100644 --- a/tests/tests_pytorch/utilities/test_warnings.py +++ b/tests/tests_pytorch/utilities/test_warnings.py @@ -24,11 +24,12 @@ from io import StringIO from unittest import mock -import lightning.pytorch import pytest -from lightning.pytorch.utilities.warnings import PossibleUserWarning from lightning_utilities.test.warning import no_warning_call +import lightning.pytorch +from lightning.pytorch.utilities.warnings import PossibleUserWarning + if __name__ == "__main__": # check that logging is properly configured import logging