From 492471c7a28c0bd0229384b7f18df28931af3b92 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Mon, 20 Jan 2025 17:37:45 +0100 Subject: [PATCH] Fix error handling of fit function of training supervisor --- tests/test_server/test_grpc/test_training_servicer.py | 6 ++++++ tiktorch/server/session/backend/supervisor.py | 7 +++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index 6cd5e4fe..741871a6 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -266,6 +266,9 @@ def fit(self): print("Training has started") trainer_is_called.set() + def should_stop_model_criteria(self) -> bool: + return False + class MockedTrainerSessionBackend(TrainerSessionProcess): def init(self, trainer_yaml_config: str = ""): self._worker = TrainerSessionBackend(MockedNominalTrainer()) @@ -385,6 +388,9 @@ def fit(self): for epoch in range(self.max_num_epochs): self.num_epochs += 1 + def should_stop_model_criteria(self) -> bool: + return False + class MockedTrainerSessionBackend(TrainerSessionProcess): def init(self, trainer_yaml_config: str): if trainer_yaml_config == "nominal": diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index d47e8cde..acd387e1 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -91,16 +91,15 @@ def _start_session(self): def _fit(self): try: self._trainer.fit() + if self.is_training_finished(): + logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ") + self._state = TrainerState.FINISHED except Exception as e: logger.exception(f"Training error: {e}") self.training_error_callbacks(e) self._state = TrainerState.FAILED return - if self.is_training_finished(): - logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ") - self._state = TrainerState.FINISHED - def is_training_finished(self): return ( self._trainer.num_epochs == self._trainer.max_num_epochs