Skip to content

Commit

Permalink
Rename PrecisionPlugin -> Precision (#18840)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Oct 30, 2023
1 parent 6de4916 commit 079544a
Show file tree
Hide file tree
Showing 56 changed files with 566 additions and 338 deletions.
18 changes: 9 additions & 9 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ precision
:nosignatures:
:template: classtemplate.rst

DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
HalfPrecisionPlugin
FSDPPrecisionPlugin
MixedPrecisionPlugin
PrecisionPlugin
XLAPrecisionPlugin
TransformerEnginePrecisionPlugin
BitsandbytesPrecisionPlugin
DeepSpeedPrecision
DoublePrecision
HalfPrecision
FSDPPrecision
MixedPrecision
Precision
XLAPrecision
TransformerEnginePrecision
BitsandbytesPrecision

environments
""""""""""""
Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/common/precision_expert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ N-Bit Precision (Expert)
Precision Plugins
*****************

You can also customize and pass your own Precision Plugin by subclassing the :class:`~lightning.pytorch.plugins.precision.precision_plugin.PrecisionPlugin` class.
You can also customize and pass your own Precision Plugin by subclassing the :class:`~lightning.pytorch.plugins.precision.precision.Precision` class.

- Perform pre and post backward/optimizer step operations such as scaling gradients.
- Provide context managers for forward, training_step, etc.

.. code-block:: python
class CustomPrecisionPlugin(PrecisionPlugin):
class CustomPrecision(Precision):
precision = "16-mixed"
...
trainer = Trainer(plugins=[CustomPrecisionPlugin()])
trainer = Trainer(plugins=[CustomPrecision()])
8 changes: 4 additions & 4 deletions docs/source-pytorch/common/precision_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,18 @@ This is configurable via the dtype argument in the plugin.

Quantizing the model will dramatically reduce the weight's memory requirements but may have a negative impact on the model's performance or runtime.

The :class:`~lightning.pytorch.plugins.precision.bitsandbytes.BitsandbytesPrecisionPlugin` automatically replaces the :class:`torch.nn.Linear` layers in your model with their BNB alternatives.
The :class:`~lightning.pytorch.plugins.precision.bitsandbytes.BitsandbytesPrecision` automatically replaces the :class:`torch.nn.Linear` layers in your model with their BNB alternatives.

.. code-block:: python
from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin
from lightning.pytorch.plugins import BitsandbytesPrecision
# this will pick out the compute dtype automatically, by default `bfloat16`
precision = BitsandbytesPrecisionPlugin("nf4-dq")
precision = BitsandbytesPrecision("nf4-dq")
trainer = Trainer(plugins=precision)
# Customize the dtype, or skip some modules
precision = BitsandbytesPrecisionPlugin("int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
precision = BitsandbytesPrecision("int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
trainer = Trainer(plugins=precision)
Expand Down
18 changes: 9 additions & 9 deletions docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ The full list of built-in precision plugins is listed below.
:nosignatures:
:template: classtemplate.rst

DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
HalfPrecisionPlugin
FSDPPrecisionPlugin
MixedPrecisionPlugin
PrecisionPlugin
XLAPrecisionPlugin
TransformerEnginePrecisionPlugin
BitsandbytesPrecisionPlugin
DeepSpeedPrecision
DoublePrecision
HalfPrecision
FSDPPrecision
MixedPrecision
Precision
XLAPrecision
TransformerEnginePrecision
BitsandbytesPrecision

More information regarding precision with Lightning can be found :ref:`here <precision>`

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/plugins/precision/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class BitsandbytesPrecision(Precision):

# TODO: we could implement optimizer replacement with
# - Fabric: Add `Precision.convert_optimizer` from `Strategy.setup_optimizer`
# - Trainer: Use `PrecisionPlugin.connect`
# - Trainer: Use `Precision.connect`

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

-
- Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840))


### Removed
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/_graveyard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
import lightning.pytorch._graveyard._torchmetrics
import lightning.pytorch._graveyard.hpu
import lightning.pytorch._graveyard.ipu
import lightning.pytorch._graveyard.precision
import lightning.pytorch._graveyard.tpu # noqa: F401
86 changes: 86 additions & 0 deletions src/lightning/pytorch/_graveyard/precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import sys
from typing import TYPE_CHECKING, Any, Literal, Optional

import lightning.pytorch as pl
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation
from lightning.pytorch.plugins.precision import (
BitsandbytesPrecision,
DeepSpeedPrecision,
DoublePrecision,
FSDPPrecision,
HalfPrecision,
MixedPrecision,
Precision,
TransformerEnginePrecision,
XLAPrecision,
)

if TYPE_CHECKING:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler


def _patch_sys_modules() -> None:
sys.modules["lightning.pytorch.plugins.precision.precision_plugin"] = sys.modules[
"lightning.pytorch.plugins.precision.precision"
]


class FSDPMixedPrecisionPlugin(FSDPPrecision):
"""AMP for Fully Sharded Data Parallel (FSDP) Training.
.. deprecated:: Use :class:`FSDPPrecision` instead.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
"""

def __init__(
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None
) -> None:
rank_zero_deprecation(
f"The `{type(self).__name__}` is deprecated."
" Use `lightning.pytorch.plugins.precision.FSDPPrecision` instead."
)
super().__init__(precision=precision, scaler=scaler)


def _create_class(deprecated_name: str, new_class: type) -> type:
def init(self: type, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
f"The `{deprecated_name}` is deprecated."
f" Use `lightning.pytorch.plugins.precision.{new_class.__name__}` instead."
)
super(type(self), self).__init__(*args, **kwargs)

return type(deprecated_name, (new_class,), {"__init__": init})


def _patch_classes() -> None:
classes_map = (
# module name, old name, new class
("bitsandbytes", "BitsandbytesPrecisionPlugin", BitsandbytesPrecision),
("deepspeed", "DeepSpeedPrecisionPlugin", DeepSpeedPrecision),
("double", "DoublePrecisionPlugin", DoublePrecision),
("fsdp", "FSDPPrecisionPlugin", FSDPPrecision),
("fsdp", "FSDPMixedPrecisionPlugin", FSDPPrecision),
("half", "HalfPrecisionPlugin", HalfPrecision),
("amp", "MixedPrecisionPlugin", MixedPrecision),
("precision", "PrecisionPlugin", Precision),
("transformer_engine", "TransformerEnginePrecisionPlugin", TransformerEnginePrecision),
("xla", "XLAPrecisionPlugin", XLAPrecision),
)

for module_name, deprecated_name, new_class in classes_map:
deprecated_class = _create_class(deprecated_name, new_class)
setattr(getattr(pl.plugins.precision, module_name), deprecated_name, deprecated_class)
setattr(pl.plugins.precision, deprecated_name, deprecated_class)
setattr(pl.plugins, deprecated_name, deprecated_class)

# special treatment for `FSDPMixedPrecisionPlugin` because it has a different signature
setattr(pl.plugins.precision.fsdp, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin)
setattr(pl.plugins.precision, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin)
setattr(pl.plugins, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin)


_patch_sys_modules()
_patch_classes()
20 changes: 10 additions & 10 deletions src/lightning/pytorch/_graveyard/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import lightning.pytorch as pl
from lightning.fabric.strategies import _StrategyRegistry
from lightning.pytorch.accelerators.xla import XLAAccelerator
from lightning.pytorch.plugins.precision import XLAPrecisionPlugin
from lightning.pytorch.plugins.precision import XLAPrecision
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation

Expand Down Expand Up @@ -63,47 +63,47 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class TPUPrecisionPlugin(XLAPrecisionPlugin):
class TPUPrecisionPlugin(XLAPrecision):
"""Legacy class.
Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecisionPlugin` instead.
Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecision` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecisionPlugin`"
"The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecision`"
" instead."
)
super().__init__(precision="32-true")


class TPUBf16PrecisionPlugin(XLAPrecisionPlugin):
class TPUBf16PrecisionPlugin(XLAPrecision):
"""Legacy class.
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead.
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUBf16PrecisionPlugin` class is deprecated. Use"
" `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead."
" `lightning.pytorch.plugins.precision.XLAPrecision` instead."
)
super().__init__(precision="bf16-true")


