Skip to content

Commit

Permalink
Fix error handling of fit function of training supervisor
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Jan 20, 2025
1 parent 309d72f commit 492471c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
6 changes: 6 additions & 0 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def fit(self):
print("Training has started")
trainer_is_called.set()

def should_stop_model_criteria(self) -> bool:
return False

class MockedTrainerSessionBackend(TrainerSessionProcess):
def init(self, trainer_yaml_config: str = ""):
self._worker = TrainerSessionBackend(MockedNominalTrainer())
Expand Down Expand Up @@ -385,6 +388,9 @@ def fit(self):
for epoch in range(self.max_num_epochs):
self.num_epochs += 1

def should_stop_model_criteria(self) -> bool:
return False

class MockedTrainerSessionBackend(TrainerSessionProcess):
def init(self, trainer_yaml_config: str):
if trainer_yaml_config == "nominal":
Expand Down
7 changes: 3 additions & 4 deletions tiktorch/server/session/backend/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,15 @@ def _start_session(self):
def _fit(self):
try:
self._trainer.fit()
if self.is_training_finished():
logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ")
self._state = TrainerState.FINISHED
except Exception as e:
logger.exception(f"Training error: {e}")
self.training_error_callbacks(e)
self._state = TrainerState.FAILED
return

if self.is_training_finished():
logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ")
self._state = TrainerState.FINISHED

def is_training_finished(self):
return (
self._trainer.num_epochs == self._trainer.max_num_epochs
Expand Down

0 comments on commit 492471c

Please sign in to comment.