From ad74bb391df8b59542657c217aadc262de702f08 Mon Sep 17 00:00:00 2001
From: Jianing Yang <jed970610@gmail.com>
Date: Sat, 7 Dec 2024 17:11:32 +0000
Subject: [PATCH 1/8] allow user to pass kwargs to DeepSpeedStrategy

---
 src/lightning/fabric/strategies/deepspeed.py  | 2 ++
 src/lightning/pytorch/strategies/deepspeed.py | 2 ++
 2 files changed, 4 insertions(+)

diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index 03d90cd5df057..a584c9b85c39b 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -97,6 +97,7 @@ def __init__(
         load_full_weights: bool = False,
         precision: Optional[Precision] = None,
         process_group_backend: Optional[str] = None,
+        **kwargs: Any,
     ) -> None:
         """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
         billion parameter models. `For more information: https://pytorch-
@@ -239,6 +240,7 @@ def __init__(
             cluster_environment=cluster_environment,
             precision=precision,
             process_group_backend=process_group_backend,
+            **kwargs,
         )
         self._backward_sync_control = None  # DeepSpeed handles gradient accumulation internally
 
diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index 4fa771114768d..cafa8e830c7c7 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -119,6 +119,7 @@ def __init__(
         load_full_weights: bool = False,
         precision_plugin: Optional[Precision] = None,
         process_group_backend: Optional[str] = None,
+        **kwargs: Any,
     ) -> None:
         """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
         billion parameter models. `For more information: https://pytorch-
@@ -263,6 +264,7 @@ def __init__(
             cluster_environment=cluster_environment,
             precision_plugin=precision_plugin,
             process_group_backend=process_group_backend,
+            **kwargs,
         )
 
         self.config = self._load_config(config)

From 763843982b4fe3ce1e6792861cad0066fafca313 Mon Sep 17 00:00:00 2001
From: Jianing Yang <jed970610@gmail.com>
Date: Sat, 7 Dec 2024 15:43:34 -0500
Subject: [PATCH 2/8] Update deepspeed.py

---
 src/lightning/pytorch/strategies/deepspeed.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index cafa8e830c7c7..455e7d26de29d 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -366,7 +366,7 @@ def _init_deepspeed_distributed(self) -> None:
                 f"MEMBER: {self.global_rank + 1}/{self.world_size}"
             )
         self._process_group_backend = self._get_process_group_backend()
-        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
+        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout)
 
     def _set_node_environment_variables(self) -> None:
         assert self.cluster_environment is not None

From 60e0e562e021da60a81817cffe8ceaacff592d53 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Sat, 7 Dec 2024 20:43:54 +0000
Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 src/lightning/pytorch/strategies/deepspeed.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index 455e7d26de29d..456d6428cbe87 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -366,7 +366,9 @@ def _init_deepspeed_distributed(self) -> None:
                 f"MEMBER: {self.global_rank + 1}/{self.world_size}"
             )
         self._process_group_backend = self._get_process_group_backend()
-        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout)
+        deepspeed.init_distributed(
+            self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
+        )
 
     def _set_node_environment_variables(self) -> None:
         assert self.cluster_environment is not None

From 5872eb911c61fb1658b102480337946af2aa1c8d Mon Sep 17 00:00:00 2001
From: Jianing Yang <jed970610@gmail.com>
Date: Sat, 7 Dec 2024 15:44:17 -0500
Subject: [PATCH 4/8] Update deepspeed.py

---
 src/lightning/fabric/strategies/deepspeed.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index a584c9b85c39b..39b40bca92893 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -650,7 +650,7 @@ def _init_deepspeed_distributed(self) -> None:
                 f"MEMBER: {self.global_rank + 1}/{self.world_size}"
             )
         self._process_group_backend = self._get_process_group_backend()
-        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
+        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout)
 
     def _set_node_environment_variables(self) -> None:
         assert self.cluster_environment is not None

From 5ebcaac03791904836ed4fa206b61651524947bb Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Sat, 7 Dec 2024 20:44:38 +0000
Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 src/lightning/fabric/strategies/deepspeed.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index 39b40bca92893..0bb5be4d8adbb 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -650,7 +650,9 @@ def _init_deepspeed_distributed(self) -> None:
                 f"MEMBER: {self.global_rank + 1}/{self.world_size}"
             )
         self._process_group_backend = self._get_process_group_backend()
-        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout)
+        deepspeed.init_distributed(
+            self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
+        )
 
     def _set_node_environment_variables(self) -> None:
         assert self.cluster_environment is not None

From 09ec21c85bff6aae7f102f9cf0b968b2d3d48284 Mon Sep 17 00:00:00 2001
From: Jianing Yang <jed970610@gmail.com>
Date: Sun, 8 Dec 2024 20:59:12 +0000
Subject: [PATCH 6/8] make timeout explicit in DeepSpeedStrategy

---
 src/lightning/fabric/strategies/deepspeed.py  |  8 +++++---
 src/lightning/pytorch/strategies/deepspeed.py | 10 +++++++---
 2 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index a584c9b85c39b..526a49034d8f0 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -18,6 +18,7 @@
 import platform
 from collections.abc import Mapping
 from contextlib import AbstractContextManager, ExitStack
+from datetime import timedelta
 from itertools import chain
 from pathlib import Path
 from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@@ -31,6 +32,7 @@
 from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
 from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
 from lightning.fabric.plugins.precision import Precision
+from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
 from lightning.fabric.strategies.ddp import DDPStrategy
 from lightning.fabric.strategies.registry import _StrategyRegistry
 from lightning.fabric.strategies.strategy import _Sharded
@@ -97,7 +99,7 @@ def __init__(
         load_full_weights: bool = False,
         precision: Optional[Precision] = None,
         process_group_backend: Optional[str] = None,
-        **kwargs: Any,
+        timeout: Optional[timedelta] = default_pg_timeout,
     ) -> None:
         """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
         billion parameter models. `For more information: https://pytorch-
