diff --git a/src/pl_bolts/callbacks/__init__.py b/src/pl_bolts/callbacks/__init__.py index 2225372f48..ce72e9dd10 100644 --- a/src/pl_bolts/callbacks/__init__.py +++ b/src/pl_bolts/callbacks/__init__.py @@ -1,4 +1,5 @@ """Collection of PyTorchLightning callbacks.""" + from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor from pl_bolts.callbacks.printing import PrintTableMetricsCallback diff --git a/src/pl_bolts/callbacks/data_monitor.py b/src/pl_bolts/callbacks/data_monitor.py index 7a39a1e709..0058e07c28 100644 --- a/src/pl_bolts/callbacks/data_monitor.py +++ b/src/pl_bolts/callbacks/data_monitor.py @@ -73,6 +73,7 @@ def log_histograms(self, batch: Any, group: str = "") -> None: Otherwise the histograms get labelled with an integer index. Each label also has the tensors's shape as suffix. group: Name under which the histograms will be grouped. + """ if not self._log or (self._train_batch_idx + 1) % self._log_every_n_steps != 0: # type: ignore[operator] return @@ -112,7 +113,7 @@ def _is_logger_available(self, logger: Logger) -> bool: if not isinstance(logger, self.supported_loggers): rank_zero_warn( f"{self.__class__.__name__} does not support logging with {logger.__class__.__name__}." - f" Supported loggers are: {', '.join((str(x.__name__) for x in self.supported_loggers))}" + f" Supported loggers are: {', '.join(str(x.__name__) for x in self.supported_loggers)}" ) available = False return available @@ -220,6 +221,7 @@ def __init__(self, log_every_n_steps: int = None) -> None: # log histogram of training data passed to `LightningModule.training_step` trainer = Trainer(callbacks=[TrainingDataMonitor()]) + """ super().__init__(log_every_n_steps=log_every_n_steps) diff --git a/src/pl_bolts/callbacks/verification/base.py b/src/pl_bolts/callbacks/verification/base.py index 49e2e3593a..1f6d59c0c8 100644 --- a/src/pl_bolts/callbacks/verification/base.py +++ b/src/pl_bolts/callbacks/verification/base.py @@ -77,6 +77,7 @@ def _model_forward(self, input_array: Any) -> Any: Returns: The output of the model. + """ if isinstance(input_array, tuple): return self.model(*input_array) @@ -105,8 +106,8 @@ def __init__(self, warn: bool = True, error: bool = False) -> None: self._raise_error = error def message(self, *args: Any, **kwargs: Any) -> str: - """The message to be printed when the model does not pass the verification. If the message for warning and - error differ, override the :meth:`warning_message` and :meth:`error_message` methods directly. + """The message to be printed when the model does not pass the verification. If the message for warning and error + differ, override the :meth:`warning_message` and :meth:`error_message` methods directly. Arguments: *args: Any positional arguments that are needed to construct the message. @@ -114,6 +115,7 @@ def message(self, *args: Any, **kwargs: Any) -> str: Returns: The message as a string. + """ def warning_message(self, *args: Any, **kwargs: Any) -> str: diff --git a/src/pl_bolts/callbacks/verification/batch_gradient.py b/src/pl_bolts/callbacks/verification/batch_gradient.py index 834184c152..107c647d9f 100644 --- a/src/pl_bolts/callbacks/verification/batch_gradient.py +++ b/src/pl_bolts/callbacks/verification/batch_gradient.py @@ -91,6 +91,7 @@ class BatchGradientVerificationCallback(VerificationCallbackBase): """The callback version of the :class:`BatchGradientVerification` test. Verification is performed right before training begins. + """ def __init__( @@ -211,12 +212,13 @@ def collect_batches(tensor: Tensor) -> Tensor: @under_review() @contextmanager def selective_eval(model: nn.Module, layer_types: Iterable[Type[nn.Module]]) -> None: - """A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance`` - check, so all subclasses are also affected. + """A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance`` check, + so all subclasses are also affected. Args: model: A model which has layers that need to be set to eval mode. layer_types: The list of class objects for which all layers of that type will be set to eval mode. + """ to_revert = [] try: diff --git a/src/pl_bolts/callbacks/vision/image_generation.py b/src/pl_bolts/callbacks/vision/image_generation.py index a30e78972b..0e860292de 100644 --- a/src/pl_bolts/callbacks/vision/image_generation.py +++ b/src/pl_bolts/callbacks/vision/image_generation.py @@ -31,6 +31,7 @@ class TensorboardGenerativeModelImageSampler(Callback): from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()]) + """ def __init__( diff --git a/src/pl_bolts/callbacks/vision/sr_image_logger.py b/src/pl_bolts/callbacks/vision/sr_image_logger.py index f27bd294e2..4a96330013 100644 --- a/src/pl_bolts/callbacks/vision/sr_image_logger.py +++ b/src/pl_bolts/callbacks/vision/sr_image_logger.py @@ -17,8 +17,8 @@ @under_review() class SRImageLoggerCallback(Callback): - """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement - the ``forward`` function for generation. + """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement the + ``forward`` function for generation. Requirements:: @@ -30,6 +30,7 @@ class SRImageLoggerCallback(Callback): from pl_bolts.callbacks import SRImageLoggerCallback trainer = Trainer(callbacks=[SRImageLoggerCallback()]) + """ def __init__(self, log_interval: int = 1000, scale_factor: int = 4, num_samples: int = 5) -> None: diff --git a/src/pl_bolts/datamodules/experience_source.py b/src/pl_bolts/datamodules/experience_source.py index 2a0d4467e4..7967b5ca46 100644 --- a/src/pl_bolts/datamodules/experience_source.py +++ b/src/pl_bolts/datamodules/experience_source.py @@ -1,5 +1,6 @@ """Datamodules for RL models that rely on experiences generated during training Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py.""" + from abc import ABC from collections import deque, namedtuple from typing import Callable, Iterator, List, Tuple diff --git a/src/pl_bolts/datamodules/imagenet_datamodule.py b/src/pl_bolts/datamodules/imagenet_datamodule.py index 90f2bb641d..3fec1bab68 100644 --- a/src/pl_bolts/datamodules/imagenet_datamodule.py +++ b/src/pl_bolts/datamodules/imagenet_datamodule.py @@ -167,12 +167,12 @@ def train_dataloader(self) -> DataLoader: return loader def val_dataloader(self) -> DataLoader: - """Uses the part of the train split of imagenet2012 that was not used for training via - `num_imgs_per_val_class` + """Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class` Args: batch_size: the batch size transforms: the transforms + """ transforms = self.val_transform() if self.val_transforms is None else self.val_transforms diff --git a/src/pl_bolts/datamodules/stl10_datamodule.py b/src/pl_bolts/datamodules/stl10_datamodule.py index 158baee0bc..bff8113256 100644 --- a/src/pl_bolts/datamodules/stl10_datamodule.py +++ b/src/pl_bolts/datamodules/stl10_datamodule.py @@ -139,6 +139,7 @@ def train_dataloader_mixed(self) -> DataLoader: batch_size: the batch size transforms: a sequence of transforms + """ transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms diff --git a/src/pl_bolts/datamodules/vocdetection_datamodule.py b/src/pl_bolts/datamodules/vocdetection_datamodule.py index de8a84b0f3..d6435bcc91 100644 --- a/src/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/src/pl_bolts/datamodules/vocdetection_datamodule.py @@ -166,6 +166,7 @@ def train_dataloader(self, image_transforms: Optional[Callable] = None) -> DataL Args: image_transforms: custom image-only transforms + """ transforms = [ _prepare_voc_instance, @@ -181,6 +182,7 @@ def val_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoa Args: image_transforms: custom image-only transforms + """ transforms = [ _prepare_voc_instance, diff --git a/src/pl_bolts/datasets/kitti_dataset.py b/src/pl_bolts/datasets/kitti_dataset.py index 0e6674224d..92b08ba9c3 100644 --- a/src/pl_bolts/datasets/kitti_dataset.py +++ b/src/pl_bolts/datasets/kitti_dataset.py @@ -86,6 +86,7 @@ def encode_segmap(self, mask): It also sets all of the valid pixels to the appropriate value between 0 and `len(valid_labels)` (the number of valid classes), so it can be used properly by the loss function when comparing with the output. + """ for voidc in self.void_labels: mask[mask == voidc] = self.ignore_index diff --git a/src/pl_bolts/datasets/sr_dataset_mixin.py b/src/pl_bolts/datasets/sr_dataset_mixin.py index cdeddce054..07570a0e69 100644 --- a/src/pl_bolts/datasets/sr_dataset_mixin.py +++ b/src/pl_bolts/datasets/sr_dataset_mixin.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" + from typing import Any, Tuple import torch diff --git a/src/pl_bolts/datasets/utils.py b/src/pl_bolts/datasets/utils.py index e3b085fbc1..52a25d46ea 100644 --- a/src/pl_bolts/datasets/utils.py +++ b/src/pl_bolts/datasets/utils.py @@ -60,6 +60,7 @@ def to_tensor(arrays: TArrays) -> torch.Tensor: Returns: Tensor of the integers + """ return torch.tensor(arrays) diff --git a/src/pl_bolts/models/__init__.py b/src/pl_bolts/models/__init__.py index 66f237728d..7ddbd599bb 100644 --- a/src/pl_bolts/models/__init__.py +++ b/src/pl_bolts/models/__init__.py @@ -1,4 +1,5 @@ """Collection of PyTorchLightning models.""" + from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE from pl_bolts.models.mnist_module import LitMNIST diff --git a/src/pl_bolts/models/detection/yolo/darknet_network.py b/src/pl_bolts/models/detection/yolo/darknet_network.py index 7a38a0f5d2..78f87d62d9 100644 --- a/src/pl_bolts/models/detection/yolo/darknet_network.py +++ b/src/pl_bolts/models/detection/yolo/darknet_network.py @@ -145,6 +145,7 @@ def read(tensor: Tensor) -> int: """Reads the contents of ``tensor`` from the current position of ``weight_file``. Returns the number of elements read. If there's no more data in ``weight_file``, returns 0. + """ np_array = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) num_elements = np_array.size @@ -275,8 +276,8 @@ def convert(key: str, value: str) -> Union[str, int, float, List[Union[str, int, def _create_layer(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: - """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the - layer config. + """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the layer + config. Args: config: Dictionary of configuration options for this layer. @@ -285,6 +286,7 @@ def _create_layer(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREAT Returns: module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in its output. + """ create_func: Dict[str, Callable[..., CREATE_LAYER_OUTPUT]] = { "convolutional": _create_convolutional, diff --git a/src/pl_bolts/models/detection/yolo/loss.py b/src/pl_bolts/models/detection/yolo/loss.py index 44ac5b0f11..6bcaadd0de 100644 --- a/src/pl_bolts/models/detection/yolo/loss.py +++ b/src/pl_bolts/models/detection/yolo/loss.py @@ -205,6 +205,7 @@ def _target_labels_to_probs( Returns: An ``[M, C]`` matrix of target class probabilities. + """ if targets.ndim == 1: # The data may contain a different number of classes than what the model predicts. In case a label is diff --git a/src/pl_bolts/models/detection/yolo/torch_networks.py b/src/pl_bolts/models/detection/yolo/torch_networks.py index ee5358ac7f..9e59eec796 100644 --- a/src/pl_bolts/models/detection/yolo/torch_networks.py +++ b/src/pl_bolts/models/detection/yolo/torch_networks.py @@ -31,6 +31,7 @@ def run_detection( detections: A list where a tensor containing the detections will be appended to. losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + """ output, preds = detection_layer(layer_input, image_size) detections.append(output) @@ -69,6 +70,7 @@ def run_detection_with_aux_head( detections: A list where a tensor containing the detections will be appended to. losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + """ output, preds = detection_layer(layer_input, image_size) detections.append(output) @@ -1132,8 +1134,8 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU class YOLOV5Network(nn.Module): - """The YOLOv5 network architecture. Different variants (n/s/m/l/x) can be achieved by adjusting the ``depth`` - and ``width`` parameters. + """The YOLOv5 network architecture. Different variants (n/s/m/l/x) can be achieved by adjusting the ``depth`` and + ``width`` parameters. Args: num_classes: Number of different classes that this model predicts. @@ -1176,6 +1178,7 @@ class YOLOV5Network(nn.Module): class_loss_multiplier: Classification loss will be scaled by this value. xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps to produce coordinate values close to one. + """ def __init__( @@ -1613,8 +1616,8 @@ def forward(self, x: Tensor) -> Tensor: class YOLOXNetwork(nn.Module): - """The YOLOX network architecture. Different variants (nano/tiny/s/m/l/x) can be achieved by adjusting the - ``depth`` and ``width`` parameters. + """The YOLOX network architecture. Different variants (nano/tiny/s/m/l/x) can be achieved by adjusting the ``depth`` + and ``width`` parameters. Args: num_classes: Number of different classes that this model predicts. @@ -1657,6 +1660,7 @@ class YOLOXNetwork(nn.Module): class_loss_multiplier: Classification loss will be scaled by this value. xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps to produce coordinate values close to one. + """ def __init__( diff --git a/src/pl_bolts/models/detection/yolo/utils.py b/src/pl_bolts/models/detection/yolo/utils.py index d981fadceb..66996fc4d8 100644 --- a/src/pl_bolts/models/detection/yolo/utils.py +++ b/src/pl_bolts/models/detection/yolo/utils.py @@ -102,8 +102,8 @@ def aligned_iou(wh1: Tensor, wh2: Tensor) -> Tensor: def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Tensor: - """Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target - significantly (IoU greater than ``threshold``). + """Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target significantly + (IoU greater than ``threshold``). Args: pred_boxes: The predicted corner coordinates. Tensor of size ``[height, width, boxes_per_cell, 4]``. @@ -112,6 +112,7 @@ def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Ten Returns: A boolean tensor sized ``[height, width, boxes_per_cell]``, with ``False`` where the predicted box overlaps a target significantly and ``True`` elsewhere. + """ shape = pred_boxes.shape[:-1] pred_boxes = pred_boxes.view(-1, 4) diff --git a/src/pl_bolts/models/detection/yolo/yolo_module.py b/src/pl_bolts/models/detection/yolo/yolo_module.py index 429ed087b7..558cb496b7 100644 --- a/src/pl_bolts/models/detection/yolo/yolo_module.py +++ b/src/pl_bolts/models/detection/yolo/yolo_module.py @@ -144,8 +144,8 @@ def __init__( def forward( self, images: Union[Tensor, IMAGES], targets: Optional[TARGETS] = None ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets - are provided, computes the losses from the detection layers. + """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are + provided, computes the losses from the detection layers. Detections are concatenated from the detection layers. Each detection layer will produce a number of detections that depends on the size of the feature map and the number of anchors per feature map cell. @@ -161,6 +161,7 @@ def forward( provided, a dictionary of losses. Detections are shaped ``[batch_size, anchors, classes + 5]``, where ``anchors`` is the feature map size (width * height) times the number of anchors per cell. The predicted box coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. + """ self.validate_batch(images, targets) images_tensor = images if isinstance(images, Tensor) else torch.stack(images) @@ -185,6 +186,7 @@ def configure_optimizers(self) -> Tuple[List[optim.Optimizer], List[LRScheduler] If weight decay is specified, it will be applied only to convolutional layer weights, as they contain much more parameters than the biases and batch normalization parameters. Regularizing all parameters could lead to underfitting. + """ if ("weight_decay" in self.optimizer_params) and (self.optimizer_params["weight_decay"] != 0): defaults = copy(self.optimizer_params) @@ -574,12 +576,13 @@ def __init__( class ResizedVOCDetectionDataModule(VOCDetectionDataModule): - """A subclass of ``VOCDetectionDataModule`` that resizes the images to a specific size. YOLO expectes the image - size to be divisible by the ratio in which the network downsamples the image. + """A subclass of ``VOCDetectionDataModule`` that resizes the images to a specific size. YOLO expectes the image size + to be divisible by the ratio in which the network downsamples the image. Args: width: Resize images to this width. height: Resize images to this height. + """ def __init__(self, width: int = 608, height: int = 608, **kwargs: Any): @@ -609,6 +612,7 @@ def _resize(self, image: Tensor, target: TARGET) -> Tuple[Tensor, TARGET]: Returns: Resized image tensor. + """ device = target["boxes"].device height, width = image.shape[-2:] diff --git a/src/pl_bolts/models/gans/srgan/components.py b/src/pl_bolts/models/gans/srgan/components.py index 99ad9d2e6a..63a531c46a 100644 --- a/src/pl_bolts/models/gans/srgan/components.py +++ b/src/pl_bolts/models/gans/srgan/components.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" + import torch import torch.nn as nn diff --git a/src/pl_bolts/models/gans/srgan/srgan_module.py b/src/pl_bolts/models/gans/srgan/srgan_module.py index ef11f10dc2..2799aa5c29 100644 --- a/src/pl_bolts/models/gans/srgan/srgan_module.py +++ b/src/pl_bolts/models/gans/srgan/srgan_module.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" + from argparse import ArgumentParser from pathlib import Path from typing import Any, List, Optional, Tuple diff --git a/src/pl_bolts/models/gans/srgan/srresnet_module.py b/src/pl_bolts/models/gans/srgan/srresnet_module.py index fc6ba2498b..1545e63391 100644 --- a/src/pl_bolts/models/gans/srgan/srresnet_module.py +++ b/src/pl_bolts/models/gans/srgan/srresnet_module.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" + from argparse import ArgumentParser from typing import Any, Tuple diff --git a/src/pl_bolts/models/rl/advantage_actor_critic_model.py b/src/pl_bolts/models/rl/advantage_actor_critic_model.py index e4863e32fe..30c73e446d 100644 --- a/src/pl_bolts/models/rl/advantage_actor_critic_model.py +++ b/src/pl_bolts/models/rl/advantage_actor_critic_model.py @@ -1,4 +1,5 @@ """Advantage Actor Critic (A2C)""" + from argparse import ArgumentParser from collections import OrderedDict from typing import Any, Iterator, List, Tuple diff --git a/src/pl_bolts/models/rl/common/agents.py b/src/pl_bolts/models/rl/common/agents.py index 116b0b89dd..ad3746bbaf 100644 --- a/src/pl_bolts/models/rl/common/agents.py +++ b/src/pl_bolts/models/rl/common/agents.py @@ -2,6 +2,7 @@ https://github.com/Shmuma/ptan/blob/master/ptan/agent.py. """ + from abc import ABC from typing import List diff --git a/src/pl_bolts/models/rl/common/distributions.py b/src/pl_bolts/models/rl/common/distributions.py index c589c2db3a..495fbb0818 100644 --- a/src/pl_bolts/models/rl/common/distributions.py +++ b/src/pl_bolts/models/rl/common/distributions.py @@ -1,4 +1,5 @@ """Distributions used in some continuous RL algorithms.""" + import torch from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/models/rl/common/gym_wrappers.py b/src/pl_bolts/models/rl/common/gym_wrappers.py index 605b498a7a..573def6fa5 100644 --- a/src/pl_bolts/models/rl/common/gym_wrappers.py +++ b/src/pl_bolts/models/rl/common/gym_wrappers.py @@ -1,5 +1,6 @@ """Set of wrapper functions for gym environments taken from https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py.""" + import collections import numpy as np diff --git a/src/pl_bolts/models/rl/common/networks.py b/src/pl_bolts/models/rl/common/networks.py index 63aad43a11..b920ae24ff 100644 --- a/src/pl_bolts/models/rl/common/networks.py +++ b/src/pl_bolts/models/rl/common/networks.py @@ -1,4 +1,5 @@ """Series of networks used Based on implementations found here:""" + import math from typing import Tuple diff --git a/src/pl_bolts/models/rl/double_dqn_model.py b/src/pl_bolts/models/rl/double_dqn_model.py index 2d76279c87..b2d36ca0c2 100644 --- a/src/pl_bolts/models/rl/double_dqn_model.py +++ b/src/pl_bolts/models/rl/double_dqn_model.py @@ -1,4 +1,5 @@ """Double DQN.""" + import argparse from collections import OrderedDict from typing import Tuple diff --git a/src/pl_bolts/models/rl/dqn_model.py b/src/pl_bolts/models/rl/dqn_model.py index 567aa8d185..bfafce3997 100644 --- a/src/pl_bolts/models/rl/dqn_model.py +++ b/src/pl_bolts/models/rl/dqn_model.py @@ -1,4 +1,5 @@ """Deep Q Network.""" + import argparse from collections import OrderedDict from typing import Dict, List, Optional, Tuple diff --git a/src/pl_bolts/models/rl/dueling_dqn_model.py b/src/pl_bolts/models/rl/dueling_dqn_model.py index 1e072d5ff8..d7e2b939e3 100644 --- a/src/pl_bolts/models/rl/dueling_dqn_model.py +++ b/src/pl_bolts/models/rl/dueling_dqn_model.py @@ -1,4 +1,5 @@ """Dueling DQN.""" + import argparse from pytorch_lightning import Trainer diff --git a/src/pl_bolts/models/rl/noisy_dqn_model.py b/src/pl_bolts/models/rl/noisy_dqn_model.py index 76b4531c5b..bfb877cd8e 100644 --- a/src/pl_bolts/models/rl/noisy_dqn_model.py +++ b/src/pl_bolts/models/rl/noisy_dqn_model.py @@ -1,4 +1,5 @@ """Noisy DQN.""" + import argparse from typing import Tuple diff --git a/src/pl_bolts/models/rl/per_dqn_model.py b/src/pl_bolts/models/rl/per_dqn_model.py index a864afb51b..1440587421 100644 --- a/src/pl_bolts/models/rl/per_dqn_model.py +++ b/src/pl_bolts/models/rl/per_dqn_model.py @@ -1,4 +1,5 @@ """Prioritized Experience Replay DQN.""" + import argparse from collections import OrderedDict from typing import Tuple diff --git a/src/pl_bolts/models/rl/ppo_model.py b/src/pl_bolts/models/rl/ppo_model.py index 21bc0873c0..b861a66619 100644 --- a/src/pl_bolts/models/rl/ppo_model.py +++ b/src/pl_bolts/models/rl/ppo_model.py @@ -319,8 +319,7 @@ def configure_optimizers(self) -> List[Optimizer]: return optimizer_actor, optimizer_critic def optimizer_step(self, *args, **kwargs): - """Run ``num_optim_iters`` number of iterations of gradient descent on actor and critic for each data - sample.""" + """Run ``num_optim_iters`` number of iterations of gradient descent on actor and critic for each data sample.""" for _ in range(self.num_optim_iters): super().optimizer_step(*args, **kwargs) diff --git a/src/pl_bolts/models/rl/sac_model.py b/src/pl_bolts/models/rl/sac_model.py index 8c0bb2b712..56aba5f530 100644 --- a/src/pl_bolts/models/rl/sac_model.py +++ b/src/pl_bolts/models/rl/sac_model.py @@ -1,4 +1,5 @@ """Soft Actor Critic.""" + import argparse from typing import Dict, List, Tuple diff --git a/src/pl_bolts/models/self_supervised/__init__.py b/src/pl_bolts/models/self_supervised/__init__.py index f501ee73a0..ede1cbf848 100644 --- a/src/pl_bolts/models/self_supervised/__init__.py +++ b/src/pl_bolts/models/self_supervised/__init__.py @@ -17,6 +17,7 @@ classifications = classifier(representations) """ + from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM from pl_bolts.models.self_supervised.byol.byol_module import BYOL from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2 diff --git a/src/pl_bolts/models/self_supervised/cpc/cpc_module.py b/src/pl_bolts/models/self_supervised/cpc/cpc_module.py index 6f60b7a267..33317c9706 100644 --- a/src/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/src/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -1,4 +1,5 @@ """CPC V2.""" + import math from argparse import ArgumentParser from typing import Optional diff --git a/src/pl_bolts/models/self_supervised/moco/moco_module.py b/src/pl_bolts/models/self_supervised/moco/moco_module.py index 86d01147c3..19f6d99c63 100644 --- a/src/pl_bolts/models/self_supervised/moco/moco_module.py +++ b/src/pl_bolts/models/self_supervised/moco/moco_module.py @@ -8,6 +8,7 @@ You may obtain a copy of the License from the LICENSE file present in this folder. """ + from copy import copy, deepcopy from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -228,6 +229,7 @@ def configure_optimizers(self) -> Tuple[List[optim.Optimizer], List[optim.lr_sch ``self.lr_scheduler_params``. If weight decay is specified, it will be applied only to convolutional layer weights. + """ if ( ("weight_decay" in self.optimizer_params) diff --git a/src/pl_bolts/models/self_supervised/moco/utils.py b/src/pl_bolts/models/self_supervised/moco/utils.py index 116b52f979..030ec7079d 100644 --- a/src/pl_bolts/models/self_supervised/moco/utils.py +++ b/src/pl_bolts/models/self_supervised/moco/utils.py @@ -101,6 +101,7 @@ def concatenate_all(tensor: Tensor) -> Tensor: """Performs ``all_gather`` operation to concatenate the provided tensor from all devices. This function has no gradient. + """ gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(gathered_tensor, tensor.contiguous()) diff --git a/src/pl_bolts/models/self_supervised/swav/swav_module.py b/src/pl_bolts/models/self_supervised/swav/swav_module.py index e212358c40..fe5990676d 100644 --- a/src/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/src/pl_bolts/models/self_supervised/swav/swav_module.py @@ -1,4 +1,5 @@ """Adapted from official swav implementation: https://github.com/facebookresearch/swav.""" + import os from argparse import ArgumentParser diff --git a/src/pl_bolts/models/self_supervised/swav/swav_resnet.py b/src/pl_bolts/models/self_supervised/swav/swav_resnet.py index 65b09dbe9a..2c2ffe5d96 100644 --- a/src/pl_bolts/models/self_supervised/swav/swav_resnet.py +++ b/src/pl_bolts/models/self_supervised/swav/swav_resnet.py @@ -1,4 +1,5 @@ """Adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py.""" + import torch from torch import nn diff --git a/src/pl_bolts/models/vision/pixel_cnn.py b/src/pl_bolts/models/vision/pixel_cnn.py index 70be9dcd19..64edc93775 100644 --- a/src/pl_bolts/models/vision/pixel_cnn.py +++ b/src/pl_bolts/models/vision/pixel_cnn.py @@ -4,6 +4,7 @@ : https: //arxiv.org/pdf/1905.09272.pdf (page 15 Accessed: May 14, 2020. """ + from torch import nn from torch.nn import functional as F # noqa: N812 diff --git a/src/pl_bolts/optimizers/lars.py b/src/pl_bolts/optimizers/lars.py index 58ed202f24..7b9e41c45c 100644 --- a/src/pl_bolts/optimizers/lars.py +++ b/src/pl_bolts/optimizers/lars.py @@ -3,6 +3,7 @@ - https://arxiv.org/pdf/1708.03888.pdf - https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py """ + import torch from torch.optim.optimizer import Optimizer, required diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index 3d3ddc0c4c..de5b24c4e2 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -3,11 +3,12 @@ import pytest import torch +from pytorch_lightning import Trainer +from torch import nn + from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models import LitMNIST -from pytorch_lightning import Trainer -from torch import nn # @pytest.mark.parametrize(("log_every_n_steps", "max_steps", "expected_calls"), [pytest.param(3, 10, 3)]) diff --git a/tests/callbacks/test_ort.py b/tests/callbacks/test_ort.py index 9186c3ad3e..aba587bf3b 100644 --- a/tests/callbacks/test_ort.py +++ b/tests/callbacks/test_ort.py @@ -13,12 +13,12 @@ # limitations under the License. import pytest -from pl_bolts.callbacks import ORTCallback -from pl_bolts.utils import _TORCH_ORT_AVAILABLE from pytorch_lightning import Callback, Trainer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pl_bolts.callbacks import ORTCallback +from pl_bolts.utils import _TORCH_ORT_AVAILABLE from tests.helpers.boring_model import BoringModel if _TORCH_ORT_AVAILABLE: diff --git a/tests/callbacks/test_param_update_callbacks.py b/tests/callbacks/test_param_update_callbacks.py index 20b6714257..a7c545b1bb 100644 --- a/tests/callbacks/test_param_update_callbacks.py +++ b/tests/callbacks/test_param_update_callbacks.py @@ -2,9 +2,10 @@ import pytest import torch -from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate from torch import nn +from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate + @pytest.mark.parametrize("initial_tau", [-0.1, 0.0, 0.996, 1.0, 1.1]) def test_byol_ma_weight_single_update_callback(initial_tau, catch_warnings): diff --git a/tests/callbacks/test_sparseml.py b/tests/callbacks/test_sparseml.py index fdef958086..56f5a274c0 100644 --- a/tests/callbacks/test_sparseml.py +++ b/tests/callbacks/test_sparseml.py @@ -16,18 +16,18 @@ import pytest import torch -from pl_bolts.callbacks import SparseMLCallback -from pl_bolts.utils import _SPARSEML_TORCH_SATISFIED from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pl_bolts.callbacks import SparseMLCallback +from pl_bolts.utils import _SPARSEML_TORCH_SATISFIED from tests.helpers.boring_model import BoringModel if _SPARSEML_TORCH_SATISFIED: from sparseml.pytorch.optim import RecipeManagerStepWrapper -@pytest.fixture() +@pytest.fixture def recipe(): return """ version: 0.1.0 diff --git a/tests/callbacks/verification/test_base.py b/tests/callbacks/verification/test_base.py index 624d2a9206..ed022c4c8a 100644 --- a/tests/callbacks/verification/test_base.py +++ b/tests/callbacks/verification/test_base.py @@ -3,11 +3,11 @@ import pytest import torch import torch.nn as nn -from pl_bolts.callbacks.verification.base import VerificationBase -from pl_bolts.utils import _PL_GREATER_EQUAL_1_4 from pytorch_lightning import LightningModule from pytorch_lightning.utilities import move_data_to_device +from pl_bolts.callbacks.verification.base import VerificationBase +from pl_bolts.utils import _PL_GREATER_EQUAL_1_4 from tests import _MARK_REQUIRE_GPU diff --git a/tests/callbacks/verification/test_batch_gradient.py b/tests/callbacks/verification/test_batch_gradient.py index bc25bf3b49..5df3e91f4a 100644 --- a/tests/callbacks/verification/test_batch_gradient.py +++ b/tests/callbacks/verification/test_batch_gradient.py @@ -2,13 +2,13 @@ import pytest import torch -from pl_bolts.callbacks import BatchGradientVerificationCallback -from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping, selective_eval -from pl_bolts.utils import BatchGradientVerification from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor, nn +from pl_bolts.callbacks import BatchGradientVerificationCallback +from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping, selective_eval +from pl_bolts.utils import BatchGradientVerification from tests import _MARK_REQUIRE_GPU diff --git a/tests/conftest.py b/tests/conftest.py index b1340ac335..bd69eb5e10 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,9 +5,10 @@ import pytest import torch +from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector + from pl_bolts.utils import _IS_WINDOWS, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13 from pl_bolts.utils.stability import UnderReviewWarning -from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector # GitHub Actions use this path to cache datasets. # Use `datadir` fixture where possible and use `DATASETS_PATH` in @@ -21,7 +22,7 @@ def datadir(): return Path(DATASETS_PATH) -@pytest.fixture() +@pytest.fixture def catch_warnings(): # noqa: PT004 with warnings.catch_warnings(): warnings.simplefilter("error") diff --git a/tests/datamodules/test_dataloader.py b/tests/datamodules/test_dataloader.py index 3d7d05a52f..7471d80f4f 100644 --- a/tests/datamodules/test_dataloader.py +++ b/tests/datamodules/test_dataloader.py @@ -1,7 +1,8 @@ import torch +from torch.utils.data import DataLoader + from pl_bolts.datamodules.async_dataloader import AsynchronousLoader from pl_bolts.datasets.cifar10_dataset import CIFAR10 -from torch.utils.data import DataLoader def test_async_dataloader(datadir): diff --git a/tests/datamodules/test_datamodules.py b/tests/datamodules/test_datamodules.py index 7dce4e3f01..76695b7804 100644 --- a/tests/datamodules/test_datamodules.py +++ b/tests/datamodules/test_datamodules.py @@ -4,6 +4,7 @@ import pytest import torch from PIL import Image + from pl_bolts.datamodules import ( BinaryEMNISTDataModule, BinaryMNISTDataModule, diff --git a/tests/datamodules/test_experience_sources.py b/tests/datamodules/test_experience_sources.py index 78aeed9045..41bddda1d5 100644 --- a/tests/datamodules/test_experience_sources.py +++ b/tests/datamodules/test_experience_sources.py @@ -4,6 +4,8 @@ import gym import numpy as np import torch +from torch.utils.data import DataLoader + from pl_bolts.datamodules.experience_source import ( BaseExperienceSource, DiscountedExperienceSource, @@ -12,7 +14,6 @@ ExperienceSourceDataset, ) from pl_bolts.models.rl.common.agents import Agent -from torch.utils.data import DataLoader class DummyAgent(Agent): diff --git a/tests/datamodules/test_sklearn_dataloaders.py b/tests/datamodules/test_sklearn_dataloaders.py index f8bad22ebc..89abe88370 100644 --- a/tests/datamodules/test_sklearn_dataloaders.py +++ b/tests/datamodules/test_sklearn_dataloaders.py @@ -2,9 +2,10 @@ import numpy as np import pytest -from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule from pytorch_lightning import seed_everything +from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule + try: from sklearn.utils import shuffle as sk_shuffle diff --git a/tests/datasets/test_array_dataset.py b/tests/datasets/test_array_dataset.py index 0bf87bd9c0..9f1b0dba26 100644 --- a/tests/datasets/test_array_dataset.py +++ b/tests/datasets/test_array_dataset.py @@ -1,13 +1,14 @@ import numpy as np import pytest import torch +from pytorch_lightning.utilities import exceptions + from pl_bolts.datasets import ArrayDataset, DataModel from pl_bolts.datasets.utils import to_tensor -from pytorch_lightning.utilities import exceptions class TestArrayDataset: - @pytest.fixture() + @pytest.fixture def array_dataset(self): features_1 = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]], transform=to_tensor) target_1 = DataModel(data=[1, 0, 0, 1], transform=to_tensor) diff --git a/tests/datasets/test_base_dataset.py b/tests/datasets/test_base_dataset.py index 71bd7eeb17..65ed0379df 100644 --- a/tests/datasets/test_base_dataset.py +++ b/tests/datasets/test_base_dataset.py @@ -1,13 +1,14 @@ import numpy as np import pytest import torch + from pl_bolts.datasets.base_dataset import DataModel from pl_bolts.datasets.utils import to_tensor from pl_bolts.utils import _IS_WINDOWS class TestDataModel: - @pytest.fixture() + @pytest.fixture def data(self): return np.array([[1, 0, 0, 1], [0, 1, 1, 0]]) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 62da40798d..0378c54870 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -3,6 +3,9 @@ import numpy as np import pytest import torch +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms as transform_lib + from pl_bolts.datasets import ( BinaryEMNIST, BinaryMNIST, @@ -16,8 +19,6 @@ from pl_bolts.datasets.sr_mnist_dataset import SRMNIST from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms as transform_lib if _PIL_AVAILABLE: from PIL import Image diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 47603a2961..8f47578544 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -1,5 +1,6 @@ import numpy as np import torch.testing + from pl_bolts.datasets.utils import to_tensor diff --git a/tests/losses/test_rl_loss.py b/tests/losses/test_rl_loss.py index 81859b99ec..6ce05b3081 100644 --- a/tests/losses/test_rl_loss.py +++ b/tests/losses/test_rl_loss.py @@ -4,10 +4,11 @@ import numpy as np import torch +from torch import Tensor + from pl_bolts.losses.rl import double_dqn_loss, dqn_loss, per_dqn_loss from pl_bolts.models.rl.common.gym_wrappers import make_environment from pl_bolts.models.rl.common.networks import CNN -from torch import Tensor class TestRLLoss(TestCase): diff --git a/tests/metrics/test_aggregation.py b/tests/metrics/test_aggregation.py index 84d4ee9b39..1aa65e0ed8 100644 --- a/tests/metrics/test_aggregation.py +++ b/tests/metrics/test_aggregation.py @@ -2,6 +2,7 @@ import pytest import torch + from pl_bolts.metrics.aggregation import accuracy, mean, precision_at_k diff --git a/tests/models/gans/integration/test_gans.py b/tests/models/gans/integration/test_gans.py index b04172a0e0..60cf5b4342 100644 --- a/tests/models/gans/integration/test_gans.py +++ b/tests/models/gans/integration/test_gans.py @@ -1,15 +1,16 @@ import warnings import pytest -from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule -from pl_bolts.datasets.sr_mnist_dataset import SRMNIST -from pl_bolts.models.gans import DCGAN, GAN, SRGAN, SRResNet -from pl_bolts.utils import _IS_WINDOWS from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data.dataloader import DataLoader from torchvision import transforms as transform_lib +from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule +from pl_bolts.datasets.sr_mnist_dataset import SRMNIST +from pl_bolts.models.gans import DCGAN, GAN, SRGAN, SRResNet +from pl_bolts.utils import _IS_WINDOWS + @pytest.mark.parametrize( "dm_cls", diff --git a/tests/models/gans/unit/test_basic_components.py b/tests/models/gans/unit/test_basic_components.py index e761678e09..153f483aec 100644 --- a/tests/models/gans/unit/test_basic_components.py +++ b/tests/models/gans/unit/test_basic_components.py @@ -1,8 +1,9 @@ import pytest import torch -from pl_bolts.models.gans.basic.components import Discriminator, Generator from pytorch_lightning import seed_everything +from pl_bolts.models.gans.basic.components import Discriminator, Generator + @pytest.mark.parametrize( ("latent_dim", "img_shape"), diff --git a/tests/models/regression/test_logistic_regression.py b/tests/models/regression/test_logistic_regression.py index e03405b690..ab625246b6 100644 --- a/tests/models/regression/test_logistic_regression.py +++ b/tests/models/regression/test_logistic_regression.py @@ -2,6 +2,7 @@ import operator import pytorch_lightning as pl + from pl_bolts import datamodules from pl_bolts.models import regression diff --git a/tests/models/rl/integration/test_actor_critic_models.py b/tests/models/rl/integration/test_actor_critic_models.py index 041da16722..f96eb9a203 100644 --- a/tests/models/rl/integration/test_actor_critic_models.py +++ b/tests/models/rl/integration/test_actor_critic_models.py @@ -2,10 +2,11 @@ import pytest import torch.cuda +from pytorch_lightning import Trainer + from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic from pl_bolts.models.rl.sac_model import SAC from pl_bolts.utils import _GYM_GREATER_EQUAL_0_20 -from pytorch_lightning import Trainer def test_a2c_cli(): diff --git a/tests/models/rl/integration/test_policy_models.py b/tests/models/rl/integration/test_policy_models.py index 62a8fff3db..0639006545 100644 --- a/tests/models/rl/integration/test_policy_models.py +++ b/tests/models/rl/integration/test_policy_models.py @@ -2,9 +2,10 @@ from unittest import TestCase import torch +from pytorch_lightning import Trainer + from pl_bolts.models.rl.reinforce_model import Reinforce from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient -from pytorch_lightning import Trainer class TestPolicyModels(TestCase): diff --git a/tests/models/rl/integration/test_value_models.py b/tests/models/rl/integration/test_value_models.py index ac1f5adb1b..a340bf2d89 100644 --- a/tests/models/rl/integration/test_value_models.py +++ b/tests/models/rl/integration/test_value_models.py @@ -3,13 +3,14 @@ import pytest import torch +from pytorch_lightning import Trainer + from pl_bolts.models.rl.double_dqn_model import DoubleDQN from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN from pl_bolts.models.rl.per_dqn_model import PERDQN from pl_bolts.utils import _IS_WINDOWS -from pytorch_lightning import Trainer class TestValueModels(TestCase): diff --git a/tests/models/rl/unit/test_a2c.py b/tests/models/rl/unit/test_a2c.py index 34bce5fd7c..adfb5cf833 100644 --- a/tests/models/rl/unit/test_a2c.py +++ b/tests/models/rl/unit/test_a2c.py @@ -1,9 +1,10 @@ import argparse import torch -from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic from torch import Tensor +from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic + def test_a2c_loss(): """Test the reinforce loss function.""" diff --git a/tests/models/rl/unit/test_agents.py b/tests/models/rl/unit/test_agents.py index e414f439f5..91a7a708a2 100644 --- a/tests/models/rl/unit/test_agents.py +++ b/tests/models/rl/unit/test_agents.py @@ -1,13 +1,15 @@ """Tests that the agent module works correctly.""" + from unittest import TestCase from unittest.mock import Mock import gym import numpy as np import torch -from pl_bolts.models.rl.common.agents import ActorCriticAgent, Agent, PolicyAgent, ValueAgent from torch import Tensor +from pl_bolts.models.rl.common.agents import ActorCriticAgent, Agent, PolicyAgent, ValueAgent + class TestAgents(TestCase): def setUp(self) -> None: diff --git a/tests/models/rl/unit/test_memory.py b/tests/models/rl/unit/test_memory.py index c199acde39..d753da2516 100644 --- a/tests/models/rl/unit/test_memory.py +++ b/tests/models/rl/unit/test_memory.py @@ -3,6 +3,7 @@ import numpy as np import torch + from pl_bolts.models.rl.common.memory import Buffer, Experience, MultiStepBuffer, PERBuffer, ReplayBuffer diff --git a/tests/models/rl/unit/test_ppo.py b/tests/models/rl/unit/test_ppo.py index eac100dc02..7685f2dbcb 100644 --- a/tests/models/rl/unit/test_ppo.py +++ b/tests/models/rl/unit/test_ppo.py @@ -1,9 +1,10 @@ import numpy as np import torch -from pl_bolts.models.rl.ppo_model import PPO from pytorch_lightning import Trainer from torch import Tensor +from pl_bolts.models.rl.ppo_model import PPO + def test_discount_rewards(): """Test calculation of discounted rewards.""" diff --git a/tests/models/rl/unit/test_reinforce.py b/tests/models/rl/unit/test_reinforce.py index aa6703f761..1346665569 100644 --- a/tests/models/rl/unit/test_reinforce.py +++ b/tests/models/rl/unit/test_reinforce.py @@ -4,12 +4,13 @@ import gym import numpy as np import torch +from torch import Tensor + from pl_bolts.datamodules.experience_source import DiscountedExperienceSource from pl_bolts.models.rl.common.agents import Agent from pl_bolts.models.rl.common.gym_wrappers import ToTensor from pl_bolts.models.rl.common.networks import MLP from pl_bolts.models.rl.reinforce_model import Reinforce -from torch import Tensor class TestReinforce(TestCase): diff --git a/tests/models/rl/unit/test_sac.py b/tests/models/rl/unit/test_sac.py index 2380e375a7..b0e85d814e 100644 --- a/tests/models/rl/unit/test_sac.py +++ b/tests/models/rl/unit/test_sac.py @@ -2,9 +2,10 @@ import pytest import torch +from torch import Tensor + from pl_bolts.models.rl.sac_model import SAC from pl_bolts.utils import _GYM_GREATER_EQUAL_0_20 -from torch import Tensor @pytest.mark.skipif(_GYM_GREATER_EQUAL_0_20, reason="gym.error.DeprecatedEnv: Env Pendulum-v0 not found") diff --git a/tests/models/rl/unit/test_vpg.py b/tests/models/rl/unit/test_vpg.py index 4acc1bdddd..1919eb6351 100644 --- a/tests/models/rl/unit/test_vpg.py +++ b/tests/models/rl/unit/test_vpg.py @@ -4,12 +4,13 @@ import gym import pytest import torch +from torch import Tensor + from pl_bolts.models.rl.common.agents import Agent from pl_bolts.models.rl.common.gym_wrappers import ToTensor from pl_bolts.models.rl.common.networks import MLP from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient from pl_bolts.utils import _IS_WINDOWS -from torch import Tensor class TestPolicyGradient(TestCase): diff --git a/tests/models/rl/unit/test_wrappers.py b/tests/models/rl/unit/test_wrappers.py index e61d7031f9..f246a257c9 100644 --- a/tests/models/rl/unit/test_wrappers.py +++ b/tests/models/rl/unit/test_wrappers.py @@ -1,9 +1,10 @@ from unittest import TestCase import gym -from pl_bolts.models.rl.common.gym_wrappers import ToTensor from torch import Tensor +from pl_bolts.models.rl.common.gym_wrappers import ToTensor + class TestToTensor(TestCase): def setUp(self) -> None: diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index d5c0108486..e8b0ef198b 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -2,6 +2,9 @@ import pytest import torch +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.warnings import PossibleUserWarning + from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised import AMDIM, BYOL, CPC_v2, MoCo, SimCLR, SimSiam, SwAV from pl_bolts.models.self_supervised.cpc import CPCEvalTransformsCIFAR10, CPCTrainTransformsCIFAR10 @@ -11,9 +14,6 @@ from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform from pl_bolts.utils import _IS_WINDOWS -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.warnings import PossibleUserWarning - from tests import _MARK_REQUIRE_GPU diff --git a/tests/models/self_supervised/test_resnets.py b/tests/models/self_supervised/test_resnets.py index 45657b1b51..07ffb2d1d9 100644 --- a/tests/models/self_supervised/test_resnets.py +++ b/tests/models/self_supervised/test_resnets.py @@ -1,5 +1,6 @@ import pytest import torch + from pl_bolts.models.self_supervised.amdim import AMDIMEncoder from pl_bolts.models.self_supervised.cpc import cpc_resnet50 from pl_bolts.models.self_supervised.resnets import ( diff --git a/tests/models/self_supervised/unit/test_transforms.py b/tests/models/self_supervised/unit/test_transforms.py index 729f259ce2..ed3909c773 100644 --- a/tests/models/self_supervised/unit/test_transforms.py +++ b/tests/models/self_supervised/unit/test_transforms.py @@ -2,6 +2,7 @@ import pytest import torch from PIL import Image + from pl_bolts.transforms.self_supervised.simclr_transforms import ( SimCLREvalDataTransform, SimCLRFinetuneTransform, diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index bf3162b252..81885af57c 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -1,8 +1,9 @@ import pytest import torch +from pytorch_lightning import Trainer, seed_everything + from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.autoencoders import AE, VAE, resnet18_decoder, resnet18_encoder, resnet50_encoder -from pytorch_lightning import Trainer, seed_everything @pytest.mark.parametrize("dm_cls", [pytest.param(CIFAR10DataModule, id="cifar10")]) diff --git a/tests/models/test_classic_ml.py b/tests/models/test_classic_ml.py index f4eac09559..39abf76fb7 100644 --- a/tests/models/test_classic_ml.py +++ b/tests/models/test_classic_ml.py @@ -1,10 +1,11 @@ import numpy as np import pytest -from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset -from pl_bolts.models.regression import LinearRegression from pytorch_lightning import Trainer, seed_everything from torch.utils.data import DataLoader +from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset +from pl_bolts.models.regression import LinearRegression + @pytest.mark.flaky(reruns=3) def test_linear_regression_model(tmpdir): diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index fb207b343d..8548decb5d 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -3,6 +3,10 @@ import pytest import torch +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.utils.data import DataLoader + from pl_bolts.datasets import DummyDetectionDataset from pl_bolts.models.detection import ( YOLO, @@ -18,10 +22,6 @@ ) from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone from pl_bolts.utils import _IS_WINDOWS -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.warnings import PossibleUserWarning -from torch.utils.data import DataLoader - from tests import TEST_ROOT diff --git a/tests/models/test_mnist_templates.py b/tests/models/test_mnist_templates.py index 58323cf356..40bf50ed33 100644 --- a/tests/models/test_mnist_templates.py +++ b/tests/models/test_mnist_templates.py @@ -1,10 +1,11 @@ import warnings -from pl_bolts.datamodules import MNISTDataModule -from pl_bolts.models import LitMNIST from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning +from pl_bolts.datamodules import MNISTDataModule +from pl_bolts.models import LitMNIST + def test_mnist(tmpdir, datadir, catch_warnings): warnings.filterwarnings( diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index e4124cfd65..42ed0837f7 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -3,8 +3,8 @@ import pytest import torch -from pl_bolts.utils import _GYM_GREATER_EQUAL_0_20, _IS_WINDOWS, _JSONARGPARSE_GREATER_THAN_4_16_0 +from pl_bolts.utils import _GYM_GREATER_EQUAL_0_20, _IS_WINDOWS, _JSONARGPARSE_GREATER_THAN_4_16_0 from tests import _MARK_REQUIRE_GPU, DATASETS_PATH _DEFAULT_ARGS = f" --data_dir={DATASETS_PATH} --max_epochs=1 --max_steps=2 --batch_size=4" diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 3f0b9a7d02..65f92ffb72 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -3,17 +3,18 @@ import pytest import torch from packaging import version -from pl_bolts.datamodules import FashionMNISTDataModule, MNISTDataModule -from pl_bolts.datasets import DummyDataset -from pl_bolts.models.vision import GPT2, ImageGPT, SemSegment, UNet -from pl_bolts.models.vision.unet import DoubleConv, Down, Up -from pl_bolts.utils import _IS_WINDOWS from pytorch_lightning import LightningDataModule, Trainer, seed_everything from pytorch_lightning import __version__ as pl_version from pytorch_lightning.callbacks.progress import TQDMProgressBar from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader +from pl_bolts.datamodules import FashionMNISTDataModule, MNISTDataModule +from pl_bolts.datasets import DummyDataset +from pl_bolts.models.vision import GPT2, ImageGPT, SemSegment, UNet +from pl_bolts.models.vision.unet import DoubleConv, Down, Up +from pl_bolts.utils import _IS_WINDOWS + class DummyDataModule(LightningDataModule): def train_dataloader(self): diff --git a/tests/models/yolo/unit/test_darknet_network.py b/tests/models/yolo/unit/test_darknet_network.py index 18020e27be..9f493f8d64 100644 --- a/tests/models/yolo/unit/test_darknet_network.py +++ b/tests/models/yolo/unit/test_darknet_network.py @@ -2,13 +2,14 @@ import pytest import torch.nn as nn +from pytorch_lightning.utilities.warnings import PossibleUserWarning + from pl_bolts.models.detection.yolo.darknet_network import ( _create_convolutional, _create_maxpool, _create_shortcut, _create_upsample, ) -from pytorch_lightning.utilities.warnings import PossibleUserWarning @pytest.mark.parametrize( diff --git a/tests/models/yolo/unit/test_target_matching.py b/tests/models/yolo/unit/test_target_matching.py index 9fe78bcce5..ffc0ed3617 100644 --- a/tests/models/yolo/unit/test_target_matching.py +++ b/tests/models/yolo/unit/test_target_matching.py @@ -1,4 +1,5 @@ import torch + from pl_bolts.models.detection.yolo.target_matching import _sim_ota_match diff --git a/tests/models/yolo/unit/test_utils.py b/tests/models/yolo/unit/test_utils.py index b883d2d60c..75dc5eac5c 100644 --- a/tests/models/yolo/unit/test_utils.py +++ b/tests/models/yolo/unit/test_utils.py @@ -2,6 +2,8 @@ import pytest import torch +from pytorch_lightning.utilities.warnings import PossibleUserWarning + from pl_bolts.models.detection.yolo.utils import ( aligned_iou, box_size_ratio, @@ -11,7 +13,6 @@ iou_below, is_inside_box, ) -from pytorch_lightning.utilities.warnings import PossibleUserWarning @pytest.mark.parametrize(("width", "height"), [(10, 5)]) diff --git a/tests/optimizers/test_lr_scheduler.py b/tests/optimizers/test_lr_scheduler.py index 477aae5b72..36309b67d1 100644 --- a/tests/optimizers/test_lr_scheduler.py +++ b/tests/optimizers/test_lr_scheduler.py @@ -2,12 +2,13 @@ import numpy as np import torch -from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pytorch_lightning import seed_everything from torch.nn import functional as F # noqa: N812 from torch.optim import SGD from torch.optim.lr_scheduler import _LRScheduler +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR + EPSILON = 1e-12 diff --git a/tests/transforms/test_normalizations.py b/tests/transforms/test_normalizations.py index b0c22cd07c..bd3af643e3 100644 --- a/tests/transforms/test_normalizations.py +++ b/tests/transforms/test_normalizations.py @@ -1,12 +1,13 @@ import pytest import torch +from pytorch_lightning import seed_everything + from pl_bolts.transforms.dataset_normalizations import ( cifar10_normalization, emnist_normalization, imagenet_normalization, stl10_normalization, ) -from pytorch_lightning import seed_everything @pytest.mark.parametrize( diff --git a/tests/utils/test_arguments.py b/tests/utils/test_arguments.py index e4e293a22f..5e95bb24b3 100644 --- a/tests/utils/test_arguments.py +++ b/tests/utils/test_arguments.py @@ -1,9 +1,10 @@ from dataclasses import FrozenInstanceError import pytest -from pl_bolts.utils.arguments import LightningArgumentParser, LitArg, gather_lit_args from pytorch_lightning import LightningDataModule, LightningModule +from pl_bolts.utils.arguments import LightningArgumentParser, LitArg, gather_lit_args + class DummyParentModel(LightningModule): name = "parent-model" @@ -32,7 +33,7 @@ def test_lightning_argument_parser(): assert parser.ignore_required_init_args is True -@pytest.mark.xfail() +@pytest.mark.xfail def test_parser_bad_argument(): parser = LightningArgumentParser() parser.add_object_args("dm", DummyParentDataModule) diff --git a/tests/utils/test_dependency.py b/tests/utils/test_dependency.py index a4fd42da33..0adf8b7a87 100644 --- a/tests/utils/test_dependency.py +++ b/tests/utils/test_dependency.py @@ -1,4 +1,5 @@ import pytest + from pl_bolts.utils._dependency import requires diff --git a/tests/utils/test_semi_supervised.py b/tests/utils/test_semi_supervised.py index 3d19bf5d4f..b9b3b47862 100644 --- a/tests/utils/test_semi_supervised.py +++ b/tests/utils/test_semi_supervised.py @@ -4,6 +4,7 @@ import numpy as np import pytest import torch + from pl_bolts.utils.semi_supervised import balance_classes, generate_half_labeled_batches try: