Skip to content

Commit

Permalink
Merge pull request #229 from thodkatz/add-is-best-model-to-training-s…
Browse files Browse the repository at this point in the history
…ervicer

Add is best model to training servicer
  • Loading branch information
thodkatz authored Jan 20, 2025
2 parents 5f5fbd1 + 492471c commit 15a0225
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 12 deletions.
2 changes: 2 additions & 0 deletions proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ service Training {

rpc GetLogs(ModelSession) returns (GetLogsResponse) {}

rpc GetBestModelIdx(ModelSession) returns (stream GetBestModelIdxResponse) {}

rpc Save(SaveRequest) returns (Empty) {}

rpc Export(ExportRequest) returns (Empty) {}
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[pytest]
python_files = test_*.py
addopts =
--timeout 60
--timeout 120
-v
-s
--color=yes
Expand Down
30 changes: 28 additions & 2 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tiktorch.server.session.backend.base import TrainerSessionBackend
from tiktorch.server.session.process import TrainerSessionProcess
from tiktorch.server.session_manager import SessionManager
from tiktorch.trainer import ShouldStopCallbacks, Trainer, TrainerState
from tiktorch.trainer import BaseCallbacks, ShouldStopCallbacks, Trainer, TrainerState


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -60,7 +60,7 @@ def unet2d_config_path(
trainer:
checkpoint_dir: {checkpoint_dir}
resume: {resume if resume else "null"}
validate_after_iters: 250
validate_after_iters: 1
log_after_iters: 2
max_num_epochs: 10000
max_num_iterations: 100000
Expand Down Expand Up @@ -260,11 +260,15 @@ def __init__(self):
self.num_iterations = 0
self.max_num_iterations = 100
self.should_stop_callbacks = ShouldStopCallbacks()
self.ping_is_best_callbacks = BaseCallbacks()

def fit(self):
print("Training has started")
trainer_is_called.set()

def should_stop_model_criteria(self) -> bool:
return False

class MockedTrainerSessionBackend(TrainerSessionProcess):
def init(self, trainer_yaml_config: str = ""):
self._worker = TrainerSessionBackend(MockedNominalTrainer())
Expand Down Expand Up @@ -366,6 +370,7 @@ def test_recover_training_failed(self):
class MockedExceptionTrainer:
def __init__(self):
self.should_stop_callbacks = ShouldStopCallbacks()
self.ping_is_best_callbacks = BaseCallbacks()

def fit(self):
raise Exception("mocked exception")
Expand All @@ -377,11 +382,15 @@ def __init__(self):
self.num_iterations = 0
self.max_num_iterations = 100
self.should_stop_callbacks = ShouldStopCallbacks()
self.ping_is_best_callbacks = BaseCallbacks()

def fit(self):
for epoch in range(self.max_num_epochs):
self.num_epochs += 1

def should_stop_model_criteria(self) -> bool:
return False

class MockedTrainerSessionBackend(TrainerSessionProcess):
def init(self, trainer_yaml_config: str):
if trainer_yaml_config == "nominal":
Expand Down Expand Up @@ -416,6 +425,7 @@ def assert_error(func, expected_message: str):
class MockedExceptionTrainer:
def __init__(self):
self.should_stop_callbacks = ShouldStopCallbacks()
self.ping_is_best_callbacks = BaseCallbacks()

def fit(self):
raise Exception("mocked exception")
Expand Down Expand Up @@ -665,6 +675,22 @@ def test_export_while_paused(self, grpc_stub):
# assume stopping training since model is exported
grpc_stub.CloseTrainerSession(training_session_id)

def test_best_model_ping(self, grpc_stub):
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())
)

grpc_stub.Start(training_session_id)

responses = grpc_stub.GetBestModelIdx(training_session_id)
received_updates = 0
for response in responses:
assert isinstance(response, training_pb2.GetBestModelIdxResponse)
assert response.id is not None
received_updates += 1
if received_updates >= 2:
break

