diff --git a/proto/training.proto b/proto/training.proto index e0151152..eb873144 100644 --- a/proto/training.proto +++ b/proto/training.proto @@ -21,6 +21,8 @@ service Training { rpc GetLogs(ModelSession) returns (GetLogsResponse) {} + rpc GetBestModelIdx(ModelSession) returns (stream GetBestModelIdxResponse) {} + rpc Save(SaveRequest) returns (Empty) {} rpc Export(ExportRequest) returns (Empty) {} diff --git a/pytest.ini b/pytest.ini index 0e0eea50..d75ceb27 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,7 @@ [pytest] python_files = test_*.py addopts = - --timeout 60 + --timeout 120 -v -s --color=yes diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index a5df763b..741871a6 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -17,7 +17,7 @@ from tiktorch.server.session.backend.base import TrainerSessionBackend from tiktorch.server.session.process import TrainerSessionProcess from tiktorch.server.session_manager import SessionManager -from tiktorch.trainer import ShouldStopCallbacks, Trainer, TrainerState +from tiktorch.trainer import BaseCallbacks, ShouldStopCallbacks, Trainer, TrainerState @pytest.fixture(scope="module") @@ -60,7 +60,7 @@ def unet2d_config_path( trainer: checkpoint_dir: {checkpoint_dir} resume: {resume if resume else "null"} - validate_after_iters: 250 + validate_after_iters: 1 log_after_iters: 2 max_num_epochs: 10000 max_num_iterations: 100000 @@ -260,11 +260,15 @@ def __init__(self): self.num_iterations = 0 self.max_num_iterations = 100 self.should_stop_callbacks = ShouldStopCallbacks() + self.ping_is_best_callbacks = BaseCallbacks() def fit(self): print("Training has started") trainer_is_called.set() + def should_stop_model_criteria(self) -> bool: + return False + class MockedTrainerSessionBackend(TrainerSessionProcess): def init(self, trainer_yaml_config: str = ""): self._worker = TrainerSessionBackend(MockedNominalTrainer()) @@ -366,6 +370,7 @@ def test_recover_training_failed(self): class MockedExceptionTrainer: def __init__(self): self.should_stop_callbacks = ShouldStopCallbacks() + self.ping_is_best_callbacks = BaseCallbacks() def fit(self): raise Exception("mocked exception") @@ -377,11 +382,15 @@ def __init__(self): self.num_iterations = 0 self.max_num_iterations = 100 self.should_stop_callbacks = ShouldStopCallbacks() + self.ping_is_best_callbacks = BaseCallbacks() def fit(self): for epoch in range(self.max_num_epochs): self.num_epochs += 1 + def should_stop_model_criteria(self) -> bool: + return False + class MockedTrainerSessionBackend(TrainerSessionProcess): def init(self, trainer_yaml_config: str): if trainer_yaml_config == "nominal": @@ -416,6 +425,7 @@ def assert_error(func, expected_message: str): class MockedExceptionTrainer: def __init__(self): self.should_stop_callbacks = ShouldStopCallbacks() + self.ping_is_best_callbacks = BaseCallbacks() def fit(self): raise Exception("mocked exception") @@ -665,6 +675,22 @@ def test_export_while_paused(self, grpc_stub): # assume stopping training since model is exported grpc_stub.CloseTrainerSession(training_session_id) + def test_best_model_ping(self, grpc_stub): + training_session_id = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()) + ) + + grpc_stub.Start(training_session_id) + + responses = grpc_stub.GetBestModelIdx(training_session_id) + received_updates = 0 + for response in responses: + assert isinstance(response, training_pb2.GetBestModelIdxResponse) + assert response.id is not None + received_updates += 1 + if received_updates >= 2: + break + def test_close_session(self, grpc_stub): """ Test closing a training session. diff --git a/tiktorch/proto/training_pb2.py b/tiktorch/proto/training_pb2.py index 18f1a52e..696528bf 100644 --- a/tiktorch/proto/training_pb2.py +++ b/tiktorch/proto/training_pb2.py @@ -14,7 +14,7 @@ from . import utils_pb2 as utils__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"%\n\x17GetBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"F\n\x0bSaveRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"H\n\rExportRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xb3\x04\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x31\n\x04Init\x12\x18.training.TrainingConfig\x1a\r.ModelSession\"\x00\x12 \n\x05Start\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12!\n\x06Resume\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12 \n\x05Pause\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12\x42\n\rStreamUpdates\x12\r.ModelSession\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x35\n\x07GetLogs\x12\r.ModelSession\x1a\x19.training.GetLogsResponse\"\x00\x12\'\n\x04Save\x12\x15.training.SaveRequest\x1a\x06.Empty\"\x00\x12+\n\x06\x45xport\x12\x17.training.ExportRequest\x1a\x06.Empty\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x12\x39\n\tGetStatus\x12\r.ModelSession\x1a\x1b.training.GetStatusResponse\"\x00\x12.\n\x13\x43loseTrainerSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"%\n\x17GetBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"F\n\x0bSaveRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"H\n\rExportRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x10\n\x08\x66ilePath\x18\x02 \x01(\t\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xfc\x04\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x31\n\x04Init\x12\x18.training.TrainingConfig\x1a\r.ModelSession\"\x00\x12 \n\x05Start\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12!\n\x06Resume\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12 \n\x05Pause\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12\x42\n\rStreamUpdates\x12\r.ModelSession\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x35\n\x07GetLogs\x12\r.ModelSession\x1a\x19.training.GetLogsResponse\"\x00\x12G\n\x0fGetBestModelIdx\x12\r.ModelSession\x1a!.training.GetBestModelIdxResponse\"\x00\x30\x01\x12\'\n\x04Save\x12\x15.training.SaveRequest\x1a\x06.Empty\"\x00\x12+\n\x06\x45xport\x12\x17.training.ExportRequest\x1a\x06.Empty\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x12\x39\n\tGetStatus\x12\r.ModelSession\x1a\x1b.training.GetStatusResponse\"\x00\x12.\n\x13\x43loseTrainerSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'training_pb2', globals()) @@ -46,5 +46,5 @@ _TRAININGCONFIG._serialized_start=735 _TRAININGCONFIG._serialized_end=773 _TRAINING._serialized_start=776 - _TRAINING._serialized_end=1339 + _TRAINING._serialized_end=1412 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/training_pb2_grpc.py b/tiktorch/proto/training_pb2_grpc.py index fffecd37..b51ae922 100644 --- a/tiktorch/proto/training_pb2_grpc.py +++ b/tiktorch/proto/training_pb2_grpc.py @@ -50,6 +50,11 @@ def __init__(self, channel): request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=training__pb2.GetLogsResponse.FromString, ) + self.GetBestModelIdx = channel.unary_stream( + '/training.Training/GetBestModelIdx', + request_serializer=utils__pb2.ModelSession.SerializeToString, + response_deserializer=training__pb2.GetBestModelIdxResponse.FromString, + ) self.Save = channel.unary_unary( '/training.Training/Save', request_serializer=training__pb2.SaveRequest.SerializeToString, @@ -122,6 +127,12 @@ def GetLogs(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetBestModelIdx(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def Save(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -190,6 +201,11 @@ def add_TrainingServicer_to_server(servicer, server): request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=training__pb2.GetLogsResponse.SerializeToString, ), + 'GetBestModelIdx': grpc.unary_stream_rpc_method_handler( + servicer.GetBestModelIdx, + request_deserializer=utils__pb2.ModelSession.FromString, + response_serializer=training__pb2.GetBestModelIdxResponse.SerializeToString, + ), 'Save': grpc.unary_unary_rpc_method_handler( servicer.Save, request_deserializer=training__pb2.SaveRequest.FromString, @@ -344,6 +360,23 @@ def GetLogs(request, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod + def GetBestModelIdx(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/training.Training/GetBestModelIdx', + utils__pb2.ModelSession.SerializeToString, + training__pb2.GetBestModelIdxResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod def Save(request, target, diff --git a/tiktorch/server/grpc/training_servicer.py b/tiktorch/server/grpc/training_servicer.py index 70ca3ef8..f2264fba 100644 --- a/tiktorch/server/grpc/training_servicer.py +++ b/tiktorch/server/grpc/training_servicer.py @@ -2,6 +2,7 @@ import logging import queue +import time from pathlib import Path from typing import Callable, List @@ -87,6 +88,17 @@ def StreamUpdates(self, request: utils_pb2.ModelSession, context): def GetLogs(self, request: utils_pb2.ModelSession, context): raise NotImplementedError + def GetBestModelIdx(self, request, context): + session = self._getTrainerSession(context, request) + prev_best_model_idx = None + while context.is_active(): + current_best_model_idx = session.client.get_best_model_idx() + if current_best_model_idx != prev_best_model_idx: + prev_best_model_idx = current_best_model_idx + yield training_pb2.GetBestModelIdxResponse(id=str(current_best_model_idx)) + time.sleep(1) + logger.info("Client disconnected. Stopping stream.") + def GetStatus(self, request: utils_pb2.ModelSession, context): session = self._getTrainerSession(context, request) state = session.client.get_state() diff --git a/tiktorch/server/session/backend/base.py b/tiktorch/server/session/backend/base.py index 7a0053a4..c8c95cf8 100644 --- a/tiktorch/server/session/backend/base.py +++ b/tiktorch/server/session/backend/base.py @@ -94,3 +94,6 @@ def export(self, file_path: Path) -> None: def get_state(self) -> TrainerState: return self._supervisor.get_state() + + def get_best_model_idx(self) -> int: + return self._supervisor.get_best_model_idx() diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index d4d4e45b..acd387e1 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -52,6 +52,15 @@ def __init__(self, trainer: Trainer) -> None: self._session_thread = threading.Thread(target=self._start_session, name="SessionThread") self._command_queue_utils = CommandPriorityQueueUtils() self.training_error_callbacks: ErrorCallbacks = BaseCallbacks() + self._best_model_idx = 0 + self._trainer.ping_is_best_callbacks.register(self._increment_best_model_idx) + + def get_best_model_idx(self) -> int: + return self._best_model_idx + + def _increment_best_model_idx(self): + self._best_model_idx += 1 + logger.debug(f"New best model detected with id {self._best_model_idx}") def get_state(self) -> TrainerState: logger.debug(f"Get state called {self._state}") @@ -82,16 +91,15 @@ def _start_session(self): def _fit(self): try: self._trainer.fit() + if self.is_training_finished(): + logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ") + self._state = TrainerState.FINISHED except Exception as e: logger.exception(f"Training error: {e}") self.training_error_callbacks(e) self._state = TrainerState.FAILED return - if self.is_training_finished(): - logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ") - self._state = TrainerState.FINISHED - def is_training_finished(self): return ( self._trainer.num_epochs == self._trainer.max_num_epochs diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index dfb2c21c..b87e2f69 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -19,7 +19,7 @@ from tiktorch.rpc.mp import BioModelClient, MPServer from ...converters import Sample -from ...trainer import TrainerYamlParser +from ...trainer import TrainerState, TrainerYamlParser from .backend import base from .rpc_interface import IRPCModelSession, IRPCTrainer @@ -145,9 +145,12 @@ def save(self, file_path: Path): def export(self, file_path: Path): self.worker.export(file_path) - def get_state(self): + def get_state(self) -> TrainerState: return self.worker.get_state() + def get_best_model_idx(self) -> int: + return self.worker.get_best_model_idx() + def shutdown(self): if self._worker is None: return Shutdown() diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index b6dbb132..b35886e8 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -89,3 +89,7 @@ def export(self, file_path: Path): @exposed def get_state(self) -> TrainerState: raise NotImplementedError + + @exposed + def get_best_model_idx(self) -> int: + raise NotImplementedError diff --git a/tiktorch/trainer.py b/tiktorch/trainer.py index ef54e20f..1ab7551d 100644 --- a/tiktorch/trainer.py +++ b/tiktorch/trainer.py @@ -176,6 +176,7 @@ def __init__( self._device = device self.logs_callbacks: LogsCallbacks = BaseCallbacks() self.should_stop_callbacks: Callbacks = ShouldStopCallbacks() + self.ping_is_best_callbacks = BaseCallbacks() # notification of having a new best model def fit(self): return super().fit() @@ -184,7 +185,11 @@ def train(self): return super().train() def validate(self): - return super().validate() + eval_score = super().validate() + is_best = self._is_best_eval_score(eval_score) + if is_best: + self.ping_is_best_callbacks() + return eval_score def save_state_dict(self, file_path: Path): """