Skip to content

Commit

Permalink
Add throughput utilities to Fabric and the Trainer (#18848)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 30, 2023
1 parent fd48627 commit 800b87e
Show file tree
Hide file tree
Showing 25 changed files with 1,327 additions and 34 deletions.
6 changes: 6 additions & 0 deletions docs/source-fabric/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ lightning.fabric.utilities
.. autofunction:: lightning.fabric.utilities.distributed.is_shared_filesystem

.. autofunction:: lightning.fabric.utilities.warnings.disable_possible_user_warnings

.. autofunction:: lightning.fabric.utilities.throughput.measure_flops

.. autoclass:: lightning.fabric.utilities.throughput.ThroughputMonitor

.. autoclass:: lightning.fabric.utilities.throughput.Throughput
1 change: 1 addition & 0 deletions docs/source-fabric/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@
("py:class", "lightning.fabric.wrappers._FabricOptimizer"),
("py:class", "lightning.fabric.loggers.csv_logs._ExperimentWriter"),
("py:class", "lightning.fabric.strategies.strategy._Sharded"),
("py:class", "lightning.fabric.utilities.throughput.Throughput"),
# Nitpick does not see abstract API
("py:meth", "lightning.fabric.plugins.collectives.Collective.init_group"),
# These seem to be missing in reference generated API
Expand Down
4 changes: 4 additions & 0 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ callbacks
RichModelSummary
RichProgressBar
StochasticWeightAveraging
SpikeDetection
ThroughputMonitor
Timer
TQDMProgressBar

Expand Down Expand Up @@ -248,3 +250,5 @@ utilities
rank_zero
seed
warnings

.. autofunction:: lightning.pytorch.utilities.measure_flops
7 changes: 5 additions & 2 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:class", "lightning.fabric.utilities.types.ReduceLROnPlateau"),
("py:class", "lightning.fabric.utilities.types.Steppable"),
("py:class", "lightning.fabric.wrappers._FabricOptimizer"),
("py:class", "lightning.fabric.utilities.throughput.Throughput"),
("py:func", "lightning.fabric.utilities.throughput.measure_flops"),
("py:class", "lightning.fabric.utilities.spike.SpikeDetection"),
("py:meth", "lightning.pytorch.Callback.on_exception"),
("py:class", "lightning.pytorch.LightningModule"),
("py:meth", "lightning.pytorch.LightningModule.on_train_epoch_end"),
Expand Down Expand Up @@ -450,7 +453,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:meth", "optimizer_step"),
("py:class", "out_dict"),
("py:meth", "prepare_data"),
("py:class", "pytorch_lightning.callbacks.device_stats_monitor.DeviceStatsMonitor"),
("py:class", "lightning.pytorch.callbacks.device_stats_monitor.DeviceStatsMonitor"),
("py:meth", "setup"),
("py:meth", "test_step"),
("py:meth", "toggle_optimizer"),
Expand Down Expand Up @@ -585,7 +588,7 @@ def package_list_from_file(file):
from lightning.pytorch import LightningDataModule, LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
from lightning.pytorch.utilities import _TORCHVISION_AVAILABLE
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added `lightning.fabric.utilities.ThroughputMonitor` and `lightning.fabric.utilities.Throughput` to track throughput and log it ([#18848](https://github.com/Lightning-AI/lightning/pull/18848))


### Changed
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from typing import Generator, List, Optional, Union, cast

import torch
from lightning_utilities.core.rank_zero import rank_zero_info

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_info


class CUDAAccelerator(Accelerator):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.overrides import is_overridden
from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler
Expand Down Expand Up @@ -67,6 +66,7 @@
)
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.fabric.utilities.seed import seed_everything
from lightning.fabric.utilities.types import ReduceOp
Expand Down
28 changes: 22 additions & 6 deletions src/lightning/fabric/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,30 @@
# limitations under the License.
"""General utilities."""

from lightning.fabric.utilities.apply_func import move_data_to_device # noqa: F401
from lightning.fabric.utilities.data import suggested_max_num_workers # noqa: F401
from lightning.fabric.utilities.distributed import is_shared_filesystem # noqa: F401
from lightning.fabric.utilities.enums import LightningEnum # noqa: F401
from lightning.fabric.utilities.rank_zero import ( # noqa: F401
from lightning.fabric.utilities.apply_func import move_data_to_device
from lightning.fabric.utilities.data import suggested_max_num_workers
from lightning.fabric.utilities.distributed import is_shared_filesystem
from lightning.fabric.utilities.enums import LightningEnum
from lightning.fabric.utilities.rank_zero import (
rank_zero_deprecation,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
)
from lightning.fabric.utilities.warnings import disable_possible_user_warnings # noqa: F401
from lightning.fabric.utilities.throughput import Throughput, ThroughputMonitor, measure_flops
from lightning.fabric.utilities.warnings import disable_possible_user_warnings

__all__ = [
"disable_possible_user_warnings",
"is_shared_filesystem",
"LightningEnum",
"measure_flops",
"move_data_to_device",
"rank_zero_deprecation",
"rank_zero_info",
"rank_zero_only",
"rank_zero_warn",
"suggested_max_num_workers",
"Throughput",
"ThroughputMonitor",
]
2 changes: 2 additions & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

_UTILITIES_GREATER_EQUAL_0_10 = compare_version("lightning_utilities", operator.ge, "0.10.0")
34 changes: 32 additions & 2 deletions src/lightning/fabric/utilities/rank_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"""Utilities that can be used for calling functions on a particular rank."""
import logging
import os
from typing import Optional
from functools import wraps
from typing import Callable, Optional, TypeVar, overload

import lightning_utilities.core.rank_zero as rank_zero_module

Expand All @@ -25,11 +26,12 @@
rank_zero_debug,
rank_zero_deprecation,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
)
from typing_extensions import ParamSpec

import lightning.fabric
from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10

rank_zero_module.log = logging.getLogger(__name__)

Expand All @@ -50,6 +52,34 @@ def _get_rank(
return None


if not _UTILITIES_GREATER_EQUAL_0_10:
T = TypeVar("T")
P = ParamSpec("P")

@overload
def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]:
...

@overload
def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]:
...

def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]:
@wraps(fn)
def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
rank = getattr(rank_zero_only, "rank", None)
if rank is None:
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
if rank == 0:
return fn(*args, **kwargs)
return default

return wrapped_fn

rank_zero_module.rank_zero_only.rank = getattr(rank_zero_module.rank_zero_only, "rank", _get_rank() or 0)
else:
rank_zero_only = rank_zero_module.rank_zero_only # type: ignore[assignment]

# add the attribute to the function but don't overwrite in case Trainer has already set it
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank() or 0)

Expand Down
Loading

0 comments on commit 800b87e

Please sign in to comment.