Skip to content

Commit

Permalink
stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
parthraut committed Mar 8, 2024
1 parent 9d12f3c commit daa5a57
Showing 1 changed file with 84 additions and 16 deletions.
100 changes: 84 additions & 16 deletions zeus/util/gpu.py → zeus/device/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@
"""

from __future__ import annotations

import abc
import pynvml
import os


from amdsmi import amdsmi_init

class GPU(abc.ABC):
def __init__(self, gpu_index: int) -> None:
Expand All @@ -39,32 +46,93 @@ def get_total_energy_consumption(self) -> float:
def set_power_limit(self, value: int) -> None:
pass

class NVIDIALocalGPU(GPU):
def __init__(self, gpu_index: int) -> None:
super().__init__(gpu_index)

def get_total_energy_consumption(self) -> float:
# Call NVML
pass
""" NVIDIA GPUs """

class NativeNVIDIAGPU(GPU):
"""Alternatve names: DirectNVIDIAGPU, LocalNVIDIAGPU, NativeNVIDIAGPU, IntegratedNVIDIAGPU"""


def set_power_limit(self, value: int) -> None:
# Call NVML
pass


class NVIDIARemoteGPU(GPU):
class RemoteNVIDIAGPU(NativeNVIDIAGPU):
def __init__(self, gpu_index: int, server_address: str) -> None:
super().__init__(gpu_index)
self.server_address = server_address

def set_power_limit(self, value: int) -> None:
# Call server since SYS_ADMIN is required
pass
def set_freq_mem(self, value: int) -> None:
# Call server since SYS_ADMIN is required
pass
def set_freq_core(self, value: int) -> None:
# Call server since SYS_ADMIN is required
pass


""" AMD GPUs """

class NativeAMDGPU(GPU):
pass

class RemoteAMDGPU(NativeAMDGPU):
pass



# Managing all GPUs - Base Abstract Class
class GPUs(abc.ABC):
# Abstract class for managing all GPUs
@abc.abstractmethod
def __init__(self) -> None:
pass

@abc.abstractmethod
def Init(self) -> None:
pass

@abc.abstractmethod
def Shutdown(self) -> None:
pass


# ?: Enfore singleton pattern?

# NVIDIA GPUs
class NVIDIAGPUs(GPUs):
def __init__(self, ensure_homogeneuous: bool = True) -> None:
pynvml.nvmlInit()

def __del__(self) -> None:
pynvml.nvmlShutdown()


#AMD GPUs

class AMDGPUs(GPUs):
def __init__(self, ensure_homogeneuous: bool = True) -> None:
amdsmi_init()

def Init(self) -> None:
pass

def Shutdown(self) -> None:
pass

def get_gpu_count(self) -> int:
pass

def get_total_energy_consumption(self) -> float:
pass

def set_power_limit(self, value: int) -> None:
pass


# Ensure only one instance of GPUs is created
def get_gpus() -> GPUs:
return NVIDIAGPUs() if pynvml.is_initialized() else AMDGPUs()

GPUManager = get_gpus()

# Factory function to get GPU
def get_gpu(gpu_index: int, server_address: str | None = None) -> GPU:
if server_address is None:
return NVIDIALocalGPU(gpu_index)
else:
return NVIDIARemoteGPU(gpu_index, server_address)

0 comments on commit daa5a57

Please sign in to comment.