def test_close_session(self, grpc_stub):
"""
Test closing a training session.
Expand Down
4 changes: 2 additions & 2 deletions tiktorch/proto/training_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions tiktorch/proto/training_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def __init__(self, channel):
request_serializer=utils__pb2.ModelSession.SerializeToString,
response_deserializer=training__pb2.GetLogsResponse.FromString,
)
self.GetBestModelIdx = channel.unary_stream(
'/training.Training/GetBestModelIdx',
request_serializer=utils__pb2.ModelSession.SerializeToString,
response_deserializer=training__pb2.GetBestModelIdxResponse.FromString,
)
self.Save = channel.unary_unary(
'/training.Training/Save',
request_serializer=training__pb2.SaveRequest.SerializeToString,
Expand Down Expand Up @@ -122,6 +127,12 @@ def GetLogs(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetBestModelIdx(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Save(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand Down Expand Up @@ -190,6 +201,11 @@ def add_TrainingServicer_to_server(servicer, server):
request_deserializer=utils__pb2.ModelSession.FromString,
response_serializer=training__pb2.GetLogsResponse.SerializeToString,
),
'GetBestModelIdx': grpc.unary_stream_rpc_method_handler(
servicer.GetBestModelIdx,
request_deserializer=utils__pb2.ModelSession.FromString,
response_serializer=training__pb2.GetBestModelIdxResponse.SerializeToString,
),
'Save': grpc.unary_unary_rpc_method_handler(
servicer.Save,
request_deserializer=training__pb2.SaveRequest.FromString,
Expand Down Expand Up @@ -344,6 +360,23 @@ def GetLogs(request,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetBestModelIdx(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(request, target, '/training.Training/GetBestModelIdx',
utils__pb2.ModelSession.SerializeToString,
training__pb2.GetBestModelIdxResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def Save(request,
target,
Expand Down
12 changes: 12 additions & 0 deletions tiktorch/server/grpc/training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import queue
import time
from pathlib import Path
from typing import Callable, List

Expand Down Expand Up @@ -87,6 +88,17 @@ def StreamUpdates(self, request: utils_pb2.ModelSession, context):
def GetLogs(self, request: utils_pb2.ModelSession, context):
raise NotImplementedError

def GetBestModelIdx(self, request, context):
session = self._getTrainerSession(context, request)
prev_best_model_idx = None
while context.is_active():
current_best_model_idx = session.client.get_best_model_idx()
if current_best_model_idx != prev_best_model_idx:
prev_best_model_idx = current_best_model_idx
yield training_pb2.GetBestModelIdxResponse(id=str(current_best_model_idx))
time.sleep(1)
logger.info("Client disconnected. Stopping stream.")

def GetStatus(self, request: utils_pb2.ModelSession, context):
session = self._getTrainerSession(context, request)
state = session.client.get_state()
Expand Down
3 changes: 3 additions & 0 deletions tiktorch/server/session/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,6 @@ def export(self, file_path: Path) -> None:

def get_state(self) -> TrainerState:
return self._supervisor.get_state()

def get_best_model_idx(self) -> int:
return self._supervisor.get_best_model_idx()
16 changes: 12 additions & 4 deletions tiktorch/server/session/backend/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ def __init__(self, trainer: Trainer) -> None:
self._session_thread = threading.Thread(target=self._start_session, name="SessionThread")
self._command_queue_utils = CommandPriorityQueueUtils()
self.training_error_callbacks: ErrorCallbacks = BaseCallbacks()
self._best_model_idx = 0
self._trainer.ping_is_best_callbacks.register(self._increment_best_model_idx)

def get_best_model_idx(self) -> int:
return self._best_model_idx

def _increment_best_model_idx(self):
self._best_model_idx += 1
logger.debug(f"New best model detected with id {self._best_model_idx}")

def get_state(self) -> TrainerState:
logger.debug(f"Get state called {self._state}")
Expand Down Expand Up @@ -82,16 +91,15 @@ def _start_session(self):
def _fit(self):
try:
self._trainer.fit()
if self.is_training_finished():
logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ")
self._state = TrainerState.FINISHED
except Exception as e:
logger.exception(f"Training error: {e}")
self.training_error_callbacks(e)
self._state = TrainerState.FAILED
return

if self.is_training_finished():
logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ")
self._state = TrainerState.FINISHED

def is_training_finished(self):
return (
self._trainer.num_epochs == self._trainer.max_num_epochs
Expand Down
7 changes: 5 additions & 2 deletions tiktorch/server/session/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tiktorch.rpc.mp import BioModelClient, MPServer

from ...converters import Sample
from ...trainer import TrainerYamlParser
from ...trainer import TrainerState, TrainerYamlParser
from .backend import base
from .rpc_interface import IRPCModelSession, IRPCTrainer

Expand Down Expand Up @@ -145,9 +145,12 @@ def save(self, file_path: Path):
def export(self, file_path: Path):
self.worker.export(file_path)

def get_state(self):
def get_state(self) -> TrainerState:
return self.worker.get_state()

def get_best_model_idx(self) -> int:
return self.worker.get_best_model_idx()

def shutdown(self):
if self._worker is None:
return Shutdown()
Expand Down
4 changes: 4 additions & 0 deletions tiktorch/server/session/rpc_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ def export(self, file_path: Path):
@exposed
def get_state(self) -> TrainerState:
raise NotImplementedError

@exposed
def get_best_model_idx(self) -> int:
raise NotImplementedError
7 changes: 6 additions & 1 deletion tiktorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def __init__(
self._device = device
self.logs_callbacks: LogsCallbacks = BaseCallbacks()
self.should_stop_callbacks: Callbacks = ShouldStopCallbacks()
self.ping_is_best_callbacks = BaseCallbacks() # notification of having a new best model

def fit(self):
return super().fit()
Expand All @@ -184,7 +185,11 @@ def train(self):
return super().train()

def validate(self):
return super().validate()
eval_score = super().validate()
is_best = self._is_best_eval_score(eval_score)
if is_best:
self.ping_is_best_callbacks()
return eval_score

def save_state_dict(self, file_path: Path):
"""
Expand Down

0 comments on commit 15a0225

Please sign in to comment.