Skip to content

Commit

Permalink
apply fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Nov 13, 2024
1 parent 4f22163 commit c542968
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from typing import Union

import torch
from typing_extensions import override
Expand Down Expand Up @@ -45,7 +45,7 @@ def parse_devices(devices: Union[int, str]) -> int:

@staticmethod
@override
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Union
from typing import Any, Union

import torch
from lightning_utilities.core.imports import RequirementCache
Expand All @@ -38,7 +38,7 @@ def setup_device(self, device: torch.device) -> None:
raise MisconfigurationException(f"Device should be CPU, got {device} instead.")

@override
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"""Get CPU stats from ``psutil`` package."""
return get_cpu_stats()

Expand All @@ -54,7 +54,7 @@ def parse_devices(devices: Union[int, str]) -> int:

@staticmethod
@override
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices
Expand Down Expand Up @@ -89,7 +89,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
_PSUTIL_AVAILABLE = RequirementCache("psutil")


def get_cpu_stats() -> Dict[str, float]:
def get_cpu_stats() -> dict[str, float]:
if not _PSUTIL_AVAILABLE:
raise ModuleNotFoundError(
f"Fetching CPU device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}"
Expand Down

0 comments on commit c542968

Please sign in to comment.