Skip to content

Commit

Permalink
finished draft for HFGPLO, fixed import errors on usage
Browse files Browse the repository at this point in the history
  • Loading branch information
parthraut committed Feb 8, 2024
1 parent 133eefe commit bf71b16
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 50 deletions.
23 changes: 5 additions & 18 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@

from zeus.monitor import ZeusMonitor
from zeus.optimizer import GlobalPowerLimitOptimizer

if __name__ == '__main__':
monitor = ZeusMonitor(gpu_indices=[0,1,2,3])

monitor.begin_window()

measurement = monitor.end_window("heavy computation")
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, PreTrainedModel

from zeus.optimizer import HFGPLO
from zeus.monitor import ZeusMonitor

print(f"Energy: {measurement.total_energy} J")
print(f"Time : {measurement.time} s")


plo = GlobalPowerLimitOptimizer(monitor)

# training loop
import pdb

plo.on_step_begin()
pdb.set_trace()




1 change: 1 addition & 0 deletions zeus/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
"""A collection of optimizers for various knobs."""

from zeus.optimizer.power_limit import GlobalPowerLimitOptimizer
from zeus.optimizer.power_limit import HFGPLO
52 changes: 20 additions & 32 deletions zeus/optimizer/power_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,29 @@ def _save_profile(self) -> None:

# only import when type checking
if TYPE_CHECKING:
from transformers import Trainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl, PreTrainedModel
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, PreTrainedModel

# to avoid hard dependency on HuggingFace Transformers, import classes dynamically
def import_hf_classes():
try:
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, PreTrainedModel
return TrainerCallback, TrainingArguments, TrainerState, TrainerControl, PreTrainedModel
except ImportError:
return None


def make_hf(cls: Type[Callback], name: str | None = None) -> Type[TrainerCallback]:

# Attempt to import HuggingFace classes
hf_classes = import_hf_classes()
if hf_classes is None:
raise ImportError("Hugging Face is not installed. Please install it to use this feature.")

TrainerCallback, TrainingArguments, TrainerState, TrainerControl, PreTrainedModel = hf_classes

class Wrapper(TrainerCallback):
# goal: help(HFGPLO) should show the init signature of GlobalPowerLimitOptimizer
# if that doesn't work, then standard class
def __init__(self, *args, **kwargs) -> None:
self.plo = cls(*args, **kwargs) # keep it args, kwargs, or specify to GlobalPowerLimitOptimizer?

Expand All @@ -502,43 +520,13 @@ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: Tr
self.plo.on_epoch_end()

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
self.plo.on_evaluate() # what to set metric to?

# NO MATCH
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
# self.plo.on_init_end() no match to zeus callback, should be overridden?
pass

# NO MATCH
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
# self.plo.on_log() no match
pass

# NO MATCH
def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
# self.plo.on_predict() no match
pass

# NO MATCH
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
# self.plo.on_prediction_step() no match
pass

# NO MATCH
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
# self.plo.on_save() no match
pass
self.plo.on_evaluate() # what to set metric to? think it is called with metric, look into it

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
self.plo.on_step_begin()

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
self.plo.on_step_end()

# NO MATCH
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
# self.plo.on_substep_end() no match
pass

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs) -> None:
self.plo.on_train_begin()
Expand Down

0 comments on commit bf71b16

Please sign in to comment.