diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 775dc5dee77dc..51bd2056177bb 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -182,7 +182,7 @@ We welcome any useful contribution! For your convenience here's a recommended wo 1. Use tags in PR name for the following cases: - **\[blocked by #\]** if your work is dependent on other PRs. - - **\[wip\]** when you start to re-edit your work, mark it so no one will accidentally merge it in meantime. + - **[wip]** when you start to re-edit your work, mark it so no one will accidentally merge it in meantime. ### Question & Answer diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 58f4afe529509..3bdc8f9a0b07f 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -16,7 +16,7 @@ Brief description of all our automation tools used for boosting development perf | .azure-pipelines/gpu-benchmarks.yml | Run speed/memory benchmarks for parity with vanila PyTorch. | GPU | | .github/workflows/ci-flagship-apps.yml | Run end-2-end tests with full applications, including deployment to the production cloud. | CPU | | .github/workflows/ci-tests-pytorch.yml | Run all tests except for accelerator-specific, standalone and slow tests. | CPU | -| .github/workflows/tpu-tests.yml | Run only TPU-specific tests. Requires that the PR title contains '\[TPU\]' | TPU | +| .github/workflows/tpu-tests.yml | Run only TPU-specific tests. Requires that the PR title contains '[TPU]' | TPU | \* Each standalone test needs to be run in separate processes to avoid unwanted interactions between test cases. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c5e65de1d7eb7..f2e475f602913 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -65,12 +65,12 @@ repos: args: ["--in-place"] - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.9.1 + rev: v1.0.0 hooks: - id: sphinx-lint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.8.6 hooks: # try to fix what is possible - id: ruff @@ -81,7 +81,7 @@ repos: - id: ruff - repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 + rev: 0.7.21 hooks: - id: mdformat additional_dependencies: diff --git a/examples/fabric/build_your_own_trainer/run.py b/examples/fabric/build_your_own_trainer/run.py index c0c2ff28ddc41..936b590f5041a 100644 --- a/examples/fabric/build_your_own_trainer/run.py +++ b/examples/fabric/build_your_own_trainer/run.py @@ -1,8 +1,9 @@ -import lightning as L import torch from torchmetrics.functional.classification.accuracy import accuracy from trainer import MyCustomTrainer +import lightning as L + class MNISTModule(L.LightningModule): def __init__(self) -> None: diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index f4f31c114f084..d9d081a2aea69 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -3,15 +3,16 @@ from functools import partial from typing import Any, Literal, Optional, Union, cast -import lightning as L import torch +from lightning_utilities import apply_to_collection +from tqdm import tqdm + +import lightning as L from lightning.fabric.accelerators import Accelerator from lightning.fabric.loggers import Logger from lightning.fabric.strategies import Strategy from lightning.fabric.wrappers import _unwrap_objects from lightning.pytorch.utilities.model_helpers import is_overridden -from lightning_utilities import apply_to_collection -from tqdm import tqdm class MyCustomTrainer: diff --git a/examples/fabric/dcgan/train_fabric.py b/examples/fabric/dcgan/train_fabric.py index f7a18b2b5bc17..66f11e1c6fcfe 100644 --- a/examples/fabric/dcgan/train_fabric.py +++ b/examples/fabric/dcgan/train_fabric.py @@ -16,9 +16,10 @@ import torch.utils.data import torchvision.transforms as transforms import torchvision.utils -from lightning.fabric import Fabric, seed_everything from torchvision.datasets import CelebA +from lightning.fabric import Fabric, seed_everything + # Root directory for dataset dataroot = "data/" # Number of workers for dataloader diff --git a/examples/fabric/fp8_distributed_transformer/train.py b/examples/fabric/fp8_distributed_transformer/train.py index ba88603268945..a30e2de2fc5ed 100644 --- a/examples/fabric/fp8_distributed_transformer/train.py +++ b/examples/fabric/fp8_distributed_transformer/train.py @@ -1,15 +1,16 @@ -import lightning as L import torch import torch.nn as nn import torch.nn.functional as F -from lightning.fabric.strategies import ModelParallelStrategy -from lightning.pytorch.demos import Transformer, WikiText2 from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.distributed.device_mesh import DeviceMesh from torch.utils.data import DataLoader from torchao.float8 import Float8LinearConfig, convert_to_float8_training from tqdm import tqdm +import lightning as L +from lightning.fabric.strategies import ModelParallelStrategy +from lightning.pytorch.demos import Transformer, WikiText2 + def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: float8_config = Float8LinearConfig( diff --git a/examples/fabric/image_classifier/train_fabric.py b/examples/fabric/image_classifier/train_fabric.py index 02487a65e3989..d207595e9d2ba 100644 --- a/examples/fabric/image_classifier/train_fabric.py +++ b/examples/fabric/image_classifier/train_fabric.py @@ -36,11 +36,12 @@ import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as T -from lightning.fabric import Fabric, seed_everything from torch.optim.lr_scheduler import StepLR from torchmetrics.classification import Accuracy from torchvision.datasets import MNIST +from lightning.fabric import Fabric, seed_everything + DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets") diff --git a/examples/fabric/kfold_cv/train_fabric.py b/examples/fabric/kfold_cv/train_fabric.py index b3aa08e9aae9b..05d9885190dbc 100644 --- a/examples/fabric/kfold_cv/train_fabric.py +++ b/examples/fabric/kfold_cv/train_fabric.py @@ -20,12 +20,13 @@ import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as T -from lightning.fabric import Fabric, seed_everything from sklearn import model_selection from torch.utils.data import DataLoader, SubsetRandomSampler from torchmetrics.classification import Accuracy from torchvision.datasets import MNIST +from lightning.fabric import Fabric, seed_everything + DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets") diff --git a/examples/fabric/language_model/train.py b/examples/fabric/language_model/train.py index cafe6ceeb18b1..01947893be926 100644 --- a/examples/fabric/language_model/train.py +++ b/examples/fabric/language_model/train.py @@ -1,9 +1,10 @@ -import lightning as L import torch import torch.nn.functional as F -from lightning.pytorch.demos import Transformer, WikiText2 from torch.utils.data import DataLoader, random_split +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 + def main(): L.seed_everything(42) diff --git a/examples/fabric/meta_learning/train_fabric.py b/examples/fabric/meta_learning/train_fabric.py index 3cf1390477aeb..203155f7b2ada 100644 --- a/examples/fabric/meta_learning/train_fabric.py +++ b/examples/fabric/meta_learning/train_fabric.py @@ -18,6 +18,7 @@ import cherry import learn2learn as l2l import torch + from lightning.fabric import Fabric, seed_everything diff --git a/examples/fabric/reinforcement_learning/rl/agent.py b/examples/fabric/reinforcement_learning/rl/agent.py index 16a4cd6d86c73..b3d024d720d11 100644 --- a/examples/fabric/reinforcement_learning/rl/agent.py +++ b/examples/fabric/reinforcement_learning/rl/agent.py @@ -3,11 +3,11 @@ import gymnasium as gym import torch import torch.nn.functional as F -from lightning.pytorch import LightningModule from torch import Tensor from torch.distributions import Categorical from torchmetrics import MeanMetric +from lightning.pytorch import LightningModule from rl.loss import entropy_loss, policy_loss, value_loss from rl.utils import layer_init diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index 4df52d7cd0455..7c9536fbd9532 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -25,13 +25,14 @@ import gymnasium as gym import torch import torchmetrics -from lightning.fabric import Fabric -from lightning.fabric.loggers import TensorBoardLogger from rl.agent import PPOLightningAgent from rl.utils import linear_annealing, make_env, parse_args, test from torch import Tensor from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler +from lightning.fabric import Fabric +from lightning.fabric.loggers import TensorBoardLogger + def train( fabric: Fabric, diff --git a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py index 7150ac3a12529..1d8967ec6f1dc 100644 --- a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py +++ b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py @@ -25,17 +25,18 @@ import gymnasium as gym import torch -from lightning.fabric import Fabric -from lightning.fabric.loggers import TensorBoardLogger -from lightning.fabric.plugins.collectives import TorchCollective -from lightning.fabric.plugins.collectives.collective import CollectibleGroup -from lightning.fabric.strategies import DDPStrategy from rl.agent import PPOLightningAgent from rl.utils import linear_annealing, make_env, parse_args, test from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler from torchmetrics import MeanMetric +from lightning.fabric import Fabric +from lightning.fabric.loggers import TensorBoardLogger +from lightning.fabric.plugins.collectives import TorchCollective +from lightning.fabric.plugins.collectives.collective import CollectibleGroup +from lightning.fabric.strategies import DDPStrategy + @torch.no_grad() def player(args, world_collective: TorchCollective, player_trainer_collective: TorchCollective): diff --git a/examples/fabric/tensor_parallel/README.md b/examples/fabric/tensor_parallel/README.md index e66d9acd2848b..1f551109cc5e7 100644 --- a/examples/fabric/tensor_parallel/README.md +++ b/examples/fabric/tensor_parallel/README.md @@ -41,5 +41,5 @@ Training successfully completed! Peak memory usage: 17.95 GB ``` -> \[!NOTE\] +> [!NOTE] > The `ModelParallelStrategy` is experimental and subject to change. Report issues on [GitHub](https://github.com/Lightning-AI/pytorch-lightning/issues). diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py index 4a98f12cf6168..35ee9074f18a8 100644 --- a/examples/fabric/tensor_parallel/train.py +++ b/examples/fabric/tensor_parallel/train.py @@ -1,13 +1,14 @@ -import lightning as L import torch import torch.nn.functional as F from data import RandomTokenDataset -from lightning.fabric.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader +import lightning as L +from lightning.fabric.strategies import ModelParallelStrategy + def train(): strategy = ModelParallelStrategy( diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index d6f594b12f57b..332c9a811e3e4 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -22,13 +22,14 @@ import torch import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader, random_split + from lightning.pytorch import LightningDataModule, LightningModule, Trainer, callbacks, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.mnist_datamodule import MNIST from lightning.pytorch.utilities import rank_zero_only from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from torch import nn -from torch.utils.data import DataLoader, random_split if _TORCHVISION_AVAILABLE: import torchvision diff --git a/examples/pytorch/basics/backbone_image_classifier.py b/examples/pytorch/basics/backbone_image_classifier.py index fceb97dc41cff..965f636d7fc0b 100644 --- a/examples/pytorch/basics/backbone_image_classifier.py +++ b/examples/pytorch/basics/backbone_image_classifier.py @@ -21,12 +21,13 @@ from typing import Optional import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader, random_split + from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.mnist_datamodule import MNIST from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/examples/pytorch/basics/profiler_example.py b/examples/pytorch/basics/profiler_example.py index 5b9ba08fec761..366aaecefe7e4 100644 --- a/examples/pytorch/basics/profiler_example.py +++ b/examples/pytorch/basics/profiler_example.py @@ -28,6 +28,7 @@ import torch import torchvision import torchvision.transforms as T + from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.profilers.pytorch import PyTorchProfiler diff --git a/examples/pytorch/basics/transformer.py b/examples/pytorch/basics/transformer.py index 93cb39d829acc..dbd990d7f2759 100644 --- a/examples/pytorch/basics/transformer.py +++ b/examples/pytorch/basics/transformer.py @@ -1,9 +1,10 @@ -import lightning as L import torch import torch.nn.functional as F -from lightning.pytorch.demos import Transformer, WikiText2 from torch.utils.data import DataLoader, random_split +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 + class LanguageModel(L.LightningModule): def __init__(self, vocab_size): diff --git a/examples/pytorch/bug_report/bug_report_model.py b/examples/pytorch/bug_report/bug_report_model.py index aa3f4cad710fe..551ea21721754 100644 --- a/examples/pytorch/bug_report/bug_report_model.py +++ b/examples/pytorch/bug_report/bug_report_model.py @@ -1,9 +1,10 @@ import os import torch -from lightning.pytorch import LightningModule, Trainer from torch.utils.data import DataLoader, Dataset +from lightning.pytorch import LightningModule, Trainer + class RandomDataset(Dataset): def __init__(self, size, length): diff --git a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py index d03bb0c4edd16..69721214748ee 100644 --- a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py +++ b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py @@ -46,11 +46,6 @@ import torch import torch.nn.functional as F -from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo -from lightning.pytorch.callbacks.finetuning import BaseFinetuning -from lightning.pytorch.cli import LightningCLI -from lightning.pytorch.utilities import rank_zero_info -from lightning.pytorch.utilities.model_helpers import get_torchvision_model from torch import nn, optim from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer @@ -60,6 +55,12 @@ from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive +from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo +from lightning.pytorch.callbacks.finetuning import BaseFinetuning +from lightning.pytorch.cli import LightningCLI +from lightning.pytorch.utilities import rank_zero_info +from lightning.pytorch.utilities.model_helpers import get_torchvision_model + log = logging.getLogger(__name__) DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py index 417e167df0d93..7ce7682d82c76 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + from lightning.pytorch import cli_lightning_logo from lightning.pytorch.core import LightningModule from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule diff --git a/examples/pytorch/domain_templates/imagenet.py b/examples/pytorch/domain_templates/imagenet.py index 49f9b509ce6e7..fd2050e2ed38b 100644 --- a/examples/pytorch/domain_templates/imagenet.py +++ b/examples/pytorch/domain_templates/imagenet.py @@ -43,13 +43,14 @@ import torch.utils.data.distributed import torchvision.datasets as datasets import torchvision.transforms as transforms +from torch.utils.data import Dataset +from torchmetrics import Accuracy + from lightning.pytorch import LightningModule from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar from lightning.pytorch.cli import LightningCLI from lightning.pytorch.strategies import ParallelStrategy from lightning.pytorch.utilities.model_helpers import get_torchvision_model -from torch.utils.data import Dataset -from torchmetrics import Accuracy class ImageNetLightningModel(LightningModule): diff --git a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py index b3bfaaea93e7f..193e6495a4182 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py +++ b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py @@ -41,11 +41,12 @@ import torch import torch.nn as nn import torch.optim as optim -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from torch.utils.data.dataset import IterableDataset +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything + class DQN(nn.Module): """Simple MLP network. diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py index 1fb083894c284..af503dbb925cd 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py +++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py @@ -35,12 +35,13 @@ import gym import torch -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything from torch import nn from torch.distributions import Categorical, Normal from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader, IterableDataset +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything + def create_mlp(input_shape: tuple[int], n_actions: int, hidden_size: int = 128): """Simple Multi-Layer Perceptron network.""" diff --git a/examples/pytorch/domain_templates/semantic_segmentation.py b/examples/pytorch/domain_templates/semantic_segmentation.py index 12ecbeeb5f0a9..0f19349f7a0fc 100644 --- a/examples/pytorch/domain_templates/semantic_segmentation.py +++ b/examples/pytorch/domain_templates/semantic_segmentation.py @@ -19,11 +19,12 @@ import torch import torch.nn.functional as F import torchvision.transforms as transforms -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo from PIL import Image from torch import nn from torch.utils.data import DataLoader, Dataset +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo + DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) diff --git a/examples/pytorch/fp8_distributed_transformer/train.py b/examples/pytorch/fp8_distributed_transformer/train.py index 6c7be98ee7dbd..78aa6f13be6c2 100644 --- a/examples/pytorch/fp8_distributed_transformer/train.py +++ b/examples/pytorch/fp8_distributed_transformer/train.py @@ -1,13 +1,14 @@ -import lightning as L import torch import torch.nn as nn import torch.nn.functional as F -from lightning.pytorch.demos import Transformer, WikiText2 -from lightning.pytorch.strategies import ModelParallelStrategy from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.utils.data import DataLoader from torchao.float8 import Float8LinearConfig, convert_to_float8_training +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 +from lightning.pytorch.strategies import ModelParallelStrategy + class LanguageModel(L.LightningModule): def __init__(self, vocab_size): diff --git a/examples/pytorch/hpu/mnist_sample.py b/examples/pytorch/hpu/mnist_sample.py index 4d2e22c03fe7e..0d04074519c8c 100644 --- a/examples/pytorch/hpu/mnist_sample.py +++ b/examples/pytorch/hpu/mnist_sample.py @@ -13,11 +13,12 @@ # limitations under the License. import torch from jsonargparse import lazy_instance +from lightning_habana import HPUPrecisionPlugin +from torch.nn import functional as F + from lightning.pytorch import LightningModule from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule -from lightning_habana import HPUPrecisionPlugin -from torch.nn import functional as F class LitClassifier(LightningModule): diff --git a/examples/pytorch/servable_module/production.py b/examples/pytorch/servable_module/production.py index da0c42d12a865..854ff1176b619 100644 --- a/examples/pytorch/servable_module/production.py +++ b/examples/pytorch/servable_module/production.py @@ -8,11 +8,12 @@ import torch import torchvision import torchvision.transforms as T +from PIL import Image as PILImage + from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.serve import ServableModule, ServableModuleValidator from lightning.pytorch.utilities.model_helpers import get_torchvision_model -from PIL import Image as PILImage DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/examples/pytorch/tensor_parallel/README.md b/examples/pytorch/tensor_parallel/README.md index d8b81b6de1bff..92e6fcb038268 100644 --- a/examples/pytorch/tensor_parallel/README.md +++ b/examples/pytorch/tensor_parallel/README.md @@ -45,5 +45,5 @@ Training successfully completed! Peak memory usage: 36.73 GB ``` -> \[!NOTE\] +> [!NOTE] > The `ModelParallelStrategy` is experimental and subject to change. Report issues on [GitHub](https://github.com/Lightning-AI/pytorch-lightning/issues). diff --git a/examples/pytorch/tensor_parallel/train.py b/examples/pytorch/tensor_parallel/train.py index 6a91e1242e4af..f9b971fd39f82 100644 --- a/examples/pytorch/tensor_parallel/train.py +++ b/examples/pytorch/tensor_parallel/train.py @@ -1,13 +1,14 @@ -import lightning as L import torch import torch.nn.functional as F from data import RandomTokenDataset -from lightning.pytorch.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader +import lightning as L +from lightning.pytorch.strategies import ModelParallelStrategy + class Llama3(L.LightningModule): def __init__(self): diff --git a/requirements/collect_env_details.py b/requirements/collect_env_details.py index 5e6f9ba3dd350..392dd637fa8fd 100644 --- a/requirements/collect_env_details.py +++ b/requirements/collect_env_details.py @@ -70,7 +70,7 @@ def nice_print(details: dict, level: int = 0) -> list: lines += [level * LEVEL_OFFSET + key] lines += [(level + 1) * LEVEL_OFFSET + "- " + v for v in details[k]] else: - template = "{:%is} {}" % KEY_PADDING + template = "{:%is} {}" % KEY_PADDING # noqa: UP031 key_val = template.format(key, details[k]) lines += [(level * LEVEL_OFFSET) + key_val] return lines diff --git a/src/lightning/data/README.md b/src/lightning/data/README.md index 525a7e14f894d..c61e7eacf26f2 100644 --- a/src/lightning/data/README.md +++ b/src/lightning/data/README.md @@ -31,11 +31,11 @@ Find the reproducible [Studio Benchmark](https://lightning.ai/lightning-ai/studi ### Imagenet-1.2M Streaming from AWS S3 -| Framework | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) | Images / sec 1st Epoch (torch16) | Images / sec 2nd Epoch (torch16) | -| ----------- | --------------------------------- | ---------------------------------- | -------------------------------- | -------------------------------- | -| PL Data | **5800.34** | **6589.98** | **6282.17** | **7221.88** | -| Web Dataset | 3134.42 | 3924.95 | 3343.40 | 4424.62 | -| Mosaic ML | 2898.61 | 5099.93 | 2809.69 | 5158.98 | +| Framework | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) | Images / sec 1st Epoch (torch16) | Images / sec 2nd Epoch (torch16) | +| ----------- | -------------------------------- | -------------------------------- | -------------------------------- | -------------------------------- | +| PL Data | **5800.34** | **6589.98** | **6282.17** | **7221.88** | +| Web Dataset | 3134.42 | 3924.95 | 3343.40 | 4424.62 | +| Mosaic ML | 2898.61 | 5099.93 | 2809.69 | 5158.98 | Higher is better. diff --git a/src/lightning/fabric/_graveyard/tpu.py b/src/lightning/fabric/_graveyard/tpu.py index c537ffc032322..138830e4e3b1b 100644 --- a/src/lightning/fabric/_graveyard/tpu.py +++ b/src/lightning/fabric/_graveyard/tpu.py @@ -71,7 +71,7 @@ class TPUPrecision(XLAPrecision): def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision`" " instead." + "The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(precision="32-true") @@ -85,8 +85,7 @@ class XLABf16Precision(XLAPrecision): def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `XLABf16Precision` class is deprecated. Use" - " `lightning.fabric.plugins.precision.XLAPrecision` instead." + "The `XLABf16Precision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(precision="bf16-true") @@ -100,8 +99,7 @@ class TPUBf16Precision(XLABf16Precision): def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `TPUBf16Precision` class is deprecated. Use" - " `lightning.fabric.plugins.precision.XLAPrecision` instead." + "The `TPUBf16Precision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(*args, **kwargs) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 0ade7f69c3629..85d30a07ce207 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -251,8 +251,7 @@ def _check_config_and_set_final_flags( if plugins_flags_types.get(Precision.__name__) and precision_input is not None: raise ValueError( - f"Received both `precision={precision_input}` and `plugins={self._precision_instance}`." - f" Choose one." + f"Received both `precision={precision_input}` and `plugins={self._precision_instance}`. Choose one." ) self._precision_input = "32-true" if precision_input is None else precision_input diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 4af5ec65949c9..41820c1cc433f 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -329,8 +329,7 @@ def setup_module_and_optimizers( """ if len(optimizers) != 1: raise ValueError( - f"Currently only one optimizer is supported with DeepSpeed." - f" Got {len(optimizers)} optimizers instead." + f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." ) self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0]) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index ad1fc19074d06..ace23a9c7a2c5 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -292,8 +292,7 @@ def load_checkpoint( if isinstance(state, Optimizer): raise NotImplementedError( - f"Loading a single optimizer object from a checkpoint is not supported yet with" - f" {type(self).__name__}." + f"Loading a single optimizer object from a checkpoint is not supported yet with {type(self).__name__}." ) return _load_checkpoint(path=path, state=state, strict=strict) diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index 78c4215f9ce23..d108894f614e6 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -145,7 +145,7 @@ def _validate_condition_metric(self, logs: dict[str, Tensor]) -> bool: error_msg = ( f"Early stopping conditioned on metric `{self.monitor}` which is not available." " Pass in or modify your `EarlyStopping` callback to use any of the following:" - f' `{"`, `".join(list(logs.keys()))}`' + f" `{'`, `'.join(list(logs.keys()))}`" ) if monitor_val is None: diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index e1752c67d9183..ced8a6f1f2bd3 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -74,7 +74,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: continue lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] if ( - type(lm_val) != type(dm_val) + type(lm_val) != type(dm_val) # noqa: E721 or (isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val)) or lm_val != dm_val ): diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index d2d380f8788f8..dabfde70242b9 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -400,8 +400,7 @@ def _setup_model_and_optimizers( """ if len(optimizers) != 1: raise ValueError( - f"Currently only one optimizer is supported with DeepSpeed." - f" Got {len(optimizers)} optimizers instead." + f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." ) # train_micro_batch_size_per_gpu is used for throughput logging purposes diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index fdde19aa80eea..0881ac0b3fa08 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -157,7 +157,7 @@ def forked(self) -> bool: def forked_name(self, on_step: bool) -> str: if self.forked: - return f'{self.name}_{"step" if on_step else "epoch"}' + return f"{self.name}_{'step' if on_step else 'epoch'}" return self.name @property diff --git a/src/lightning_fabric/README.md b/src/lightning_fabric/README.md index d842c0d19118d..076800caeb71e 100644 --- a/src/lightning_fabric/README.md +++ b/src/lightning_fabric/README.md @@ -215,7 +215,7 @@ Lightning is rigorously tested across multiple CPUs and GPUs and against major P | System / PyTorch ver. | 1.12 | 1.13 | 2.0 | 2.1 | | :--------------------------------: | :-------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------: | -| Linux py3.9 \[GPUs\] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Flightning-fabric%20%28GPUs%29) | +| Linux py3.9 [GPUs] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Flightning-fabric%20%28GPUs%29) | | Linux (multiple Python versions) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | | OSX (multiple Python versions) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | | Windows (multiple Python versions) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index ae9339dfb2b0d..f3fb8cb2fd2b3 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -79,7 +79,7 @@ Lightning is rigorously tested across multiple CPUs, GPUs and TPUs and against m | System / PyTorch ver. | 1.12 | 1.13 | 2.0 | 2.1 | | :--------------------------------: | :---------------------------------------------------------------------------------------------------------: | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------- | -| Linux py3.9 \[GPUs\] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Fpytorch-lightning%20%28GPUs%29) | +| Linux py3.9 [GPUs] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Fpytorch-lightning%20%28GPUs%29) | | Linux (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | | OSX (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | | Windows (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | diff --git a/tests/legacy/simple_classif_training.py b/tests/legacy/simple_classif_training.py index d2cf4cd2166f3..dd767ee9075f2 100644 --- a/tests/legacy/simple_classif_training.py +++ b/tests/legacy/simple_classif_training.py @@ -14,13 +14,14 @@ import os import sys -import lightning.pytorch as pl import torch -from lightning.pytorch import seed_everything -from lightning.pytorch.callbacks import EarlyStopping from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.simple_models import ClassificationModel +import lightning.pytorch as pl +from lightning.pytorch import seed_everything +from lightning.pytorch.callbacks import EarlyStopping + PATH_LEGACY = os.path.dirname(__file__) diff --git a/tests/parity_fabric/conftest.py b/tests/parity_fabric/conftest.py index ceb19e061c774..9fc6f9d908d81 100644 --- a/tests/parity_fabric/conftest.py +++ b/tests/parity_fabric/conftest.py @@ -17,7 +17,7 @@ import torch.distributed -@pytest.fixture() +@pytest.fixture def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield @@ -25,7 +25,7 @@ def reset_deterministic_algorithm(): torch.use_deterministic_algorithms(False) -@pytest.fixture() +@pytest.fixture def reset_cudnn_benchmark(): """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" yield diff --git a/tests/parity_fabric/test_parity_ddp.py b/tests/parity_fabric/test_parity_ddp.py index 217d401ad6fba..d30d2b6233886 100644 --- a/tests/parity_fabric/test_parity_ddp.py +++ b/tests/parity_fabric/test_parity_ddp.py @@ -18,11 +18,11 @@ import torch import torch.distributed import torch.nn.functional -from lightning.fabric.fabric import Fabric from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from lightning.fabric.fabric import Fabric from parity_fabric.models import ConvNet from parity_fabric.utils import ( cuda_reset, diff --git a/tests/parity_fabric/test_parity_simple.py b/tests/parity_fabric/test_parity_simple.py index a97d39dfa1035..54c0de7297ac5 100644 --- a/tests/parity_fabric/test_parity_simple.py +++ b/tests/parity_fabric/test_parity_simple.py @@ -19,9 +19,9 @@ import torch import torch.distributed import torch.nn.functional -from lightning.fabric.fabric import Fabric from tests_fabric.helpers.runif import RunIf +from lightning.fabric.fabric import Fabric from parity_fabric.models import ConvNet from parity_fabric.utils import ( cuda_reset, diff --git a/tests/parity_fabric/utils.py b/tests/parity_fabric/utils.py index 7f0028dc23421..7d7a14732bb0e 100644 --- a/tests/parity_fabric/utils.py +++ b/tests/parity_fabric/utils.py @@ -14,6 +14,7 @@ import os import torch + from lightning.fabric.accelerators.cuda import _clear_cuda_memory diff --git a/tests/parity_pytorch/__init__.py b/tests/parity_pytorch/__init__.py index 148237ab9c718..6d7cadefc20fa 100644 --- a/tests/parity_pytorch/__init__.py +++ b/tests/parity_pytorch/__init__.py @@ -1,4 +1,5 @@ import pytest + from lightning.pytorch.utilities.testing import _runif_reasons diff --git a/tests/parity_pytorch/models.py b/tests/parity_pytorch/models.py index f55b0d6f1f36e..17cbef6a76faa 100644 --- a/tests/parity_pytorch/models.py +++ b/tests/parity_pytorch/models.py @@ -14,11 +14,12 @@ import torch import torch.nn.functional as F +from tests_pytorch import _PATH_DATASETS +from torch.utils.data import DataLoader + from lightning.pytorch.core.module import LightningModule from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE from lightning.pytorch.utilities.model_helpers import get_torchvision_model -from tests_pytorch import _PATH_DATASETS -from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/tests/parity_pytorch/test_basic_parity.py b/tests/parity_pytorch/test_basic_parity.py index d7f15815c831f..6b413dada8659 100644 --- a/tests/parity_pytorch/test_basic_parity.py +++ b/tests/parity_pytorch/test_basic_parity.py @@ -16,9 +16,9 @@ import numpy as np import pytest import torch -from lightning.pytorch import LightningModule, Trainer, seed_everything from tests_pytorch.helpers.advanced_models import ParityModuleMNIST, ParityModuleRNN +from lightning.pytorch import LightningModule, Trainer, seed_everything from parity_pytorch.measure import measure_loops from parity_pytorch.models import ParityModuleCIFAR diff --git a/tests/parity_pytorch/test_sync_batchnorm_parity.py b/tests/parity_pytorch/test_sync_batchnorm_parity.py index 4d2300cf15670..af22d7470e524 100644 --- a/tests/parity_pytorch/test_sync_batchnorm_parity.py +++ b/tests/parity_pytorch/test_sync_batchnorm_parity.py @@ -14,9 +14,9 @@ import torch import torch.nn as nn -from lightning.pytorch import LightningModule, Trainer, seed_everything from torch.utils.data import DataLoader, DistributedSampler +from lightning.pytorch import LightningModule, Trainer, seed_everything from parity_pytorch import RunIf diff --git a/tests/tests_fabric/accelerators/test_cpu.py b/tests/tests_fabric/accelerators/test_cpu.py index 5efb5d6afddbc..7c7029f9ec9e3 100644 --- a/tests/tests_fabric/accelerators/test_cpu.py +++ b/tests/tests_fabric/accelerators/test_cpu.py @@ -14,6 +14,7 @@ import pytest import torch + from lightning.fabric.accelerators.cpu import CPUAccelerator, _parse_cpu_cores diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py index 0aed3675d93e1..037eb8d400825 100644 --- a/tests/tests_fabric/accelerators/test_cuda.py +++ b/tests/tests_fabric/accelerators/test_cuda.py @@ -18,15 +18,15 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest import torch + +import lightning.fabric from lightning.fabric.accelerators.cuda import ( CUDAAccelerator, _check_cuda_matmul_precision, find_usable_cuda_devices, ) - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/accelerators/test_mps.py b/tests/tests_fabric/accelerators/test_mps.py index 20dc9b6a93581..612bd8f74a640 100644 --- a/tests/tests_fabric/accelerators/test_mps.py +++ b/tests/tests_fabric/accelerators/test_mps.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest import torch + from lightning.fabric.accelerators.mps import MPSAccelerator from lightning.fabric.utilities.exceptions import MisconfigurationException - from tests_fabric.helpers.runif import RunIf _MAYBE_MPS = "mps" if MPSAccelerator.is_available() else "cpu" diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 2544df1e01ff8..28bfbb8ffd97c 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -14,6 +14,7 @@ from typing import Any import torch + from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator diff --git a/tests/tests_fabric/accelerators/test_xla.py b/tests/tests_fabric/accelerators/test_xla.py index 7a906c8ae0c54..95d5cab90a5f2 100644 --- a/tests/tests_fabric/accelerators/test_xla.py +++ b/tests/tests_fabric/accelerators/test_xla.py @@ -13,8 +13,8 @@ # limitations under the License import pytest -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, XLAAccelerator +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, XLAAccelerator from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index dd272257b3923..ad16c96305135 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -18,9 +18,10 @@ from pathlib import Path from unittest.mock import Mock -import lightning.fabric import pytest import torch.distributed + +import lightning.fabric from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection @@ -29,9 +30,10 @@ @pytest.fixture(autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" - from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities + from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric + functions = (rank_zero_only_fabric, rank_zero_only_utilities) ranks = [getattr(fn, "rank", None) for fn in functions] yield @@ -126,7 +128,7 @@ def reset_in_fabric_backward(): wrappers._in_fabric_backward = False -@pytest.fixture() +@pytest.fixture def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield @@ -134,7 +136,7 @@ def reset_deterministic_algorithm(): torch.use_deterministic_algorithms(False) -@pytest.fixture() +@pytest.fixture def reset_cudnn_benchmark(): """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" yield @@ -155,7 +157,7 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock()) -@pytest.fixture() +@pytest.fixture def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_xla_available(monkeypatch) @@ -166,12 +168,12 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) -@pytest.fixture() +@pytest.fixture def tpu_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_tpu_available(monkeypatch) -@pytest.fixture() +@pytest.fixture def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. diff --git a/tests/tests_fabric/helpers/runif.py b/tests/tests_fabric/helpers/runif.py index 813c4f93b1788..23a620295bcbf 100644 --- a/tests/tests_fabric/helpers/runif.py +++ b/tests/tests_fabric/helpers/runif.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.fabric.utilities.testing import _runif_reasons diff --git a/tests/tests_fabric/loggers/test_csv.py b/tests/tests_fabric/loggers/test_csv.py index 784bb6fb45aba..f03e8c5125e54 100644 --- a/tests/tests_fabric/loggers/test_csv.py +++ b/tests/tests_fabric/loggers/test_csv.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.fabric.loggers import CSVLogger from lightning.fabric.loggers.csv_logs import _ExperimentWriter diff --git a/tests/tests_fabric/loggers/test_tensorboard.py b/tests/tests_fabric/loggers/test_tensorboard.py index fa685241ea1b5..4dcb86f0e7406 100644 --- a/tests/tests_fabric/loggers/test_tensorboard.py +++ b/tests/tests_fabric/loggers/test_tensorboard.py @@ -19,10 +19,10 @@ import numpy as np import pytest import torch + from lightning.fabric.loggers import TensorBoardLogger from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.fabric.wrappers import _FabricModule - from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/plugins/collectives/test_single_device.py b/tests/tests_fabric/plugins/collectives/test_single_device.py index e7aefdb6078b1..a5a909da64b8a 100644 --- a/tests/tests_fabric/plugins/collectives/test_single_device.py +++ b/tests/tests_fabric/plugins/collectives/test_single_device.py @@ -1,6 +1,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.collectives import SingleDeviceCollective diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index b4c223e770282..5cef70f3b91ba 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -6,12 +6,12 @@ import pytest import torch + from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.ddp import DDPStrategy from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher - from tests_fabric.helpers.runif import RunIf if TorchCollective.is_available(): diff --git a/tests/tests_fabric/plugins/environments/test_kubeflow.py b/tests/tests_fabric/plugins/environments/test_kubeflow.py index 3c44273825510..3436adc9ce2aa 100644 --- a/tests/tests_fabric/plugins/environments/test_kubeflow.py +++ b/tests/tests_fabric/plugins/environments/test_kubeflow.py @@ -16,6 +16,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import KubeflowEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_lightning.py b/tests/tests_fabric/plugins/environments/test_lightning.py index 02100329d173c..cc0179af4c3f5 100644 --- a/tests/tests_fabric/plugins/environments/test_lightning.py +++ b/tests/tests_fabric/plugins/environments/test_lightning.py @@ -15,6 +15,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import LightningEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_lsf.py b/tests/tests_fabric/plugins/environments/test_lsf.py index 4e60d968dc953..31cc5976cfe09 100644 --- a/tests/tests_fabric/plugins/environments/test_lsf.py +++ b/tests/tests_fabric/plugins/environments/test_lsf.py @@ -15,6 +15,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import LSFEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_mpi.py b/tests/tests_fabric/plugins/environments/test_mpi.py index 649d4dcb1dab2..3df0000cf2766 100644 --- a/tests/tests_fabric/plugins/environments/test_mpi.py +++ b/tests/tests_fabric/plugins/environments/test_mpi.py @@ -16,8 +16,9 @@ from unittest import mock from unittest.mock import MagicMock -import lightning.fabric.plugins.environments.mpi import pytest + +import lightning.fabric.plugins.environments.mpi from lightning.fabric.plugins.environments import MPIEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py index f237478a533f4..75ca43577d579 100644 --- a/tests/tests_fabric/plugins/environments/test_slurm.py +++ b/tests/tests_fabric/plugins/environments/test_slurm.py @@ -18,10 +18,10 @@ from unittest import mock import pytest -from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning_utilities.test.warning import no_warning_call +from lightning.fabric.plugins.environments import SLURMEnvironment +from lightning.fabric.utilities.warnings import PossibleUserWarning from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/environments/test_torchelastic.py b/tests/tests_fabric/plugins/environments/test_torchelastic.py index 3cf8619a00c22..161d42894df30 100644 --- a/tests/tests_fabric/plugins/environments/test_torchelastic.py +++ b/tests/tests_fabric/plugins/environments/test_torchelastic.py @@ -17,6 +17,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import TorchElasticEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index 7a3610b65b4bb..7e33d5db87dd4 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -14,11 +14,11 @@ import os from unittest import mock -import lightning.fabric import pytest + +import lightning.fabric from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 from lightning.fabric.plugins.environments import XLAEnvironment - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_all.py b/tests/tests_fabric/plugins/precision/test_all.py index 5e86a35647489..94e5efaa74eed 100644 --- a/tests/tests_fabric/plugins/precision/test_all.py +++ b/tests/tests_fabric/plugins/precision/test_all.py @@ -1,5 +1,6 @@ import pytest import torch + from lightning.fabric.plugins import DeepSpeedPrecision, DoublePrecision, FSDPPrecision, HalfPrecision diff --git a/tests/tests_fabric/plugins/precision/test_amp.py b/tests/tests_fabric/plugins/precision/test_amp.py index 93d53eb406f71..73507f085936b 100644 --- a/tests/tests_fabric/plugins/precision/test_amp.py +++ b/tests/tests_fabric/plugins/precision/test_amp.py @@ -16,6 +16,7 @@ import pytest import torch + from lightning.fabric.plugins.precision.amp import MixedPrecision from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py index aa6c6cfce4504..bcbd9435d47ac 100644 --- a/tests/tests_fabric/plugins/precision/test_amp_integration.py +++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py @@ -16,9 +16,9 @@ import pytest import torch import torch.nn as nn + from lightning.fabric import Fabric, seed_everything from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index b8b9020b201a7..152f9a1c01fe9 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -14,16 +14,16 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed + +import lightning.fabric from lightning.fabric import Fabric from lightning.fabric.connector import _Connector from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.utilities.init import _materialize_meta_tensors from lightning.fabric.utilities.load import _lazy_load - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed.py b/tests/tests_fabric/plugins/precision/test_deepspeed.py index 248f616646842..170f15afaa2ad 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed.py @@ -16,9 +16,9 @@ import pytest import torch + from lightning.fabric.plugins.precision.deepspeed import DeepSpeedPrecision from lightning.fabric.utilities.types import Steppable - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py index 27e44398fc095..e989534343b8c 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py @@ -14,10 +14,10 @@ from unittest import mock import pytest + from lightning.fabric.connector import _Connector from lightning.fabric.plugins import DeepSpeedPrecision from lightning.fabric.strategies import DeepSpeedStrategy - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_double.py b/tests/tests_fabric/plugins/precision/test_double.py index 97d4a2303100e..4921e0f4e659b 100644 --- a/tests/tests_fabric/plugins/precision/test_double.py +++ b/tests/tests_fabric/plugins/precision/test_double.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch + from lightning.fabric.plugins.precision.double import DoublePrecision diff --git a/tests/tests_fabric/plugins/precision/test_double_integration.py b/tests/tests_fabric/plugins/precision/test_double_integration.py index 6701bc1a80e59..8f96f75ad1c67 100644 --- a/tests/tests_fabric/plugins/precision/test_double_integration.py +++ b/tests/tests_fabric/plugins/precision/test_double_integration.py @@ -15,8 +15,8 @@ import torch import torch.nn as nn -from lightning.fabric import Fabric +from lightning.fabric import Fabric from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index 6a4968736ea86..3b8d916e20c8f 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -15,9 +15,9 @@ import pytest import torch + from lightning.fabric.plugins import FSDPPrecision from lightning.fabric.plugins.precision.utils import _DtypeContextManager - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_half.py b/tests/tests_fabric/plugins/precision/test_half.py index 4037feebbd178..00d23df4ae5b6 100644 --- a/tests/tests_fabric/plugins/precision/test_half.py +++ b/tests/tests_fabric/plugins/precision/test_half.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from lightning.fabric.plugins.precision import HalfPrecision diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py index c003715dead8a..033484aca9c90 100644 --- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py +++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py @@ -14,10 +14,11 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed + +import lightning.fabric from lightning.fabric.connector import _Connector from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision diff --git a/tests/tests_fabric/plugins/precision/test_utils.py b/tests/tests_fabric/plugins/precision/test_utils.py index 74899c86e9e2d..6e459c3fe6637 100644 --- a/tests/tests_fabric/plugins/precision/test_utils.py +++ b/tests/tests_fabric/plugins/precision/test_utils.py @@ -1,5 +1,6 @@ import pytest import torch + from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager, _DtypeContextManager diff --git a/tests/tests_fabric/plugins/precision/test_xla.py b/tests/tests_fabric/plugins/precision/test_xla.py index 0cdc11b00b99a..cfdc32112a957 100644 --- a/tests/tests_fabric/plugins/precision/test_xla.py +++ b/tests/tests_fabric/plugins/precision/test_xla.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.fabric.plugins import XLAPrecision diff --git a/tests/tests_fabric/plugins/precision/test_xla_integration.py b/tests/tests_fabric/plugins/precision/test_xla_integration.py index 14a5cd1442e4a..75dede49e2fe0 100644 --- a/tests/tests_fabric/plugins/precision/test_xla_integration.py +++ b/tests/tests_fabric/plugins/precision/test_xla_integration.py @@ -17,9 +17,9 @@ import pytest import torch import torch.nn as nn + from lightning.fabric import Fabric from lightning.fabric.plugins import XLAPrecision - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py index 6c595fba7acab..5bb85e070f17d 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py @@ -16,8 +16,8 @@ import pytest import torch -from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher +from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py index c3bc7d3c2c6cd..6ae96b9bcafc6 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py @@ -15,8 +15,8 @@ import pytest import torch import torch.nn as nn -from lightning.fabric import Fabric +from lightning.fabric import Fabric from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/launchers/test_subprocess_script.py b/tests/tests_fabric/strategies/launchers/test_subprocess_script.py index a2d04e29bc0d6..70587ca2877a6 100644 --- a/tests/tests_fabric/strategies/launchers/test_subprocess_script.py +++ b/tests/tests_fabric/strategies/launchers/test_subprocess_script.py @@ -17,8 +17,9 @@ from unittest import mock from unittest.mock import ANY, Mock -import lightning.fabric import pytest + +import lightning.fabric from lightning.fabric.strategies.launchers.subprocess_script import ( _HYDRA_AVAILABLE, _ChildProcessObserver, diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index b98d5f8226dc2..fa5c975228a5e 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -19,12 +19,12 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel + from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl -from torch.nn.parallel import DistributedDataParallel - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py index a7ed09b00b09e..70dd25aa99603 100644 --- a/tests/tests_fabric/strategies/test_ddp_integration.py +++ b/tests/tests_fabric/strategies/test_ddp_integration.py @@ -18,11 +18,11 @@ import pytest import torch -from lightning.fabric import Fabric from lightning_utilities.core.imports import RequirementCache from torch._dynamo import OptimizedModule from torch.nn.parallel.distributed import DistributedDataParallel +from lightning.fabric import Fabric from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _run_test_clip_gradients from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 4ee87b265b086..032ee63cd4721 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -18,14 +18,14 @@ import pytest import torch -from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator -from lightning.fabric.strategies import DeepSpeedStrategy from torch.optim import Optimizer +from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator +from lightning.fabric.strategies import DeepSpeedStrategy from tests_fabric.helpers.runif import RunIf -@pytest.fixture() +@pytest.fixture def deepspeed_config(): return { "optimizer": {"type": "SGD", "params": {"lr": 3e-5}}, @@ -36,7 +36,7 @@ def deepspeed_config(): } -@pytest.fixture() +@pytest.fixture def deepspeed_zero_config(deepspeed_config): return {**deepspeed_config, "zero_allow_untested_optimizer": True, "zero_optimization": {"stage": 2}} diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 4811599ed05ab..5970b673cee5f 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -20,11 +20,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader + from lightning.fabric import Fabric from lightning.fabric.plugins import DeepSpeedPrecision from lightning.fabric.strategies import DeepSpeedStrategy -from torch.utils.data import DataLoader - from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset from tests_fabric.helpers.runif import RunIf from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/strategies/test_dp.py b/tests/tests_fabric/strategies/test_dp.py index e50abb1882870..ff470c646f4b3 100644 --- a/tests/tests_fabric/strategies/test_dp.py +++ b/tests/tests_fabric/strategies/test_dp.py @@ -16,9 +16,9 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.strategies import DataParallelStrategy - from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _run_test_clip_gradients diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index cb6542cdb6243..d5f82752a9176 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -19,6 +19,10 @@ import pytest import torch import torch.nn as nn +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.optim import Adam + from lightning.fabric.plugins import HalfPrecision from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import FSDPStrategy @@ -28,9 +32,6 @@ _is_sharded_checkpoint, ) from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision -from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from torch.optim import Adam def test_custom_mixed_precision(): diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 0697c3043d496..11a7a1a6f8f7f 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -20,17 +20,17 @@ import pytest import torch import torch.nn as nn -from lightning.fabric import Fabric -from lightning.fabric.plugins import FSDPPrecision -from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.utilities.load import _load_distributed_checkpoint -from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo import OptimizedModule from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType from torch.distributed.fsdp.wrap import always_wrap_policy, wrap from torch.nn import Parameter from torch.utils.data import DataLoader +from lightning.fabric import Fabric +from lightning.fabric.plugins import FSDPPrecision +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities.load import _load_distributed_checkpoint +from lightning.fabric.wrappers import _FabricOptimizer from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py index 1f8b5b783b73e..78622adf66fa6 100644 --- a/tests/tests_fabric/strategies/test_model_parallel.py +++ b/tests/tests_fabric/strategies/test_model_parallel.py @@ -19,12 +19,12 @@ import pytest import torch import torch.nn as nn +from torch.optim import Adam + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import ModelParallelStrategy from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint from lightning.fabric.strategies.model_parallel import _ParallelBackwardSyncControl -from torch.optim import Adam - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index b04a29b691529..bddfadd9a2c54 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -20,16 +20,16 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader, DistributedSampler + from lightning.fabric import Fabric from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state from lightning.fabric.utilities.load import _load_distributed_checkpoint -from torch.utils.data import DataLoader, DistributedSampler - from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf -@pytest.fixture() +@pytest.fixture def distributed(): yield if torch.distributed.is_initialized(): diff --git a/tests/tests_fabric/strategies/test_single_device.py b/tests/tests_fabric/strategies/test_single_device.py index 95ed9787f40a2..fff7175909222 100644 --- a/tests/tests_fabric/strategies/test_single_device.py +++ b/tests/tests_fabric/strategies/test_single_device.py @@ -15,9 +15,9 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.strategies import SingleDeviceStrategy - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_strategy.py b/tests/tests_fabric/strategies/test_strategy.py index a7a1dba87cb97..37ccbea5e6c95 100644 --- a/tests/tests_fabric/strategies/test_strategy.py +++ b/tests/tests_fabric/strategies/test_strategy.py @@ -16,10 +16,10 @@ import pytest import torch + from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.strategies import SingleDeviceStrategy from lightning.fabric.utilities.types import _Stateful - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py index f711eb3470b45..a260b3f231e1d 100644 --- a/tests/tests_fabric/strategies/test_xla.py +++ b/tests/tests_fabric/strategies/test_xla.py @@ -18,13 +18,13 @@ import pytest import torch +from torch.utils.data import DataLoader + from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1, XLAAccelerator from lightning.fabric.strategies import XLAStrategy from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.utilities.distributed import ReduceOp from lightning.fabric.utilities.seed import seed_everything -from torch.utils.data import DataLoader - from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index 879a55cf77f34..c2634283ad110 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -18,12 +18,12 @@ import pytest import torch.nn import torch.nn as nn +from torch.optim import Adam + from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.plugins import XLAPrecision from lightning.fabric.strategies import XLAFSDPStrategy from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl -from torch.optim import Adam - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py index 20c2ef042272e..b77803744b6c4 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py @@ -18,10 +18,10 @@ import pytest import torch -from lightning.fabric import Fabric -from lightning.fabric.strategies import XLAFSDPStrategy from torch.utils.data import DataLoader +from lightning.fabric import Fabric +from lightning.fabric.strategies import XLAFSDPStrategy from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index a57f413ff6081..ec3160111e6ee 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -20,12 +20,12 @@ from unittest.mock import Mock import pytest -from lightning.fabric.cli import _consolidate, _get_supported_strategies, _run +from lightning.fabric.cli import _consolidate, _get_supported_strategies, _run from tests_fabric.helpers.runif import RunIf -@pytest.fixture() +@pytest.fixture def fake_script(tmp_path): script = tmp_path / "script.py" script.touch() @@ -184,8 +184,7 @@ def test_run_through_lightning_entry_point(): result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True) deprecation_message = ( - "`lightning run model` is deprecated and will be removed in future versions. " - "Please call `fabric run` instead" + "`lightning run model` is deprecated and will be removed in future versions. Please call `fabric run` instead" ) message = "Usage: lightning run [OPTIONS] SCRIPT [SCRIPT_ARGS]" assert deprecation_message in result.stdout diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 8a6e9206b3df5..9bb9fa1d7d145 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -19,10 +19,12 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed +from lightning_utilities.test.warning import no_warning_call + +import lightning.fabric from lightning.fabric import Fabric from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.accelerators.accelerator import Accelerator @@ -63,8 +65,6 @@ from lightning.fabric.strategies.ddp import _DDP_FORK_ALIASES from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.fabric.utilities.imports import _IS_WINDOWS -from lightning_utilities.test.warning import no_warning_call - from tests_fabric.conftest import mock_tpu_available from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 7bb6b29eceaf2..ee002b5d8061c 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -17,11 +17,15 @@ from unittest import mock from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call -import lightning.fabric import pytest import torch import torch.distributed import torch.nn.functional +from lightning_utilities.test.warning import no_warning_call +from torch import nn +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset + +import lightning.fabric from lightning.fabric.fabric import Fabric from lightning.fabric.strategies import ( DataParallelStrategy, @@ -37,10 +41,6 @@ from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer -from lightning_utilities.test.warning import no_warning_call -from torch import nn -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 363da022d285b..de8ae208e4dc4 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -16,6 +16,10 @@ import pytest import torch +from torch._dynamo import OptimizedModule +from torch.utils.data import BatchSampler, DistributedSampler +from torch.utils.data.dataloader import DataLoader + from lightning.fabric.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin @@ -28,10 +32,6 @@ _unwrap_objects, is_wrapped, ) -from torch._dynamo import OptimizedModule -from torch.utils.data import BatchSampler, DistributedSampler -from torch.utils.data.dataloader import DataLoader - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_apply_func.py b/tests/tests_fabric/utilities/test_apply_func.py index 9e137561aa525..055fa89101c96 100644 --- a/tests/tests_fabric/utilities/test_apply_func.py +++ b/tests/tests_fabric/utilities/test_apply_func.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch -from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars, move_data_to_device from torch import Tensor +from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars, move_data_to_device + @pytest.mark.parametrize("should_return", [False, True]) def test_wrongly_implemented_transferable_data_type(should_return): diff --git a/tests/tests_fabric/utilities/test_cloud_io.py b/tests/tests_fabric/utilities/test_cloud_io.py index d502199da1493..e1333ddff87f3 100644 --- a/tests/tests_fabric/utilities/test_cloud_io.py +++ b/tests/tests_fabric/utilities/test_cloud_io.py @@ -16,6 +16,7 @@ import fsspec from fsspec.implementations.local import LocalFileSystem from fsspec.spec import AbstractFileSystem + from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem diff --git a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py index 2584aab8bdc2e..78690a9870982 100644 --- a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py +++ b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py @@ -16,8 +16,9 @@ from pathlib import Path from unittest import mock -import lightning.fabric import pytest + +import lightning.fabric from lightning.fabric.utilities.consolidate_checkpoint import _parse_cli_args, _process_cli_args from lightning.fabric.utilities.load import _METADATA_FILENAME diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 656b9cac3d77e..faff6e182a06f 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -4,10 +4,14 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import BatchSampler, DataLoader, RandomSampler + +import lightning.fabric from lightning.fabric.utilities.data import ( AttributeDict, _get_dataloader_init_args_and_kwargs, @@ -21,10 +25,6 @@ suggested_max_num_workers, ) from lightning.fabric.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, RandomSampler - from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset diff --git a/tests/tests_fabric/utilities/test_device_dtype_mixin.py b/tests/tests_fabric/utilities/test_device_dtype_mixin.py index 9958e48c624ee..1261ca5e0accb 100644 --- a/tests/tests_fabric/utilities/test_device_dtype_mixin.py +++ b/tests/tests_fabric/utilities/test_device_dtype_mixin.py @@ -1,8 +1,8 @@ import pytest import torch -from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from torch import nn as nn +from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_device_parser.py b/tests/tests_fabric/utilities/test_device_parser.py index 9b5a09e370860..f9f9e49cf6afe 100644 --- a/tests/tests_fabric/utilities/test_device_parser.py +++ b/tests/tests_fabric/utilities/test_device_parser.py @@ -14,6 +14,7 @@ from unittest import mock import pytest + from lightning.fabric.utilities import device_parser from lightning.fabric.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index f5a78a1529a52..3450ed89f6cc7 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -5,9 +5,11 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest import torch +from lightning_utilities.core.imports import RequirementCache + +import lightning.fabric from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy @@ -23,8 +25,6 @@ _sync_ddp, is_shared_filesystem, ) -from lightning_utilities.core.imports import RequirementCache - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_init.py b/tests/tests_fabric/utilities/test_init.py index dd08dec020669..69758c5a3e17e 100644 --- a/tests/tests_fabric/utilities/test_init.py +++ b/tests/tests_fabric/utilities/test_init.py @@ -16,12 +16,12 @@ import pytest import torch.nn + from lightning.fabric.utilities.init import ( _EmptyInit, _has_meta_device_parameters_or_buffers, _materialize_meta_tensors, ) - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index 39d257f8b685b..ed38aa2459af7 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -14,6 +14,7 @@ import pytest import torch import torch.nn as nn + from lightning.fabric.utilities.load import ( _lazy_load, _materialize_tensors, @@ -55,7 +56,7 @@ def test_lazy_load_tensor(tmp_path): for t0, t1 in zip(expected.values(), loaded.values()): assert isinstance(t1, _NotYetLoadedTensor) t1_materialized = _materialize_tensors(t1) - assert type(t0) == type(t1_materialized) + assert type(t0) == type(t1_materialized) # noqa: E721 assert torch.equal(t0, t1_materialized) @@ -91,7 +92,7 @@ def test_materialize_tensors(tmp_path): loaded = _lazy_load(tmp_path / "tensor.pt") materialized = _materialize_tensors(loaded) assert torch.equal(materialized, tensor) - assert type(tensor) == type(materialized) + assert type(tensor) == type(materialized) # noqa: E721 # Collection of tensors collection = { diff --git a/tests/tests_fabric/utilities/test_logger.py b/tests/tests_fabric/utilities/test_logger.py index 0f6500cb42be1..26823143102a7 100644 --- a/tests/tests_fabric/utilities/test_logger.py +++ b/tests/tests_fabric/utilities/test_logger.py @@ -17,6 +17,7 @@ import numpy as np import torch + from lightning.fabric.utilities.logger import ( _add_prefix, _convert_json_serializable, diff --git a/tests/tests_fabric/utilities/test_optimizer.py b/tests/tests_fabric/utilities/test_optimizer.py index 83c7ed44120b9..d96c58049ed3a 100644 --- a/tests/tests_fabric/utilities/test_optimizer.py +++ b/tests/tests_fabric/utilities/test_optimizer.py @@ -2,9 +2,9 @@ import pytest import torch -from lightning.fabric.utilities.optimizer import _optimizer_to_device from torch import Tensor +from lightning.fabric.utilities.optimizer import _optimizer_to_device from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_rank_zero.py b/tests/tests_fabric/utilities/test_rank_zero.py index 0c1b39fe9d9b8..b8ea54f90625d 100644 --- a/tests/tests_fabric/utilities/test_rank_zero.py +++ b/tests/tests_fabric/utilities/test_rank_zero.py @@ -3,6 +3,7 @@ from unittest import mock import pytest + from lightning.fabric.utilities.rank_zero import _get_rank diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index be2ecba3294b1..0973709bf84fd 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -6,6 +6,7 @@ import numpy import pytest import torch + from lightning.fabric.utilities.seed import ( _collect_rng_states, _set_rng_states, diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py index 9739540af7f18..6054bf224d3df 100644 --- a/tests/tests_fabric/utilities/test_spike.py +++ b/tests/tests_fabric/utilities/test_spike.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index d410d0766d97b..00dafbb72cb8f 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.utilities.throughput import ( @@ -12,7 +13,6 @@ get_available_flops, measure_flops, ) - from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/utilities/test_warnings.py b/tests/tests_fabric/utilities/test_warnings.py index b7989d85b5932..bfccaaa8481d9 100644 --- a/tests/tests_fabric/utilities/test_warnings.py +++ b/tests/tests_fabric/utilities/test_warnings.py @@ -27,16 +27,17 @@ from pathlib import Path from unittest import mock -import lightning.fabric import pytest +from lightning_utilities.core.rank_zero import WarningCache, _warn +from lightning_utilities.test.warning import no_warning_call + +import lightning.fabric from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from lightning.fabric.utilities.warnings import ( PossibleUserWarning, _is_path_in_lightning, disable_possible_user_warnings, ) -from lightning_utilities.core.rank_zero import WarningCache, _warn -from lightning_utilities.test.warning import no_warning_call def line_number(): diff --git a/tests/tests_pytorch/__init__.py b/tests/tests_pytorch/__init__.py index a43ffae6a83b4..efbfa8bb14c76 100644 --- a/tests/tests_pytorch/__init__.py +++ b/tests/tests_pytorch/__init__.py @@ -25,7 +25,7 @@ # todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages if str(_PROJECT_ROOT) not in os.getenv("PYTHONPATH", ""): splitter = ":" if os.environ.get("PYTHONPATH", "") else "" - os.environ["PYTHONPATH"] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}' + os.environ["PYTHONPATH"] = f"{_PROJECT_ROOT}{splitter}{os.environ.get('PYTHONPATH', '')}" # Ignore cleanup warnings from pytest (rarely happens due to a race condition when executing pytest in parallel) warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*") diff --git a/tests/tests_pytorch/accelerators/test_common.py b/tests/tests_pytorch/accelerators/test_common.py index 6967bffd9ffa2..42fd66a247d1d 100644 --- a/tests/tests_pytorch/accelerators/test_common.py +++ b/tests/tests_pytorch/accelerators/test_common.py @@ -14,6 +14,7 @@ from typing import Any import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators import Accelerator from lightning.pytorch.strategies import DDPStrategy diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index 844556064621d..ec2aabf559dc7 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -3,16 +3,16 @@ from typing import Any, Union from unittest.mock import Mock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.fabric.plugins import TorchCheckpointIO from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.strategies import SingleDeviceStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/accelerators/test_gpu.py b/tests/tests_pytorch/accelerators/test_gpu.py index e175f8aa7647c..5a71887e17eec 100644 --- a/tests/tests_pytorch/accelerators/test_gpu.py +++ b/tests/tests_pytorch/accelerators/test_gpu.py @@ -15,11 +15,11 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.accelerators.cuda import get_nvidia_gpu_stats from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/accelerators/test_mps.py b/tests/tests_pytorch/accelerators/test_mps.py index 73d0785f9592d..c0a28840f0ef6 100644 --- a/tests/tests_pytorch/accelerators/test_mps.py +++ b/tests/tests_pytorch/accelerators/test_mps.py @@ -16,11 +16,11 @@ import pytest import torch + +import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer from lightning.pytorch.accelerators import MPSAccelerator from lightning.pytorch.demos.boring_classes import BoringModel - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py index 48b346006786f..6322a36a6cbc1 100644 --- a/tests/tests_pytorch/accelerators/test_xla.py +++ b/tests/tests_pytorch/accelerators/test_xla.py @@ -17,9 +17,12 @@ from unittest import mock from unittest.mock import MagicMock, call, patch -import lightning.fabric import pytest import torch +from torch import nn +from torch.utils.data import DataLoader + +import lightning.fabric from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator @@ -27,9 +30,6 @@ from lightning.pytorch.plugins import Precision, XLACheckpointIO, XLAPrecision from lightning.pytorch.strategies import DDPStrategy, XLAStrategy from lightning.pytorch.utilities import find_shared_parameters -from torch import nn -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.trainer.connectors.test_accelerator_connector import DeviceMock from tests_pytorch.trainer.optimization.test_manual_optimization import assert_emtpy_grad diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 89c1effe839a8..430fb9842cddc 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -17,14 +17,15 @@ from unittest.mock import DEFAULT, Mock import pytest +from tests_pytorch.helpers.runif import RunIf +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ProgressBar, RichProgressBar from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.logger import DummyLogger -from tests_pytorch.helpers.runif import RunIf -from torch.utils.data import DataLoader @RunIf(rich=True) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index d5187d5a1e325..d93bf1cf60e9c 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -22,6 +22,9 @@ import pytest import torch +from tests_pytorch.helpers.runif import RunIf +from torch.utils.data.dataloader import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint, ProgressBar, TQDMProgressBar from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm @@ -30,8 +33,6 @@ from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.logger import DummyLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.runif import RunIf -from torch.utils.data.dataloader import DataLoader class MockTqdm(Tqdm): diff --git a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py b/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py index 8f32c756881da..366a924a5867c 100644 --- a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py +++ b/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/callbacks/test_callbacks.py b/tests/tests_pytorch/callbacks/test_callbacks.py index 38b9428526505..53ea109b6ddf3 100644 --- a/tests/tests_pytorch/callbacks/test_callbacks.py +++ b/tests/tests_pytorch/callbacks/test_callbacks.py @@ -16,10 +16,11 @@ from unittest.mock import Mock import pytest +from lightning_utilities.test.warning import no_warning_call + from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel -from lightning_utilities.test.warning import no_warning_call def test_callbacks_configured_in_model(tmp_path): diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index f1d999f1df61a..290a0921cb06d 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -20,6 +20,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators.cpu import _CPU_PERCENT, _CPU_SWAP_PERCENT, _CPU_VM_PERCENT, get_cpu_stats from lightning.pytorch.callbacks import DeviceStatsMonitor @@ -28,7 +29,6 @@ from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index a3d56bb0135c3..c11aa37b456dc 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -22,11 +22,11 @@ import cloudpickle import pytest import torch + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index 0c09ae5d5042a..07343c1ecc12a 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -15,14 +15,14 @@ import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 -from lightning.pytorch import LightningModule, Trainer, seed_everything -from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from torch import nn from torch.optim import SGD, Optimizer from torch.utils.data import DataLoader +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py b/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py index cc584b2da624c..9ad9759ff248c 100644 --- a/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py +++ b/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py @@ -15,6 +15,7 @@ from unittest.mock import Mock, patch import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import GradientAccumulationScheduler from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/callbacks/test_lambda_function.py b/tests/tests_pytorch/callbacks/test_lambda_function.py index 40d694bb35ebc..2b5e025653940 100644 --- a/tests/tests_pytorch/callbacks/test_lambda_function.py +++ b/tests/tests_pytorch/callbacks/test_lambda_function.py @@ -14,10 +14,10 @@ from functools import partial import pytest + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import Callback, LambdaCallback from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.models.test_hooks import get_members diff --git a/tests/tests_pytorch/callbacks/test_lr_monitor.py b/tests/tests_pytorch/callbacks/test_lr_monitor.py index 4aedb4f23fa14..4f3a20ce82de5 100644 --- a/tests/tests_pytorch/callbacks/test_lr_monitor.py +++ b/tests/tests_pytorch/callbacks/test_lr_monitor.py @@ -13,6 +13,8 @@ # limitations under the License. import pytest import torch +from torch import optim + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor from lightning.pytorch.callbacks.callback import Callback @@ -20,8 +22,6 @@ from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import optim - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -44,9 +44,9 @@ def test_lr_monitor_single_lr(tmp_path): assert lr_monitor.lrs, "No learning rates logged" assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default" - assert all( - v is None for v in lr_monitor.last_weight_decay_values.values() - ), "Weight decay should not be logged by default" + assert all(v is None for v in lr_monitor.last_weight_decay_values.values()), ( + "Weight decay should not be logged by default" + ) assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-SGD"] @@ -87,9 +87,9 @@ def configure_optimizers(self): assert len(lr_monitor.last_momentum_values) == len(trainer.lr_scheduler_configs) assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values) - assert all( - v is not None for v in lr_monitor.last_weight_decay_values.values() - ), "Expected weight decay to be logged" + assert all(v is not None for v in lr_monitor.last_weight_decay_values.values()), ( + "Expected weight decay to be logged" + ) assert len(lr_monitor.last_weight_decay_values) == len(trainer.lr_scheduler_configs) assert all(k == f"lr-{opt}-weight_decay" for k in lr_monitor.last_weight_decay_values) diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py index 02604f5a195fe..7249d5739686b 100644 --- a/tests/tests_pytorch/callbacks/test_prediction_writer.py +++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py @@ -14,11 +14,12 @@ from unittest.mock import ANY, Mock, call import pytest +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import BasePredictionWriter from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader class DummyPredictionWriter(BasePredictionWriter): diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index f3ec7e2ccc029..d70ab68b78b32 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -19,13 +19,13 @@ import pytest import torch import torch.nn.utils.prune as pytorch_prune +from torch import nn +from torch.nn import Sequential + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint, ModelPruning from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.nn import Sequential - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_rich_model_summary.py b/tests/tests_pytorch/callbacks/test_rich_model_summary.py index 73709fd80a833..7534c23d5679c 100644 --- a/tests/tests_pytorch/callbacks/test_rich_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_rich_model_summary.py @@ -16,11 +16,11 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import RichModelSummary, RichProgressBar from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.model_summary import summarize - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py index 5634feaf221cd..692a28dcc38c4 100644 --- a/tests/tests_pytorch/callbacks/test_spike.py +++ b/tests/tests_pytorch/callbacks/test_spike.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks.spike import SpikeDetection diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 8d3a1800e1fa2..e50eef7f258e1 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -20,17 +20,17 @@ import pytest import torch +from torch import nn +from torch.optim.lr_scheduler import LambdaLR +from torch.optim.swa_utils import SWALR +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import StochasticWeightAveraging from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.strategies import Strategy from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.optim.lr_scheduler import LambdaLR -from torch.optim.swa_utils import SWALR -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 4867134a85642..9f77e4371e69e 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric.utilities.throughput import measure_flops from lightning.pytorch import Trainer from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor diff --git a/tests/tests_pytorch/callbacks/test_timer.py b/tests/tests_pytorch/callbacks/test_timer.py index e6359a2e9a5e1..e91170f2096d3 100644 --- a/tests/tests_pytorch/callbacks/test_timer.py +++ b/tests/tests_pytorch/callbacks/test_timer.py @@ -17,12 +17,12 @@ from unittest.mock import Mock, patch import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks.timer import Timer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py index ff8f3c95e43c5..2e998c42ed2b7 100644 --- a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py @@ -16,9 +16,9 @@ import pytest import torch + from lightning.pytorch import Trainer, callbacks from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index be754d3911ade..006a123356c98 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -16,11 +16,11 @@ import sys from unittest.mock import patch -import lightning.pytorch as pl import pytest import torch -from lightning.pytorch import Callback, Trainer +import lightning.pytorch as pl +from lightning.pytorch import Callback, Trainer from tests_pytorch import _PATH_LEGACY from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -75,6 +75,7 @@ def test_legacy_ckpt_threading(pl_version: str): def load_model(): import torch + from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index d43f07179e7bb..2351cc7548f79 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -25,11 +25,13 @@ from unittest.mock import Mock, call, patch import cloudpickle -import lightning.pytorch as pl import pytest import torch import yaml from jsonargparse import ArgumentParser +from torch import optim + +import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -37,8 +39,6 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from torch import optim - from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: diff --git a/tests/tests_pytorch/checkpointing/test_torch_saving.py b/tests/tests_pytorch/checkpointing/test_torch_saving.py index 4422a4063c719..e49c45bf54ade 100644 --- a/tests/tests_pytorch/checkpointing/test_torch_saving.py +++ b/tests/tests_pytorch/checkpointing/test_torch_saving.py @@ -13,9 +13,9 @@ # limitations under the License. import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py index c07400eaf8446..b0f1528d42d28 100644 --- a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py @@ -15,9 +15,10 @@ from unittest import mock from unittest.mock import ANY, Mock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.fabric.plugins import TorchCheckpointIO, XLACheckpointIO from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index ea5207516cad1..ea7380be3a42c 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -21,18 +21,18 @@ from pathlib import Path from unittest.mock import Mock -import lightning.fabric -import lightning.pytorch import pytest import torch.distributed +from tqdm import TMonitor + +import lightning.fabric +import lightning.pytorch from lightning.fabric.plugins.environments.lightning import find_free_network_port from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector -from tqdm import TMonitor - from tests_pytorch import _PATH_DATASETS @@ -44,9 +44,10 @@ def datadir(): @pytest.fixture(autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" + from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities + from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric from lightning.pytorch.utilities.rank_zero import rank_zero_only as rank_zero_only_pytorch - from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities functions = (rank_zero_only_pytorch, rank_zero_only_fabric, rank_zero_only_utilities) ranks = [getattr(fn, "rank", None) for fn in functions] @@ -176,22 +177,22 @@ def mock_cuda_count(monkeypatch, n: int) -> None: monkeypatch.setattr(lightning.pytorch.accelerators.cuda, "num_cuda_devices", lambda: n) -@pytest.fixture() +@pytest.fixture def cuda_count_0(monkeypatch): mock_cuda_count(monkeypatch, 0) -@pytest.fixture() +@pytest.fixture def cuda_count_1(monkeypatch): mock_cuda_count(monkeypatch, 1) -@pytest.fixture() +@pytest.fixture def cuda_count_2(monkeypatch): mock_cuda_count(monkeypatch, 2) -@pytest.fixture() +@pytest.fixture def cuda_count_4(monkeypatch): mock_cuda_count(monkeypatch, 4) @@ -201,12 +202,12 @@ def mock_mps_count(monkeypatch, n: int) -> None: monkeypatch.setattr(lightning.fabric.accelerators.mps.MPSAccelerator, "is_available", lambda *_: n > 0) -@pytest.fixture() +@pytest.fixture def mps_count_0(monkeypatch): mock_mps_count(monkeypatch, 0) -@pytest.fixture() +@pytest.fixture def mps_count_1(monkeypatch): mock_mps_count(monkeypatch, 1) @@ -222,7 +223,7 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) -@pytest.fixture() +@pytest.fixture def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_xla_available(monkeypatch) @@ -238,12 +239,12 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) -@pytest.fixture() +@pytest.fixture def tpu_available(monkeypatch) -> None: mock_tpu_available(monkeypatch) -@pytest.fixture() +@pytest.fixture def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. @@ -271,7 +272,7 @@ def caplog(caplog): logging.getLogger(name).propagate = propagate -@pytest.fixture() +@pytest.fixture def tmpdir_server(tmp_path): Handler = partial(SimpleHTTPRequestHandler, directory=str(tmp_path)) from http.server import ThreadingHTTPServer @@ -285,7 +286,7 @@ def tmpdir_server(tmp_path): server.shutdown() -@pytest.fixture() +@pytest.fixture def single_process_pg(): """Initialize the default process group with only the current process for testing purposes. diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index b3ccd88aae704..49e8bf0f5b36d 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -21,6 +21,7 @@ import pytest import torch + from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import ( @@ -34,7 +35,6 @@ from lightning.pytorch.utilities import AttributeDict from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 5ee91e82689f4..2036014762ebf 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -17,15 +17,15 @@ import pytest import torch +from torch import nn +from torch.optim import SGD, Adam + from lightning.fabric import Fabric from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.core.module import _TrainerFabricShim from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.optim import SGD, Adam - from tests_pytorch.helpers.runif import RunIf @@ -336,9 +336,9 @@ def __init__(self, spec): ), "Expect the shards to be different before `m_1` loading `m_0`'s state dict" m_1.load_state_dict(m_0.state_dict(), strict=False) - assert torch.allclose( - m_1.sharded_tensor.local_shards()[0].tensor, m_0.sharded_tensor.local_shards()[0].tensor - ), "Expect the shards to be same after `m_1` loading `m_0`'s state dict" + assert torch.allclose(m_1.sharded_tensor.local_shards()[0].tensor, m_0.sharded_tensor.local_shards()[0].tensor), ( + "Expect the shards to be same after `m_1` loading `m_0`'s state dict" + ) def test_lightning_module_configure_gradient_clipping(tmp_path): diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index b25b7ae648a3a..ed1ca2b4db03f 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -15,13 +15,13 @@ from unittest.mock import DEFAULT, Mock, patch import torch +from torch.optim import SGD, Adam, Optimizer + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops.optimization.automatic import Closure from lightning.pytorch.tuner.tuning import Tuner -from torch.optim import SGD, Adam, Optimizer - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index dcb3f71c7499c..4a5df7d37cd7a 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -16,9 +16,15 @@ from contextlib import nullcontext, suppress from unittest import mock -import lightning.pytorch as pl import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor, tensor +from torch.nn import ModuleDict, ModuleList +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import Accuracy + +import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint @@ -30,12 +36,6 @@ _Sync, ) from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor, tensor -from torch.nn import ModuleDict, ModuleList -from torchmetrics import Metric, MetricCollection -from torchmetrics.classification import Accuracy - from tests_pytorch.core.test_results import spawn_launch from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/core/test_results.py b/tests/tests_pytorch/core/test_results.py index 006731f3e7c60..edc6efbfa4889 100644 --- a/tests/tests_pytorch/core/test_results.py +++ b/tests/tests_pytorch/core/test_results.py @@ -15,12 +15,12 @@ import torch import torch.distributed as dist + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.models.test_tpu import wrap_launch_function diff --git a/tests/tests_pytorch/core/test_saving.py b/tests/tests_pytorch/core/test_saving.py index c7e48239754c5..8e1e9584a7c68 100644 --- a/tests/tests_pytorch/core/test_saving.py +++ b/tests/tests_pytorch/core/test_saving.py @@ -1,11 +1,11 @@ from unittest.mock import ANY, Mock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel - from tests_pytorch.conftest import mock_cuda_count, mock_mps_count from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/demos/transformer.py b/tests/tests_pytorch/demos/transformer.py index 47ecbb2083273..50de873511053 100644 --- a/tests/tests_pytorch/demos/transformer.py +++ b/tests/tests_pytorch/demos/transformer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch + from lightning.pytorch.demos import Transformer diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index e6da72c777dbb..0b79b638534fa 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -1,9 +1,10 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch.nn + +import lightning.fabric from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision.double import LightningDoublePrecisionModule diff --git a/tests/tests_pytorch/helpers/__init__.py b/tests/tests_pytorch/helpers/__init__.py index 82a6332c56738..1299d7e542955 100644 --- a/tests/tests_pytorch/helpers/__init__.py +++ b/tests/tests_pytorch/helpers/__init__.py @@ -4,5 +4,4 @@ ManualOptimBoringModel, RandomDataset, ) - from tests_pytorch.helpers.datasets import TrialMNIST # noqa: F401 diff --git a/tests/tests_pytorch/helpers/advanced_models.py b/tests/tests_pytorch/helpers/advanced_models.py index ade21004dc635..959e6e5968d18 100644 --- a/tests/tests_pytorch/helpers/advanced_models.py +++ b/tests/tests_pytorch/helpers/advanced_models.py @@ -16,9 +16,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from lightning.pytorch.core.module import LightningModule from torch.utils.data import DataLoader +from lightning.pytorch.core.module import LightningModule from tests_pytorch import _PATH_DATASETS from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST diff --git a/tests/tests_pytorch/helpers/datamodules.py b/tests/tests_pytorch/helpers/datamodules.py index 5a91d8ebb981d..6282acf3be547 100644 --- a/tests/tests_pytorch/helpers/datamodules.py +++ b/tests/tests_pytorch/helpers/datamodules.py @@ -13,10 +13,10 @@ # limitations under the License. import torch -from lightning.pytorch.core.datamodule import LightningDataModule from lightning_utilities.core.imports import RequirementCache from torch.utils.data import DataLoader +from lightning.pytorch.core.datamodule import LightningDataModule from tests_pytorch.helpers.datasets import MNIST, SklearnDataset, TrialMNIST _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") diff --git a/tests/tests_pytorch/helpers/deterministic_model.py b/tests/tests_pytorch/helpers/deterministic_model.py index 95975e9ad1654..fbf3488158825 100644 --- a/tests/tests_pytorch/helpers/deterministic_model.py +++ b/tests/tests_pytorch/helpers/deterministic_model.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from lightning.pytorch.core.module import LightningModule from torch import Tensor, nn from torch.utils.data import DataLoader, Dataset +from lightning.pytorch.core.module import LightningModule + class DeterministicModel(LightningModule): def __init__(self, weights=None): diff --git a/tests/tests_pytorch/helpers/pipelines.py b/tests/tests_pytorch/helpers/pipelines.py index ab33878010123..b6c63a5702bfc 100644 --- a/tests/tests_pytorch/helpers/pipelines.py +++ b/tests/tests_pytorch/helpers/pipelines.py @@ -14,11 +14,11 @@ from functools import partial import torch +from torchmetrics.functional import accuracy + from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from torchmetrics.functional import accuracy - from tests_pytorch.helpers.utils import get_default_logger, load_model_from_checkpoint diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 1c5b059d679a5..25fadd524adf8 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch.utilities.testing import _runif_reasons diff --git a/tests/tests_pytorch/helpers/simple_models.py b/tests/tests_pytorch/helpers/simple_models.py index 940adc0ac49a6..a9dc635bba275 100644 --- a/tests/tests_pytorch/helpers/simple_models.py +++ b/tests/tests_pytorch/helpers/simple_models.py @@ -15,11 +15,12 @@ import torch import torch.nn.functional as F -from lightning.pytorch import LightningModule from lightning_utilities.core.imports import compare_version from torch import nn from torchmetrics import Accuracy, MeanSquaredError +from lightning.pytorch import LightningModule + # using new API with task _TM_GE_0_11 = compare_version("torchmetrics", operator.ge, "0.11.0") diff --git a/tests/tests_pytorch/helpers/test_models.py b/tests/tests_pytorch/helpers/test_models.py index cca2fbdc2e3e0..721641ae8343a 100644 --- a/tests/tests_pytorch/helpers/test_models.py +++ b/tests/tests_pytorch/helpers/test_models.py @@ -14,9 +14,9 @@ import os import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 7cc5cc94fe8cc..ab1149ca9651a 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -18,7 +18,7 @@ import pytest -@pytest.fixture() +@pytest.fixture def mlflow_mock(monkeypatch): mlflow = ModuleType("mlflow") mlflow.set_tracking_uri = Mock() @@ -43,7 +43,7 @@ def mlflow_mock(monkeypatch): return mlflow -@pytest.fixture() +@pytest.fixture def wandb_mock(monkeypatch): class RunType: # to make isinstance checks pass pass @@ -89,7 +89,7 @@ class RunType: # to make isinstance checks pass return wandb -@pytest.fixture() +@pytest.fixture def comet_mock(monkeypatch): comet = ModuleType("comet_ml") monkeypatch.setitem(sys.modules, "comet_ml", comet) @@ -110,7 +110,7 @@ def comet_mock(monkeypatch): return comet -@pytest.fixture() +@pytest.fixture def neptune_mock(monkeypatch): class RunType: # to make isinstance checks pass def get_root_object(self): diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 1b845c57ec35d..4ac28ef487ad8 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -19,6 +19,7 @@ import pytest import torch + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import ( @@ -32,7 +33,6 @@ from lightning.pytorch.loggers.logger import DummyExperiment, Logger from lightning.pytorch.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.pytorch.tuner.tuning import Tuner - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.loggers.test_comet import _patch_comet_atexit from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index e467c63543ede..de99454548302 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -16,11 +16,12 @@ from unittest.mock import DEFAULT, Mock, patch import pytest +from torch import tensor + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CometLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import tensor def _patch_comet_atexit(monkeypatch): diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py index 27b85bb4ad745..1b09302ffb74a 100644 --- a/tests/tests_pytorch/loggers/test_csv.py +++ b/tests/tests_pytorch/loggers/test_csv.py @@ -18,11 +18,11 @@ import fsspec import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.core.saving import load_hparams_from_yaml from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.csv_logs import ExperimentWriter - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index dcdd504fd4660..124a9120a9197 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -20,6 +20,7 @@ import numpy as np import pytest import torch + from lightning.fabric.utilities.logger import _convert_params, _sanitize_params from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 14af36680904c..c7f9dbe1fe2c6 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, Mock import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.mlflow import ( diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index 0a39337ac5c16..b5e98fbe99113 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -17,9 +17,10 @@ from unittest import mock from unittest.mock import MagicMock, call -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import NeptuneLogger diff --git a/tests/tests_pytorch/loggers/test_tensorboard.py b/tests/tests_pytorch/loggers/test_tensorboard.py index 82ffff25cac7c..173805f1a6f3c 100644 --- a/tests/tests_pytorch/loggers/test_tensorboard.py +++ b/tests/tests_pytorch/loggers/test_tensorboard.py @@ -20,12 +20,12 @@ import pytest import torch import yaml + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE - from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index a8e70bfb6589d..f3d82b0582be2 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -18,14 +18,14 @@ import pytest import yaml +from lightning_utilities.test.warning import no_warning_call + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call - from tests_pytorch.test_cli import _xfail_python_ge_3_11_9 diff --git a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py index 2fb04d0d9d8d1..e20c1789be023 100644 --- a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops.optimization.automatic import ClosureResult diff --git a/tests/tests_pytorch/loops/optimization/test_closure.py b/tests/tests_pytorch/loops/optimization/test_closure.py index d7d4e51794aca..7766a385c3057 100644 --- a/tests/tests_pytorch/loops/optimization/test_closure.py +++ b/tests/tests_pytorch/loops/optimization/test_closure.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/loops/optimization/test_manual_loop.py b/tests/tests_pytorch/loops/optimization/test_manual_loop.py index 67be30b24e159..cedfefb4791ea 100644 --- a/tests/tests_pytorch/loops/optimization/test_manual_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_manual_loop.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops.optimization.manual import ManualResult diff --git a/tests/tests_pytorch/loops/test_all.py b/tests/tests_pytorch/loops/test_all.py index 1eb67064fb300..51b7bbeedf90b 100644 --- a/tests/tests_pytorch/loops/test_all.py +++ b/tests/tests_pytorch/loops/test_all.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 588672e19a05c..6d8c4be70d9c6 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -16,13 +16,13 @@ import pytest import torch +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.sampler import BatchSampler, RandomSampler + from lightning.fabric.accelerators.cuda import _clear_cuda_memory from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities import CombinedLoader -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.sampler import BatchSampler, RandomSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py index 09ded9a777679..e698a15cb0fb2 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py @@ -14,11 +14,11 @@ """Tests the evaluation loop.""" import torch +from torch import Tensor + from lightning.pytorch import Trainer from lightning.pytorch.core.module import LightningModule from lightning.pytorch.trainer.states import RunningStage -from torch import Tensor - from tests_pytorch.helpers.deterministic_model import DeterministicModel diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 75b25e3d98fd8..f66e9f9f3b16f 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -17,6 +17,9 @@ import pytest import torch +from torch import Tensor +from torch.utils.data import DataLoader, Dataset, IterableDataset + from lightning.pytorch import LightningDataModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.loops.fetchers import _DataLoaderIterDataFetcher, _PrefetchDataFetcher @@ -24,9 +27,6 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import STEP_OUTPUT -from torch import Tensor -from torch.utils.data import DataLoader, Dataset, IterableDataset - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 1820ca3568173..384ae2b47859b 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -20,6 +20,8 @@ import pytest import torch +from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback, ModelCheckpoint, OnExceptionCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -27,8 +29,6 @@ from lightning.pytorch.loops.progress import _BaseProgress from lightning.pytorch.utilities import CombinedLoader from lightning.pytorch.utilities.types import STEP_OUTPUT -from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index f27413955cae9..470cbcdc195f5 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -14,10 +14,11 @@ import itertools import pytest +from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper -from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler def test_prediction_loop_stores_predictions(tmp_path): diff --git a/tests/tests_pytorch/loops/test_progress.py b/tests/tests_pytorch/loops/test_progress.py index e7256d4504402..27184d7b17afb 100644 --- a/tests/tests_pytorch/loops/test_progress.py +++ b/tests/tests_pytorch/loops/test_progress.py @@ -14,6 +14,7 @@ from copy import deepcopy import pytest + from lightning.pytorch.loops.progress import ( _BaseProgress, _OptimizerProgress, diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index a110a20bfaf84..71252fe2011b8 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -16,11 +16,12 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call + from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.trainer import Trainer -from lightning_utilities.test.warning import no_warning_call def test_no_val_on_train_epoch_loop_restart(tmp_path): diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 2afeb338fd9fd..29afd1ba1a250 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -16,6 +16,7 @@ import pytest import torch + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _FitLoop diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py index c89913b29dbdb..44a0c85184d9e 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py @@ -14,9 +14,9 @@ """Tests to ensure that the training loop works with a dict (1.0)""" import torch + from lightning.pytorch import Trainer from lightning.pytorch.core.module import LightningModule - from tests_pytorch.helpers.deterministic_model import DeterministicModel diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py index ccdeb55d50ea8..4a4e8cb81c6a2 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from lightning.pytorch import Trainer -from lightning.pytorch.core.module import LightningModule -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.loops.optimization.automatic import Closure -from lightning.pytorch.trainer.states import RunningStage from lightning_utilities.test.warning import no_warning_call from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate +from lightning.pytorch import Trainer +from lightning.pytorch.core.module import LightningModule +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.loops.optimization.automatic import Closure +from lightning.pytorch.trainer.states import RunningStage from tests_pytorch.helpers.deterministic_model import DeterministicModel diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index c28d6300131ae..24323f5c1d691 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -16,12 +16,12 @@ import pytest import torch -from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from torch.utils.data import DataLoader import tests_pytorch.helpers.utils as tutils +from lightning.fabric.plugins.environments import SLURMEnvironment +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_cpu.py b/tests/tests_pytorch/models/test_cpu.py index 435b8437760d2..a2d38aca7c56c 100644 --- a/tests/tests_pytorch/models/test_cpu.py +++ b/tests/tests_pytorch/models/test_cpu.py @@ -15,12 +15,12 @@ from unittest import mock import torch -from lightning.pytorch import Trainer, seed_everything -from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils +from lightning.pytorch import Trainer, seed_everything +from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py index 54d394948eeee..a1d7c6a7f8ac1 100644 --- a/tests/tests_pytorch/models/test_ddp_fork_amp.py +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -14,8 +14,8 @@ import multiprocessing import torch -from lightning.pytorch.plugins import MixedPrecision +from lightning.pytorch.plugins import MixedPrecision from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_fabric_integration.py b/tests/tests_pytorch/models/test_fabric_integration.py index fd5ecd34e96be..3e2811e8a9413 100644 --- a/tests/tests_pytorch/models/test_fabric_integration.py +++ b/tests/tests_pytorch/models/test_fabric_integration.py @@ -16,6 +16,7 @@ from unittest.mock import Mock import torch + from lightning.fabric import Fabric from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel diff --git a/tests/tests_pytorch/models/test_gpu.py b/tests/tests_pytorch/models/test_gpu.py index b411774c3e164..797120312436f 100644 --- a/tests/tests_pytorch/models/test_gpu.py +++ b/tests/tests_pytorch/models/test_gpu.py @@ -18,14 +18,14 @@ import pytest import torch + +import tests_pytorch.helpers.pipelines as tpipes from lightning.fabric.plugins.environments import TorchElasticEnvironment from lightning.fabric.utilities.device_parser import _parse_gpu_ids from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 1a8aeb4b297a9..e943d0533cab5 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -18,12 +18,12 @@ import pytest import torch -from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__ -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset -from lightning.pytorch.utilities.model_helpers import is_overridden from torch import Tensor from torch.utils.data import DataLoader +from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__ +from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset +from lightning.pytorch.utilities.model_helpers import is_overridden from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 871b1cba673eb..6fd400aab2724 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -25,6 +25,10 @@ import pytest import torch from fsspec.implementations.local import LocalFileSystem +from lightning_utilities.core.imports import RequirementCache +from lightning_utilities.test.warning import no_warning_call +from torch.utils.data import DataLoader + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.core.datamodule import LightningDataModule @@ -34,10 +38,6 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities import AttributeDict, is_picklable from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from lightning_utilities.core.imports import RequirementCache -from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index ee670cd66e871..b3032bea5560d 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -21,11 +21,11 @@ import onnxruntime import pytest import torch -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel from lightning_utilities import compare_version import tests_pytorch.helpers.pipelines as tpipes +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf from tests_pytorch.utilities.test_model_summary import UnorderedModel diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 64f70b176a971..099493890831d 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -22,16 +22,16 @@ import cloudpickle import pytest import torch -from lightning.fabric import seed_everything -from lightning.pytorch import Callback, Trainer -from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer.states import TrainerFn from lightning_utilities.test.warning import no_warning_call from torch import Tensor import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils +from lightning.fabric import seed_everything +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.trainer.states import TrainerFn from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 993085729e545..8f9151265d21a 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -18,11 +18,11 @@ import pytest import torch from fsspec.implementations.local import LocalFileSystem + from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch.core.module import LightningModule from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index af927e0d5596a..8067fd63b6562 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -17,6 +17,9 @@ import pytest import torch +from torch.utils.data import DataLoader + +import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.callbacks import EarlyStopping @@ -25,9 +28,6 @@ from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 3e2fba54bcd03..dd5f8f1504af7 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -15,11 +15,11 @@ import pytest import torch +from torch.utils.data import BatchSampler, SequentialSampler + from lightning.fabric.utilities.data import has_len from lightning.pytorch import LightningModule, Trainer, seed_everything from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper -from torch.utils.data import BatchSampler, SequentialSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_all.py b/tests/tests_pytorch/plugins/precision/test_all.py index 2668311c8b452..2a11b7c66c772 100644 --- a/tests/tests_pytorch/plugins/precision/test_all.py +++ b/tests/tests_pytorch/plugins/precision/test_all.py @@ -1,5 +1,6 @@ import pytest import torch + from lightning.pytorch.plugins import ( DeepSpeedPrecision, DoublePrecision, diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 90ecc703c8945..cb061c540b2be 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,9 +14,10 @@ from unittest.mock import Mock import pytest +from torch.optim import Optimizer + from lightning.pytorch.plugins import MixedPrecision from lightning.pytorch.utilities import GradClipAlgorithmType -from torch.optim import Optimizer def test_clip_gradients(): diff --git a/tests/tests_pytorch/plugins/precision/test_amp_integration.py b/tests/tests_pytorch/plugins/precision/test_amp_integration.py index bc9f77907919a..f231e3ce91e0b 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp_integration.py +++ b/tests/tests_pytorch/plugins/precision/test_amp_integration.py @@ -14,12 +14,12 @@ from unittest.mock import Mock import torch + from lightning.fabric import seed_everything from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision import MixedPrecision - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py b/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py index a88e38d6303e4..8f331e26f979d 100644 --- a/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py @@ -14,10 +14,11 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed + +import lightning.fabric from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index 3e1aaa17763e9..da4ce6b89aaab 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -14,6 +14,7 @@ import pytest import torch + from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision diff --git a/tests/tests_pytorch/plugins/precision/test_double.py b/tests/tests_pytorch/plugins/precision/test_double.py index 1ee89752fcbae..74c178c65d05d 100644 --- a/tests/tests_pytorch/plugins/precision/test_double.py +++ b/tests/tests_pytorch/plugins/precision/test_double.py @@ -16,11 +16,11 @@ import pytest import torch +from torch.utils.data import DataLoader, Dataset + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.plugins.precision.double import DoublePrecision -from torch.utils.data import DataLoader, Dataset - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index 3ad3af1f1b56b..0389d364dcb79 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -15,9 +15,9 @@ import pytest import torch + from lightning.fabric.plugins.precision.utils import _DtypeContextManager from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_half.py b/tests/tests_pytorch/plugins/precision/test_half.py index d51392a00e3d0..9597e01ea428b 100644 --- a/tests/tests_pytorch/plugins/precision/test_half.py +++ b/tests/tests_pytorch/plugins/precision/test_half.py @@ -14,6 +14,7 @@ import pytest import torch + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins import HalfPrecision diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py index 7c92ff47d909a..a9967280e3f23 100644 --- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py +++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py @@ -15,9 +15,10 @@ from contextlib import nullcontext from unittest.mock import ANY, Mock -import lightning.fabric import pytest import torch + +import lightning.fabric from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins import TransformerEnginePrecision from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector diff --git a/tests/tests_pytorch/plugins/precision/test_xla.py b/tests/tests_pytorch/plugins/precision/test_xla.py index 97990b6380dab..b456c49e8ff50 100644 --- a/tests/tests_pytorch/plugins/precision/test_xla.py +++ b/tests/tests_pytorch/plugins/precision/test_xla.py @@ -18,8 +18,8 @@ import pytest import torch -from lightning.pytorch.plugins import XLAPrecision +from lightning.pytorch.plugins import XLAPrecision from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index e8adb59c39e51..0b68c098cc713 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -18,11 +18,11 @@ import pytest import torch +from torch import Tensor + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins import MixedPrecision -from torch import Tensor - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 58baa47e7a620..cae26fc1fe775 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, Mock import torch + from lightning.fabric.plugins import CheckpointIO, TorchCheckpointIO from lightning.fabric.utilities.types import _PATH from lightning.pytorch import Trainer diff --git a/tests/tests_pytorch/plugins/test_cluster_integration.py b/tests/tests_pytorch/plugins/test_cluster_integration.py index 026465ac8b17b..08bd1707b5cfd 100644 --- a/tests/tests_pytorch/plugins/test_cluster_integration.py +++ b/tests/tests_pytorch/plugins/test_cluster_integration.py @@ -16,11 +16,11 @@ import pytest import torch + from lightning.fabric.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment from lightning.pytorch import Trainer from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy from lightning.pytorch.utilities.rank_zero import rank_zero_only - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index 5b0c13e605ee2..df9bc16c284ad 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -22,6 +22,7 @@ import numpy as np import pytest import torch + from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import EarlyStopping, StochasticWeightAveraging @@ -30,7 +31,6 @@ from lightning.pytorch.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from lightning.pytorch.profilers.pytorch import _KINETO_AVAILABLE, RegisterRecordFunction, warning_cache from lightning.pytorch.utilities.exceptions import MisconfigurationException - from tests_pytorch.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -55,7 +55,7 @@ def _sleep_generator(durations): yield duration -@pytest.fixture() +@pytest.fixture def simple_profiler(): return SimpleProfiler() @@ -264,7 +264,7 @@ def test_simple_profiler_summary(tmp_path, extended): assert expected_text == summary -@pytest.fixture() +@pytest.fixture def advanced_profiler(tmp_path): return AdvancedProfiler(dirpath=tmp_path, filename="profiler") @@ -336,7 +336,7 @@ def test_advanced_profiler_deepcopy(advanced_profiler): assert deepcopy(advanced_profiler) -@pytest.fixture() +@pytest.fixture def pytorch_profiler(tmp_path): return PyTorchProfiler(dirpath=tmp_path, filename="profiler") diff --git a/tests/tests_pytorch/profilers/test_xla_profiler.py b/tests/tests_pytorch/profilers/test_xla_profiler.py index 980a4dac74731..80337382ddcb5 100644 --- a/tests/tests_pytorch/profilers/test_xla_profiler.py +++ b/tests/tests_pytorch/profilers/test_xla_profiler.py @@ -16,10 +16,10 @@ from unittest import mock import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.profilers import XLAProfiler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index ec4dd8825c8ea..ba90949132ba2 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -1,9 +1,10 @@ import pytest import torch +from torch import Tensor + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator -from torch import Tensor class ServableBoringModel(BoringModel, ServableModule): diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 394d827058987..d26f6c4d2c3ef 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -18,13 +18,13 @@ import pytest import torch + from lightning.fabric.plugins import ClusterEnvironment from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher from lightning.pytorch.trainer.states import TrainerFn - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py b/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py index b8a5ddb29de23..dd8576ec0cafe 100644 --- a/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py +++ b/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py @@ -4,9 +4,9 @@ from unittest.mock import Mock import pytest -from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning_utilities.core.imports import RequirementCache +from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from tests_pytorch.helpers.runif import RunIf _HYDRA_WITH_RUN_PROCESS = RequirementCache("hydra-core>=1.0.7") diff --git a/tests/tests_pytorch/strategies/test_common.py b/tests/tests_pytorch/strategies/test_common.py index 699424b3c53b9..6ab4f49374c27 100644 --- a/tests/tests_pytorch/strategies/test_common.py +++ b/tests/tests_pytorch/strategies/test_common.py @@ -15,10 +15,10 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision from lightning.pytorch.strategies import SingleDeviceStrategy - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/strategies/test_custom_strategy.py b/tests/tests_pytorch/strategies/test_custom_strategy.py index 347dacbd9a811..8a297db217943 100644 --- a/tests/tests_pytorch/strategies/test_custom_strategy.py +++ b/tests/tests_pytorch/strategies/test_custom_strategy.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import SingleDeviceStrategy diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index b23d306b9d907..915e57440b40f 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -17,14 +17,14 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.trainer.states import TrainerFn -from torch.nn.parallel import DistributedDataParallel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 836072d36be83..048403366ebc7 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -15,9 +15,14 @@ from unittest import mock from unittest.mock import Mock -import lightning.pytorch as pl import pytest import torch +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.multiprocessing import ProcessRaisedException +from torch.nn.parallel.distributed import DistributedDataParallel + +import lightning.pytorch as pl +import tests_pytorch.helpers.pipelines as tpipes from lightning.fabric.plugins.environments import ClusterEnvironment, LightningEnvironment from lightning.fabric.utilities.distributed import _distributed_is_initialized from lightning.pytorch import Trainer @@ -27,11 +32,6 @@ from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher from lightning.pytorch.strategies.launchers.multiprocessing import _MultiProcessingLauncher from lightning.pytorch.trainer import seed_everything -from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.multiprocessing import ProcessRaisedException -from torch.nn.parallel.distributed import DistributedDataParallel - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py b/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py index 3723ee2fc5a8f..c4d40e1dfa7d2 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py @@ -15,10 +15,10 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DDPStrategy - from tests_pytorch.helpers.runif import RunIf if torch.distributed.is_available(): diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 73697ea131545..7e7d2eacd0617 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -22,6 +22,10 @@ import pytest import torch import torch.nn.functional as F +from torch import Tensor, nn +from torch.utils.data import DataLoader +from torchmetrics import Accuracy + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint @@ -31,10 +35,6 @@ from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from torch import Tensor, nn -from torch.utils.data import DataLoader -from torchmetrics import Accuracy - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -81,7 +81,7 @@ def automatic_optimization(self) -> bool: return False -@pytest.fixture() +@pytest.fixture def deepspeed_config(): return { "optimizer": {"type": "SGD", "params": {"lr": 3e-5}}, @@ -92,7 +92,7 @@ def deepspeed_config(): } -@pytest.fixture() +@pytest.fixture def deepspeed_zero_config(deepspeed_config): return {**deepspeed_config, "zero_allow_untested_optimizer": True, "zero_optimization": {"stage": 2}} diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 2aee68f7ae733..f3e88ca356764 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -12,6 +12,10 @@ import pytest import torch import torch.nn as nn +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision +from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap +from torchmetrics import Accuracy + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 @@ -24,10 +28,6 @@ from lightning.pytorch.strategies import FSDPStrategy from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision -from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap -from torchmetrics import Accuracy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py index 731da66d4a61f..86a95944ac20d 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel.py +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -20,11 +20,11 @@ import pytest import torch import torch.nn as nn + from lightning.fabric.strategies.model_parallel import _is_sharded_checkpoint from lightning.pytorch import LightningModule from lightning.pytorch.plugins.environments import LightningEnvironment from lightning.pytorch.strategies import ModelParallelStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 9dcbcc802834b..00600183f4293 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -18,12 +18,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from lightning.pytorch import LightningModule, Trainer, seed_everything -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.strategies import ModelParallelStrategy from torch.utils.data import DataLoader, DistributedSampler from torchmetrics.classification import Accuracy +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.strategies import ModelParallelStrategy from tests_pytorch.helpers.runif import RunIf @@ -86,7 +86,7 @@ def fn(model, device_mesh): return fn -@pytest.fixture() +@pytest.fixture def distributed(): yield if torch.distributed.is_initialized(): diff --git a/tests/tests_pytorch/strategies/test_registry.py b/tests/tests_pytorch/strategies/test_registry.py index 90e15638bfd06..d2c580fd28c0f 100644 --- a/tests/tests_pytorch/strategies/test_registry.py +++ b/tests/tests_pytorch/strategies/test_registry.py @@ -14,10 +14,10 @@ from unittest import mock import pytest + from lightning.pytorch import Trainer from lightning.pytorch.plugins import CheckpointIO from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy, FSDPStrategy, StrategyRegistry, XLAStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_single_device.py b/tests/tests_pytorch/strategies/test_single_device.py index 5b10c4a17d726..7582dfe86dd3c 100644 --- a/tests/tests_pytorch/strategies/test_single_device.py +++ b/tests/tests_pytorch/strategies/test_single_device.py @@ -16,12 +16,12 @@ import pytest import torch +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.strategies import SingleDeviceStrategy -from torch.utils.data import DataLoader - from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py index b4f0e2c37ec05..3fde2600c9483 100644 --- a/tests/tests_pytorch/strategies/test_xla.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -16,11 +16,11 @@ from unittest.mock import Mock import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import XLAStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index de89d094cdfcf..b9decf781386b 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -27,6 +27,14 @@ import pytest import torch import yaml +from lightning_utilities import compare_version +from lightning_utilities.test.warning import no_warning_call +from packaging.version import Version +from tensorboard.backend.event_processing import event_accumulator +from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData +from torch.optim import SGD +from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR + from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__, seed_everything from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint @@ -46,14 +54,6 @@ from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from lightning_utilities import compare_version -from lightning_utilities.test.warning import no_warning_call -from packaging.version import Version -from tensorboard.backend.event_processing import event_accumulator -from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData -from torch.optim import SGD -from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR - from tests_pytorch.helpers.runif import RunIf if _JSONARGPARSE_SIGNATURES_AVAILABLE: @@ -84,7 +84,7 @@ def mock_subclasses(baseclass, *subclasses): yield None -@pytest.fixture() +@pytest.fixture def cleandir(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) return @@ -666,7 +666,7 @@ class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) - match = "BoringModel.configure_optimizers` will be overridden by " "`MyLightningCLI.configure_optimizers`" + match = "BoringModel.configure_optimizers` will be overridden by `MyLightningCLI.configure_optimizers`" argv = ["fit", "--trainer.fast_dev_run=1"] if run else [] with mock.patch("sys.argv", ["any.py"] + argv), pytest.warns(UserWarning, match=match): cli = MyLightningCLI(BoringModel, run=run) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 9e947e0723dcd..b8517a0303015 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -19,11 +19,12 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric -import lightning.pytorch import pytest import torch import torch.distributed + +import lightning.fabric +import lightning.pytorch from lightning.fabric.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, @@ -61,7 +62,6 @@ from lightning.pytorch.utilities.imports import ( _LIGHTNING_HABANA_AVAILABLE, ) - from tests_pytorch.conftest import mock_cuda_count, mock_mps_count, mock_tpu_available, mock_xla_available from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index eb09413cafcce..94b5fcba652be 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -18,6 +18,7 @@ import pytest import torch + from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_10_0 from lightning.pytorch import Callback, LightningModule, Trainer from lightning.pytorch.callbacks import ( diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index d29e2285e983c..722742a3ccae0 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index ca5690ed20f41..ceb0418f2cb1d 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -16,8 +16,12 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler + +import lightning.fabric from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer @@ -33,10 +37,6 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import _update_dataloader from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py index 8825db3727e86..83c5c2bb7e02b 100644 --- a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py @@ -18,13 +18,13 @@ from unittest.mock import Mock import pytest + from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector from lightning.pytorch.utilities.exceptions import SIGTERMException - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py index f7b7d3c8ac8a7..89c842f2bbf0f 100644 --- a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py +++ b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch +from torch.utils.data import Dataset + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from torch.utils.data import Dataset class RandomDatasetA(Dataset): diff --git a/tests/tests_pytorch/trainer/flags/test_barebones.py b/tests/tests_pytorch/trainer/flags/test_barebones.py index 329fcf915d751..875aaef40a123 100644 --- a/tests/tests_pytorch/trainer/flags/test_barebones.py +++ b/tests/tests_pytorch/trainer/flags/test_barebones.py @@ -14,6 +14,7 @@ import logging import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelSummary from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py b/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py index fba938d82e761..301721e4b28f5 100644 --- a/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +from torch.utils.data import DataLoader + from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.trainer.trainer import Trainer -from torch.utils.data import DataLoader @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py index 62087619ae42f..63b05a0a131f2 100644 --- a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py +++ b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/flags/test_inference_mode.py b/tests/tests_pytorch/trainer/flags/test_inference_mode.py index bae7b66dbbd55..802d9bf30c59b 100644 --- a/tests/tests_pytorch/trainer/flags/test_inference_mode.py +++ b/tests/tests_pytorch/trainer/flags/test_inference_mode.py @@ -16,6 +16,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _Loop diff --git a/tests/tests_pytorch/trainer/flags/test_limit_batches.py b/tests/tests_pytorch/trainer/flags/test_limit_batches.py index e190a0b380377..ef405a8ee95b1 100644 --- a/tests/tests_pytorch/trainer/flags/test_limit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_limit_batches.py @@ -14,6 +14,7 @@ import logging import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.states import TrainerFn diff --git a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py index 25aaeb8cff77e..3315c328b6249 100644 --- a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py +++ b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py @@ -1,8 +1,9 @@ import pytest +from lightning_utilities.test.warning import no_warning_call + from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning_utilities.test.warning import no_warning_call @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index dfe8a63a8bfc4..050818287ba45 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -16,11 +16,11 @@ import pytest import torch +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.trainer.states import RunningStage -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.datasets import SklearnDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index b776263e9953d..b6cc446cb0840 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -14,10 +14,11 @@ import logging import pytest +from torch.utils.data import DataLoader + from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.trainer.trainer import Trainer from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader @pytest.mark.parametrize("max_epochs", [1, 2, 3]) diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py index af7cecdb21a08..90f9a3e697535 100644 --- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py @@ -19,7 +19,6 @@ from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.logger import Logger - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 40c82bec2fd10..be6de37ddff3a 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -24,6 +24,8 @@ import numpy as np import pytest import torch +from torch import Tensor + from lightning.pytorch import Trainer, callbacks from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -31,8 +33,6 @@ from lightning.pytorch.loops import _EvaluationLoop from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import Tensor - from tests_pytorch.helpers.runif import RunIf if _RICH_AVAILABLE: diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index e96857a6c192d..faf88a09f6499 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -18,6 +18,11 @@ import pytest import torch +from lightning_utilities.core.imports import compare_version +from torch.utils.data import DataLoader +from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection +from torchmetrics import AveragePrecision as AvgPre + from lightning.pytorch import LightningModule from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -28,11 +33,6 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1 from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from lightning_utilities.core.imports import compare_version -from torch.utils.data import DataLoader -from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection -from torchmetrics import AveragePrecision as AvgPre - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.models.test_hooks import get_members diff --git a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py index e43f3b8d6bffd..6bc9dd1de0587 100644 --- a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py @@ -18,6 +18,7 @@ from unittest.mock import ANY import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index e48f80d2d1680..3981ddd64d773 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -22,6 +22,11 @@ import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import DataLoader +from torchmetrics import Accuracy + from lightning.pytorch import Trainer, callbacks from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from lightning.pytorch.core.module import LightningModule @@ -30,11 +35,6 @@ from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import DataLoader -from torchmetrics import Accuracy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/optimization/test_backward_calls.py b/tests/tests_pytorch/trainer/optimization/test_backward_calls.py index b91dbff8c6d09..42332bc05580f 100644 --- a/tests/tests_pytorch/trainer/optimization/test_backward_calls.py +++ b/tests/tests_pytorch/trainer/optimization/test_backward_calls.py @@ -2,6 +2,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index f0ab8fe401633..3f89e1459298d 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -21,11 +21,11 @@ import torch import torch.distributed as torch_distrib import torch.nn.functional as F + from lightning.fabric.utilities.exceptions import MisconfigurationException from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel from lightning.pytorch.strategies import Strategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py index 319eafeb0d0bb..dcbae32827c50 100644 --- a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py @@ -13,9 +13,10 @@ # limitations under the License. """Tests to ensure that the behaviours related to multiple optimizers works.""" -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index ac660b6651be5..6b88534f3430d 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -16,6 +16,8 @@ import pytest import torch +from torch import optim + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.core.optimizer import ( @@ -26,8 +28,6 @@ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import LRSchedulerConfig -from torch import optim - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 76c0c695b3c02..aa60db594447d 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -18,11 +18,11 @@ import pytest import torch +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset from lightning.pytorch.strategies import SingleDeviceXLAStrategy -from torch.utils.data import DataLoader - from tests_pytorch.conftest import mock_cuda_count from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_get_model.py b/tests/tests_pytorch/trainer/properties/test_get_model.py index 72967ce929eea..47d2bffcfca22 100644 --- a/tests/tests_pytorch/trainer/properties/test_get_model.py +++ b/tests/tests_pytorch/trainer/properties/test_get_model.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_log_dir.py b/tests/tests_pytorch/trainer/properties/test_log_dir.py index 1fc4f3454f9d0..0f045c2e815fd 100644 --- a/tests/tests_pytorch/trainer/properties/test_log_dir.py +++ b/tests/tests_pytorch/trainer/properties/test_log_dir.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 7fcf663f59122..1d07f3e99d412 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger - from tests_pytorch.loggers.test_logger import CustomLogger diff --git a/tests/tests_pytorch/trainer/test_config_validator.py b/tests/tests_pytorch/trainer/test_config_validator.py index 8963b1f76186d..cfca98e04c8c8 100644 --- a/tests/tests_pytorch/trainer/test_config_validator.py +++ b/tests/tests_pytorch/trainer/test_config_validator.py @@ -15,6 +15,7 @@ import pytest import torch + from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index a2d29baa9fa6f..7fbe55030770e 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -14,10 +14,17 @@ import os from unittest.mock import Mock, call, patch -import lightning.pytorch import numpy import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch.utils.data import RandomSampler +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import SequentialSampler + +import lightning.pytorch from lightning.fabric.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset from lightning.pytorch import Callback, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -33,13 +40,6 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import RandomSampler -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset, IterableDataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import SequentialSampler - from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_states.py b/tests/tests_pytorch/trainer/test_states.py index d89e99c9319c6..fff2d0b464d42 100644 --- a/tests/tests_pytorch/trainer/test_states.py +++ b/tests/tests_pytorch/trainer/test_states.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index d66f3aafee5df..18ae7ce77bdfc 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -27,6 +27,12 @@ import pytest import torch import torch.nn as nn +from torch.multiprocessing import ProcessRaisedException +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim import SGD +from torch.utils.data import DataLoader, IterableDataset + +import tests_pytorch.helpers.utils as tutils from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.seed import seed_everything @@ -50,12 +56,6 @@ from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from torch.multiprocessing import ProcessRaisedException -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.optim import SGD -from torch.utils.data import DataLoader, IterableDataset - -import tests_pytorch.helpers.utils as tutils from tests_pytorch.conftest import mock_cuda_count, mock_mps_count from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -335,9 +335,9 @@ def mock_save_function(filepath, *args): file_lists = set(os.listdir(tmp_path)) - assert len(file_lists) == len( - expected_files - ), f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}" + assert len(file_lists) == len(expected_files), ( + f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}" + ) # verify correct naming for fname in expected_files: diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index a31be67911409..ec894688ccb6c 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -20,6 +20,8 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks.lr_finder import LearningRateFinder from lightning.pytorch.demos.boring_classes import BoringModel @@ -27,8 +29,6 @@ from lightning.pytorch.tuner.tuning import Tuner from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import STEP_OUTPUT -from lightning_utilities.test.warning import no_warning_call - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -73,9 +73,9 @@ def test_model_reset_correctly(tmp_path): after_state_dict = model.state_dict() for key in before_state_dict: - assert torch.all( - torch.eq(before_state_dict[key], after_state_dict[key]) - ), "Model was not reset correctly after learning rate finder" + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), ( + "Model was not reset correctly after learning rate finder" + ) assert not any(f for f in os.listdir(tmp_path) if f.startswith(".lr_find")) diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index 8dd66fe9bfcff..e4ed533c6fa83 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -18,14 +18,14 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset from lightning.pytorch.tuner.tuning import Tuner from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf @@ -114,9 +114,9 @@ def test_trainer_reset_correctly(tmp_path, trainer_fn): after_state_dict = model.state_dict() for key in before_state_dict: - assert torch.all( - torch.eq(before_state_dict[key], after_state_dict[key]) - ), "Model was not reset correctly after scaling batch size" + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), ( + "Model was not reset correctly after scaling batch size" + ) assert not any(f for f in os.listdir(tmp_path) if f.startswith(".scale_batch_size_temp_model")) diff --git a/tests/tests_pytorch/tuner/test_tuning.py b/tests/tests_pytorch/tuner/test_tuning.py index dda08354575c0..e3b24a69b7999 100644 --- a/tests/tests_pytorch/tuner/test_tuning.py +++ b/tests/tests_pytorch/tuner/test_tuning.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import BatchSizeFinder, LearningRateFinder from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index 9680c90a94c5b..f9c921f6f1bfd 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -13,9 +13,10 @@ # limitations under the License. from unittest.mock import ANY, MagicMock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 56b8701cfcfc2..41ca1a779f8a5 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -18,16 +18,16 @@ import sys from unittest.mock import ANY -import lightning.pytorch as pl import pytest import torch -from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.pytorch.utilities.migration import migrate_checkpoint, pl_legacy_patch -from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint, _RedirectingUnpickler from lightning_utilities.core.imports import module_available from lightning_utilities.test.warning import no_warning_call from packaging.version import Version +import lightning.pytorch as pl +from lightning.fabric.utilities.warnings import PossibleUserWarning +from lightning.pytorch.utilities.migration import migrate_checkpoint, pl_legacy_patch +from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint, _RedirectingUnpickler from tests_pytorch.checkpointing.test_legacy_checkpoints import ( CHECKPOINT_EXTENSION, LEGACY_BACK_COMPATIBLE_PL_VERSIONS, @@ -75,9 +75,9 @@ def _list_sys_modules(pattern: str) -> str: @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @pytest.mark.skipif(module_available("lightning"), reason="This test is ONLY relevant for the STANDALONE package") def test_test_patch_legacy_imports_standalone(pl_version): - assert any( - key.startswith("pytorch_lightning") for key in sys.modules - ), f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + assert any(key.startswith("pytorch_lightning") for key in sys.modules), ( + f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + ) path_legacy = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) path_ckpts = sorted(glob.glob(os.path.join(path_legacy, f"*{CHECKPOINT_EXTENSION}"))) assert path_ckpts, f'No checkpoints found in folder "{path_legacy}"' @@ -86,9 +86,9 @@ def test_test_patch_legacy_imports_standalone(pl_version): with no_warning_call(match="Redirecting import of*"), pl_legacy_patch(): torch.load(path_ckpt, weights_only=False) - assert any( - key.startswith("pytorch_lightning") for key in sys.modules - ), f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + assert any(key.startswith("pytorch_lightning") for key in sys.modules), ( + f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + ) assert not any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( "Did not import the unified package," f" so it should not be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" @@ -98,9 +98,9 @@ def test_test_patch_legacy_imports_standalone(pl_version): @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @pytest.mark.skipif(not module_available("lightning"), reason="This test is ONLY relevant for the UNIFIED package") def test_patch_legacy_imports_unified(pl_version): - assert any( - key.startswith("lightning." + "pytorch") for key in sys.modules - ), f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( + f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + ) assert not any(key.startswith("pytorch_lightning") for key in sys.modules), ( "Should not import standalone package, all imports should be redirected to the unified package;\n" f" environment: {_list_sys_modules('pytorch_lightning')}" @@ -119,9 +119,9 @@ def test_patch_legacy_imports_unified(pl_version): with context, pl_legacy_patch(): torch.load(path_ckpt, weights_only=False) - assert any( - key.startswith("lightning." + "pytorch") for key in sys.modules - ), f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( + f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + ) assert not any(key.startswith("pytorch_lightning") for key in sys.modules), ( "Should not import standalone package, all imports should be redirected to the unified package;\n" f" environment: {_list_sys_modules('pytorch_lightning')}" diff --git a/tests/tests_pytorch/utilities/test_all_gather_grad.py b/tests/tests_pytorch/utilities/test_all_gather_grad.py index 9b034cdcd34e2..82ca15fd87432 100644 --- a/tests/tests_pytorch/utilities/test_all_gather_grad.py +++ b/tests/tests_pytorch/utilities/test_all_gather_grad.py @@ -15,9 +15,9 @@ import numpy as np import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.core.test_results import spawn_launch from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index d4b2fdaf82834..4da7bfd098a0b 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -14,13 +14,13 @@ import inspect import pytest +from torch.utils.data.dataloader import DataLoader + from lightning.fabric.utilities.seed import seed_everything from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities.exceptions import SIGTERMException -from torch.utils.data.dataloader import DataLoader - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 43a146c6eb089..da168be1e3e8a 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -19,6 +19,13 @@ import pytest import torch +from torch import Tensor +from torch.utils._pytree import tree_flatten +from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import RandomSampler, SequentialSampler + from lightning.fabric.utilities.types import _Stateful from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -31,13 +38,6 @@ _MinSize, _Sequential, ) -from torch import Tensor -from torch.utils._pytree import tree_flatten -from torch.utils.data import DataLoader, TensorDataset -from torch.utils.data.dataset import Dataset, IterableDataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler, SequentialSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_compile.py b/tests/tests_pytorch/utilities/test_compile.py index 67f992421f7ce..a053c847dfd6c 100644 --- a/tests/tests_pytorch/utilities/test_compile.py +++ b/tests/tests_pytorch/utilities/test_compile.py @@ -18,12 +18,12 @@ import pytest import torch +from lightning_utilities.core.imports import RequirementCache + from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled -from lightning_utilities.core.imports import RequirementCache - from tests_pytorch.conftest import mock_cuda_count from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index d79e9e24383a0..65a3a47715bfe 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -4,6 +4,10 @@ import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler + from lightning.fabric.utilities.data import _replace_dunder_methods from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer @@ -19,9 +23,6 @@ warning_cache, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler def test_extract_batch_size(): diff --git a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py index a44ed655c2e61..05657854eb0b1 100644 --- a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py @@ -14,11 +14,11 @@ import os import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DeepSpeedStrategy from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py index c8a138ba0a02a..256233e01fa98 100644 --- a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py +++ b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py @@ -17,7 +17,6 @@ from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DeepSpeedStrategy from lightning.pytorch.utilities.model_summary import DeepSpeedSummary - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py index 1ba3aff359609..171656d072076 100644 --- a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py +++ b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch.nn as nn + from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_grads.py b/tests/tests_pytorch/utilities/test_grads.py index ade88f767670c..277971d64b180 100644 --- a/tests/tests_pytorch/utilities/test_grads.py +++ b/tests/tests_pytorch/utilities/test_grads.py @@ -16,6 +16,7 @@ import pytest import torch import torch.nn as nn + from lightning.pytorch.utilities import grad_norm diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 56ee326f076dc..301c97d756899 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -19,10 +19,10 @@ from unittest import mock import pytest -from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from lightning_utilities.core.imports import RequirementCache from torch.distributed import is_available +from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from tests_pytorch.helpers.runif import RunIf @@ -60,7 +60,7 @@ def new_fn(*args, **kwargs): return new_fn -@pytest.fixture() +@pytest.fixture def clean_import(): """This fixture allows test to import {pytorch_}lightning* modules completely cleanly, regardless of the current state of the imported modules. diff --git a/tests/tests_pytorch/utilities/test_memory.py b/tests/tests_pytorch/utilities/test_memory.py index c1ebff03a3a4f..336a9fafa3243 100644 --- a/tests/tests_pytorch/utilities/test_memory.py +++ b/tests/tests_pytorch/utilities/test_memory.py @@ -13,6 +13,7 @@ # limitations under the License. import torch + from lightning.pytorch.utilities.memory import recursive_detach diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index 78a63a7e9d2a7..e7a9d9275a484 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -16,10 +16,11 @@ import pytest import torch.nn +from lightning_utilities import module_available + from lightning.pytorch import LightningDataModule from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.utilities.model_helpers import _ModuleMode, _restricted_classmethod, is_overridden -from lightning_utilities import module_available def test_is_overridden(): diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index cced6546aab75..54c5572d01767 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -18,6 +18,7 @@ import pytest import torch import torch.nn as nn + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.model_summary.model_summary import ( @@ -27,7 +28,6 @@ ModelSummary, summarize, ) - from tests_pytorch.helpers.advanced_models import ParityModuleRNN from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_parameter_tying.py b/tests/tests_pytorch/utilities/test_parameter_tying.py index 9dc9b5648ff01..e45fb39f81b34 100644 --- a/tests/tests_pytorch/utilities/test_parameter_tying.py +++ b/tests/tests_pytorch/utilities/test_parameter_tying.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch +from torch import nn + from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters -from torch import nn class ParameterSharingModule(BoringModel): diff --git a/tests/tests_pytorch/utilities/test_parsing.py b/tests/tests_pytorch/utilities/test_parsing.py index 1c126723d89a6..a2671eb8790de 100644 --- a/tests/tests_pytorch/utilities/test_parsing.py +++ b/tests/tests_pytorch/utilities/test_parsing.py @@ -15,6 +15,8 @@ import threading import pytest +from torch.jit import ScriptModule + from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.utilities.parsing import ( _get_init_args, @@ -26,7 +28,6 @@ lightning_setattr, parse_class_init_keys, ) -from torch.jit import ScriptModule unpicklable_function = lambda: None @@ -103,12 +104,12 @@ def test_lightning_hasattr(): assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable" assert not lightning_hasattr(model4, "learning_rate"), "lightning_hasattr found variable when it should not" assert lightning_hasattr(model5, "batch_size"), "lightning_hasattr failed to find batch_size in datamodule" - assert lightning_hasattr( - model6, "batch_size" - ), "lightning_hasattr failed to find batch_size in datamodule w/ hparams present" - assert lightning_hasattr( - model7, "batch_size" - ), "lightning_hasattr failed to find batch_size in hparams w/ datamodule present" + assert lightning_hasattr(model6, "batch_size"), ( + "lightning_hasattr failed to find batch_size in datamodule w/ hparams present" + ) + assert lightning_hasattr(model7, "batch_size"), ( + "lightning_hasattr failed to find batch_size in hparams w/ datamodule present" + ) assert lightning_hasattr(model8, "batch_size") for m in models: diff --git a/tests/tests_pytorch/utilities/test_pytree.py b/tests/tests_pytorch/utilities/test_pytree.py index afd198919e23f..c87a83f85f6ea 100644 --- a/tests/tests_pytorch/utilities/test_pytree.py +++ b/tests/tests_pytorch/utilities/test_pytree.py @@ -1,7 +1,8 @@ import torch -from lightning.pytorch.utilities._pytree import _tree_flatten, tree_unflatten from torch.utils.data import DataLoader, TensorDataset +from lightning.pytorch.utilities._pytree import _tree_flatten, tree_unflatten + def assert_tree_flatten_unflatten(pytree, leaves): flat, spec = _tree_flatten(pytree) diff --git a/tests/tests_pytorch/utilities/test_seed.py b/tests/tests_pytorch/utilities/test_seed.py index 282009f6b93cf..00484009481e9 100644 --- a/tests/tests_pytorch/utilities/test_seed.py +++ b/tests/tests_pytorch/utilities/test_seed.py @@ -4,8 +4,8 @@ import numpy as np import pytest import torch -from lightning.pytorch.utilities.seed import isolate_rng +from lightning.pytorch.utilities.seed import isolate_rng from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_signature_utils.py b/tests/tests_pytorch/utilities/test_signature_utils.py index e453459670360..23f9258b0d56b 100644 --- a/tests/tests_pytorch/utilities/test_signature_utils.py +++ b/tests/tests_pytorch/utilities/test_signature_utils.py @@ -1,4 +1,5 @@ import torch + from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 2ef1ecd4fe3e5..15db0becb9551 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -18,6 +18,7 @@ import pytest import torch + from lightning.pytorch.utilities.upgrade_checkpoint import main as upgrade_main diff --git a/tests/tests_pytorch/utilities/test_warnings.py b/tests/tests_pytorch/utilities/test_warnings.py index 8c385a2d2e49f..ab7b7d2c15a4d 100644 --- a/tests/tests_pytorch/utilities/test_warnings.py +++ b/tests/tests_pytorch/utilities/test_warnings.py @@ -24,11 +24,12 @@ from io import StringIO from unittest import mock -import lightning.pytorch import pytest -from lightning.pytorch.utilities.warnings import PossibleUserWarning from lightning_utilities.test.warning import no_warning_call +import lightning.pytorch +from lightning.pytorch.utilities.warnings import PossibleUserWarning + if __name__ == "__main__": # check that logging is properly configured import logging