class XLABf16PrecisionPlugin(XLAPrecisionPlugin):
class XLABf16PrecisionPlugin(XLAPrecision):
"""Legacy class.
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead.
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `XLABf16PrecisionPlugin` class is deprecated. Use"
" `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead."
" `lightning.pytorch.plugins.precision.XLAPrecision` instead."
)
super().__init__(precision="bf16-true")

Expand Down
40 changes: 20 additions & 20 deletions src/lightning/pytorch/callbacks/throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@

import torch

from lightning.fabric.plugins import Precision
from lightning.fabric.plugins import Precision as FabricPrecision
from lightning.fabric.utilities.throughput import Throughput, get_available_flops
from lightning.fabric.utilities.throughput import _plugin_to_compute_dtype as fabric_plugin_to_compute_dtype
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.plugins import (
DoublePrecisionPlugin,
FSDPPrecisionPlugin,
MixedPrecisionPlugin,
PrecisionPlugin,
TransformerEnginePrecisionPlugin,
BitsandbytesPrecision,
DeepSpeedPrecision,
DoublePrecision,
FSDPPrecision,
HalfPrecision,
MixedPrecision,
Precision,
TransformerEnginePrecision,
XLAPrecision,
)
from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecisionPlugin
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn

Expand Down Expand Up @@ -227,24 +227,24 @@ def on_predict_batch_end(
self._compute(trainer, iter_num)


def _plugin_to_compute_dtype(plugin: Union[Precision, PrecisionPlugin]) -> torch.dtype:
def _plugin_to_compute_dtype(plugin: Union[FabricPrecision, Precision]) -> torch.dtype:
# TODO: integrate this into the precision plugins
if not isinstance(plugin, PrecisionPlugin):
if not isinstance(plugin, Precision):
return fabric_plugin_to_compute_dtype(plugin)
if isinstance(plugin, BitsandbytesPrecisionPlugin):
if isinstance(plugin, BitsandbytesPrecision):
return plugin.dtype
if isinstance(plugin, HalfPrecisionPlugin):
if isinstance(plugin, HalfPrecision):
return plugin._desired_input_dtype
if isinstance(plugin, MixedPrecisionPlugin):
if isinstance(plugin, MixedPrecision):
return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
if isinstance(plugin, DoublePrecisionPlugin):
if isinstance(plugin, DoublePrecision):
return torch.double
if isinstance(plugin, (XLAPrecisionPlugin, DeepSpeedPrecisionPlugin)):
if isinstance(plugin, (XLAPrecision, DeepSpeedPrecision)):
return plugin._desired_dtype
if isinstance(plugin, TransformerEnginePrecisionPlugin):
if isinstance(plugin, TransformerEnginePrecision):
return torch.int8
if isinstance(plugin, FSDPPrecisionPlugin):
if isinstance(plugin, FSDPPrecision):
return plugin.mixed_precision_config.reduce_dtype or torch.float32
if isinstance(plugin, PrecisionPlugin):
if isinstance(plugin, Precision):
return torch.float32
raise NotImplementedError(plugin)
39 changes: 19 additions & 20 deletions src/lightning/pytorch/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,33 @@
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, TorchCheckpointIO, XLACheckpointIO
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecisionPlugin
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, FSDPPrecisionPlugin
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
from lightning.pytorch.plugins.precision.amp import MixedPrecision
from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning.pytorch.plugins.precision.double import DoublePrecision
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
from lightning.pytorch.plugins.precision.half import HalfPrecision
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision
from lightning.pytorch.plugins.precision.xla import XLAPrecision

PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO, LayerSync]
PLUGIN_INPUT = Union[PLUGIN, str]

__all__ = [
"AsyncCheckpointIO",
"CheckpointIO",
"TorchCheckpointIO",
"XLACheckpointIO",
"BitsandbytesPrecisionPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"HalfPrecisionPlugin",
"MixedPrecisionPlugin",
"PrecisionPlugin",
"TransformerEnginePrecisionPlugin",
"FSDPMixedPrecisionPlugin",
"FSDPPrecisionPlugin",
"XLAPrecisionPlugin",
"BitsandbytesPrecision",
"DeepSpeedPrecision",
"DoublePrecision",
"HalfPrecision",
"MixedPrecision",
"Precision",
"TransformerEnginePrecision",
"FSDPPrecision",
"XLAPrecision",
"LayerSync",
"TorchSyncBatchNorm",
]
Loading

0 comments on commit 079544a

Please sign in to comment.