diff --git a/zeus/util/gpu.py b/zeus/device/gpu.py similarity index 52% rename from zeus/util/gpu.py rename to zeus/device/gpu.py index 4bba19a0..799b6a13 100644 --- a/zeus/util/gpu.py +++ b/zeus/device/gpu.py @@ -25,7 +25,14 @@ """ +from __future__ import annotations + import abc +import pynvml +import os + + +from amdsmi import amdsmi_init class GPU(abc.ABC): def __init__(self, gpu_index: int) -> None: @@ -39,20 +46,16 @@ def get_total_energy_consumption(self) -> float: def set_power_limit(self, value: int) -> None: pass -class NVIDIALocalGPU(GPU): - def __init__(self, gpu_index: int) -> None: - super().__init__(gpu_index) - def get_total_energy_consumption(self) -> float: - # Call NVML - pass +""" NVIDIA GPUs """ + +class NativeNVIDIAGPU(GPU): + """Alternatve names: DirectNVIDIAGPU, LocalNVIDIAGPU, NativeNVIDIAGPU, IntegratedNVIDIAGPU""" + - def set_power_limit(self, value: int) -> None: - # Call NVML - pass -class NVIDIARemoteGPU(GPU): +class RemoteNVIDIAGPU(NativeNVIDIAGPU): def __init__(self, gpu_index: int, server_address: str) -> None: super().__init__(gpu_index) self.server_address = server_address @@ -60,11 +63,76 @@ def __init__(self, gpu_index: int, server_address: str) -> None: def set_power_limit(self, value: int) -> None: # Call server since SYS_ADMIN is required pass + def set_freq_mem(self, value: int) -> None: + # Call server since SYS_ADMIN is required + pass + def set_freq_core(self, value: int) -> None: + # Call server since SYS_ADMIN is required + pass + + +""" AMD GPUs """ + +class NativeAMDGPU(GPU): + pass + +class RemoteAMDGPU(NativeAMDGPU): + pass + + + +# Managing all GPUs - Base Abstract Class +class GPUs(abc.ABC): + # Abstract class for managing all GPUs + @abc.abstractmethod + def __init__(self) -> None: + pass + + @abc.abstractmethod + def Init(self) -> None: + pass + + @abc.abstractmethod + def Shutdown(self) -> None: + pass + + +# ?: Enfore singleton pattern? + +# NVIDIA GPUs +class NVIDIAGPUs(GPUs): + def __init__(self, ensure_homogeneuous: bool = True) -> None: + pynvml.nvmlInit() + + def __del__(self) -> None: + pynvml.nvmlShutdown() + + +#AMD GPUs + +class AMDGPUs(GPUs): + def __init__(self, ensure_homogeneuous: bool = True) -> None: + amdsmi_init() + + def Init(self) -> None: + pass + + def Shutdown(self) -> None: + pass + + def get_gpu_count(self) -> int: + pass + + def get_total_energy_consumption(self) -> float: + pass + + def set_power_limit(self, value: int) -> None: + pass + + +# Ensure only one instance of GPUs is created +def get_gpus() -> GPUs: + return NVIDIAGPUs() if pynvml.is_initialized() else AMDGPUs() +GPUManager = get_gpus() -# Factory function to get GPU -def get_gpu(gpu_index: int, server_address: str | None = None) -> GPU: - if server_address is None: - return NVIDIALocalGPU(gpu_index) - else: - return NVIDIARemoteGPU(gpu_index, server_address) \ No newline at end of file