@@ -240,9 +242,9 @@ def __init__(
             cluster_environment=cluster_environment,
             precision=precision,
             process_group_backend=process_group_backend,
-            **kwargs,
         )
         self._backward_sync_control = None  # DeepSpeed handles gradient accumulation internally
+        self._timeout: Optional[timedelta] = timeout
 
         self.config = self._load_config(config)
         if self.config is None:
@@ -650,7 +652,7 @@ def _init_deepspeed_distributed(self) -> None:
                 f"MEMBER: {self.global_rank + 1}/{self.world_size}"
             )
         self._process_group_backend = self._get_process_group_backend()
-        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
+        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout)
 
     def _set_node_environment_variables(self) -> None:
         assert self.cluster_environment is not None
diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index cafa8e830c7c7..94bed46af6f8b 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -19,6 +19,7 @@
 from collections import OrderedDict
 from collections.abc import Generator, Mapping
 from contextlib import contextmanager
+from datetime import timedelta
 from pathlib import Path
 from typing import TYPE_CHECKING, Any, Optional, Union
 
@@ -30,6 +31,9 @@
 
 import lightning.pytorch as pl
 from lightning.fabric.plugins import ClusterEnvironment
+from lightning.fabric.plugins.collectives.torch_collective import (
+    default_pg_timeout
+)
 from lightning.fabric.strategies import _StrategyRegistry
 from lightning.fabric.strategies.deepspeed import (
     _DEEPSPEED_AVAILABLE,
@@ -119,7 +123,7 @@ def __init__(
         load_full_weights: bool = False,
         precision_plugin: Optional[Precision] = None,
         process_group_backend: Optional[str] = None,
-        **kwargs: Any,
+        timeout: Optional[timedelta] = default_pg_timeout,
     ) -> None:
         """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
         billion parameter models. `For more information: https://pytorch-
@@ -264,8 +268,8 @@ def __init__(
             cluster_environment=cluster_environment,
             precision_plugin=precision_plugin,
             process_group_backend=process_group_backend,
-            **kwargs,
         )
+        self._timeout: Optional[timedelta] = timeout
 
         self.config = self._load_config(config)
         if self.config is None:
@@ -366,7 +370,7 @@ def _init_deepspeed_distributed(self) -> None:
                 f"MEMBER: {self.global_rank + 1}/{self.world_size}"
             )
         self._process_group_backend = self._get_process_group_backend()
-        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
+        deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout)
 
     def _set_node_environment_variables(self) -> None:
         assert self.cluster_environment is not None

From 5948705094f4d320135389ebc6c01266110fc225 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Sun, 8 Dec 2024 21:02:22 +0000
Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 src/lightning/fabric/strategies/deepspeed.py  | 2 +-
 src/lightning/pytorch/strategies/deepspeed.py | 4 +---
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index 3195eb470c262..1e94fa1166f93 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -30,9 +30,9 @@
 from typing_extensions import override
 
 from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
+from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
 from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
 from lightning.fabric.plugins.precision import Precision
-from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
 from lightning.fabric.strategies.ddp import DDPStrategy
 from lightning.fabric.strategies.registry import _StrategyRegistry
 from lightning.fabric.strategies.strategy import _Sharded
diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index de67d8715bbe5..e17377d4464b0 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -31,9 +31,7 @@
 
 import lightning.pytorch as pl
 from lightning.fabric.plugins import ClusterEnvironment
-from lightning.fabric.plugins.collectives.torch_collective import (
-    default_pg_timeout
-)
+from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
 from lightning.fabric.strategies import _StrategyRegistry
 from lightning.fabric.strategies.deepspeed import (
     _DEEPSPEED_AVAILABLE,

From dbc472665d3b8f7bc1a1a567428d892870f388f8 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Mon, 9 Dec 2024 15:21:05 +0000
Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 examples/fabric/tensor_parallel/train.py  | 3 +--
 examples/pytorch/tensor_parallel/train.py | 3 +--
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py
index 1435e5c2003e1..4a98f12cf6168 100644
--- a/examples/fabric/tensor_parallel/train.py
+++ b/examples/fabric/tensor_parallel/train.py
@@ -1,14 +1,13 @@
 import lightning as L
 import torch
 import torch.nn.functional as F
+from data import RandomTokenDataset
 from lightning.fabric.strategies import ModelParallelStrategy
 from model import ModelArgs, Transformer
 from parallelism import parallelize
 from torch.distributed.tensor.parallel import loss_parallel
 from torch.utils.data import DataLoader
 
-from data import RandomTokenDataset
-
 
 def train():
     strategy = ModelParallelStrategy(
diff --git a/examples/pytorch/tensor_parallel/train.py b/examples/pytorch/tensor_parallel/train.py
index 37c620f4582f0..6a91e1242e4af 100644
--- a/examples/pytorch/tensor_parallel/train.py
+++ b/examples/pytorch/tensor_parallel/train.py
@@ -1,14 +1,13 @@
 import lightning as L
 import torch
 import torch.nn.functional as F
+from data import RandomTokenDataset
 from lightning.pytorch.strategies import ModelParallelStrategy
 from model import ModelArgs, Transformer
 from parallelism import parallelize
 from torch.distributed.tensor.parallel import loss_parallel
 from torch.utils.data import DataLoader
 
-from data import RandomTokenDataset
-
 
 class Llama3(L.LightningModule):
     def __init__(self):