From 486ad1c4a99e50ea05abec6eb5b233f4e53755e8 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Mon, 9 Dec 2024 17:14:42 +0100 Subject: [PATCH] Add training service - Supported operations: start, resume, pause, shutdown - pytorch-3dunet package is used as the framework to create the models --- proto/training.proto | 106 +++++ setup.cfg | 2 +- .../test_grpc/test_training_servicer.py | 417 ++++++++++++++++++ .../test_training/test_training.py | 126 ------ .../test_worker/test_commands.py | 8 +- tiktorch/converters.py | 13 +- tiktorch/proto/training_pb2.py | 163 +++++++ tiktorch/proto/training_pb2_grpc.py | 396 +++++++++++++++++ tiktorch/rpc/interface.py | 1 + tiktorch/server/grpc/__init__.py | 9 +- tiktorch/server/grpc/training_servicer.py | 100 +++++ tiktorch/server/session/backend/base.py | 81 ++-- tiktorch/server/session/backend/commands.py | 146 +++++- tiktorch/server/session/backend/supervisor.py | 305 ++++++++----- tiktorch/server/session/process.py | 57 ++- tiktorch/server/session/rpc_interface.py | 40 ++ tiktorch/trainer.py | 190 ++++++++ 17 files changed, 1866 insertions(+), 294 deletions(-) create mode 100644 proto/training.proto create mode 100644 tests/test_server/test_grpc/test_training_servicer.py delete mode 100644 tests/test_server/test_training/test_training.py create mode 100644 tiktorch/proto/training_pb2.py create mode 100644 tiktorch/proto/training_pb2_grpc.py create mode 100644 tiktorch/server/grpc/training_servicer.py create mode 100644 tiktorch/trainer.py diff --git a/proto/training.proto b/proto/training.proto new file mode 100644 index 00000000..ed280cbc --- /dev/null +++ b/proto/training.proto @@ -0,0 +1,106 @@ +syntax = "proto3"; + +package training; + + +message Empty {} + + +service Training { + rpc Init(TrainingConfig) returns (TrainingSessionId) {} + + rpc Start(TrainingSessionId) returns (Empty) {} + + rpc Resume(TrainingSessionId) returns (Empty) {} + + rpc Pause(TrainingSessionId) returns (Empty) {} + + rpc StreamUpdates(TrainingSessionId) returns (stream StreamUpdateResponse) {} + + rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {} + + rpc Save(TrainingSessionId) returns (Empty) {} + + rpc Export(TrainingSessionId) returns (Empty) {} + + rpc Predict(PredictRequest) returns (PredictResponse) {} + + rpc GetStatus(TrainingSessionId) returns (GetStatusResponse) {} + + rpc CloseTrainerSession(TrainingSessionId) returns (Empty) {} +} + +message TrainingSessionId { + string id = 1; +} + +message Logs { + enum ModelPhase { + Train = 0; + Eval = 1; + } + ModelPhase mode = 1; + double eval_score = 2; + double loss = 3; + uint32 iteration = 4; +} + + +message StreamUpdateResponse { + uint32 best_model_idx = 1; + Logs logs = 2; +} + + +message GetLogsResponse { + repeated Logs logs = 1; +} + + +message NamedInt { + uint32 size = 1; + string name = 2; +} + + +message Tensor { + bytes buffer = 1; + string dtype = 2; + repeated NamedInt shape = 4; +} + + +message PredictRequest { + repeated Tensor tensors = 1; + TrainingSessionId id = 2; +} + + +message PredictResponse { + uint32 best_model_idx = 1; + repeated Tensor tensors = 2; +} + +message ValidationResponse { + double validation_score_average = 1; +} + +message GetStatusResponse { + enum State { + Idle = 0; + Running = 1; + Paused = 2; + Failed = 3; + Finished = 4; + } + State state = 1; +} + + +message GetCurrentBestModelIdxResponse { + uint32 id = 1; +} + +message TrainingConfig { + string yaml_content = 1; +} diff --git a/setup.cfg b/setup.cfg index aecec9f7..50d7fac8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,5 +12,5 @@ max-line-length = 120 [flake8] max-line-length = 120 -ignore=E203 +ignore=E203,W503 exclude = tiktorch/proto/*,vendor diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py new file mode 100644 index 00000000..5ce1ee3f --- /dev/null +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -0,0 +1,417 @@ +import tempfile +import threading +import time +from pathlib import Path + +import grpc +import h5py +import numpy as np +import pytest + +from tiktorch.converters import trainer_state_to_pb +from tiktorch.proto import training_pb2, training_pb2_grpc +from tiktorch.server.device_pool import TorchDevicePool +from tiktorch.server.grpc import training_servicer +from tiktorch.server.session_manager import SessionManager +from tiktorch.trainer import Callbacks, TrainerState, TrainerYamlParser + + +@pytest.fixture(scope="module") +def grpc_add_to_server(): + return training_pb2_grpc.add_TrainingServicer_to_server + + +@pytest.fixture(scope="module") +def grpc_servicer(): + return training_servicer.TrainingServicer(TorchDevicePool(), SessionManager()) + + +@pytest.fixture(autouse=True) +def clean(grpc_servicer): + yield + grpc_servicer.close_all_sessions() + + +@pytest.fixture(scope="module") +def grpc_stub_cls(): + return training_pb2_grpc.TrainingStub + + +def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: str = "cpu"): + return f""" +device: {device} # Use CPU for faster test execution, change to 'cuda' if GPU is available and necessary +model: + name: UNet2D + in_channels: 3 + out_channels: 2 + layer_order: gcr + f_maps: 16 + num_groups: 4 + final_sigmoid: false + is_segmentation: true +trainer: + checkpoint_dir: {checkpoint_dir} + resume: null + validate_after_iters: 2 + log_after_iters: 2 + max_num_epochs: 1 + max_num_iterations: 2 + eval_score_higher_is_better: True +optimizer: + learning_rate: 0.0002 + weight_decay: 0.00001 +loss: + name: CrossEntropyLoss +eval_metric: + name: MeanIoU + ignore_index: null +lr_scheduler: + name: MultiStepLR + milestones: [2, 3] + gamma: 0.5 +loaders: + dataset: StandardHDF5Dataset + batch_size: 1 + num_workers: 1 + raw_internal_path: raw + label_internal_path: label + weight_internal_path: null + train: + file_paths: + - {train_data_dir} + + slice_builder: + name: SliceBuilder + patch_shape: [1, 64, 64] + stride_shape: [1, 64, 64] + skip_shape_check: true + + transformer: + raw: + - name: Standardize + - name: RandomFlip + - name: RandomRotate90 + - name: RandomRotate + axes: [[2, 1]] + angle_spectrum: 30 + mode: reflect + - name: ElasticDeformation + execution_probability: 1.0 + spline_order: 3 + - name: AdditiveGaussianNoise + execution_probability: 1.0 + - name: AdditivePoissonNoise + execution_probability: 1.0 + - name: ToTensor + expand_dims: true + label: + - name: RandomFlip + - name: RandomRotate90 + - name: RandomRotate + axes: [[2, 1]] + angle_spectrum: 30 + mode: reflect + - name: ElasticDeformation + execution_probability: 1.0 + spline_order: 0 + - name: ToTensor + # do not expand dims for cross-entropy loss + expand_dims: false + # cross-entropy loss requires target to be of type 'long' + dtype: 'long' + weight: + - name: ToTensor + expand_dims: false + val: + file_paths: + - {val_data_path} + + slice_builder: + name: SliceBuilder + patch_shape: [1, 64, 64] + stride_shape: [1, 64, 64] + skip_shape_check: true + + transformer: + raw: + - name: Standardize + - name: ToTensor + expand_dims: true + label: + - name: ToTensor + expand_dims: false + dtype: 'long' + weight: + - name: ToTensor + expand_dims: false +""" + + +def create_random_dataset(shape, channel_per_class): + tmp = tempfile.NamedTemporaryFile(delete=False) + + with h5py.File(tmp.name, "w") as f: + l_shape = w_shape = shape + # make sure that label and weight tensors are 3D + if len(shape) == 4: + l_shape = shape[1:] + w_shape = shape[1:] + + if channel_per_class: + l_shape = (2,) + l_shape + + f.create_dataset("raw", data=np.random.rand(*shape)) + f.create_dataset("label", data=np.random.randint(0, 2, l_shape)) + f.create_dataset("weight_map", data=np.random.rand(*w_shape)) + + return tmp.name + + +def prepare_unet2d_test_environment(device: str = "cpu") -> str: + checkpoint_dir = Path(tempfile.mkdtemp()) + + shape = (3, 1, 128, 128) + binary_loss = False + train = create_random_dataset(shape, binary_loss) + val = create_random_dataset(shape, binary_loss) + + return unet2d_config_path(checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device) + + +class TestTrainingServicer: + def assert_state(self, grpc_stub, training_session_id: str, state_to_check: TrainerState): + response = grpc_stub.GetStatus(training_session_id) + assert response.state == trainer_state_to_pb[state_to_check] + + def poll_for_state(self, grpc_stub, session_id, expected_state: TrainerState, timeout=3, poll_interval=0.1): + start_time = time.time() + + while True: + status_response = grpc_stub.GetStatus(session_id) + current_state = status_response.state + + if current_state == trainer_state_to_pb[expected_state]: + return current_state + + if time.time() - start_time > timeout: + pytest.fail(f"Timeout: State did not transition to {expected_state} within {timeout} seconds.") + + time.sleep(poll_interval) + + def test_init_success(self, grpc_stub): + """ + Test that a session initializes successfully with valid YAML. + """ + response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + assert response.id is not None, "Failed to initialize training session" + + def test_init_invalid_yaml(self, grpc_stub): + """ + Test that initializing with invalid YAML raises an error. + """ + invalid_yaml = "invalid_yaml_content: {unbalanced_braces" + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=invalid_yaml)) + assert "expected ',' or '}', but got" in excinfo.value.details() + + def test_start_training_success(self, grpc_stub, monkeypatch): + """ + Test starting training after successful initialization. + """ + trainer_is_called = threading.Event() + + class MockedNominalTrainer: + def __init__(self): + self.num_epochs = 0 + self.max_num_epochs = 10 + self.num_iterations = 0 + self.max_num_iterations = 100 + self.should_stop_callbacks = Callbacks() + + def fit(self): + print("Training has started") + trainer_is_called.set() + + monkeypatch.setattr(TrainerYamlParser, "parse", lambda *args: MockedNominalTrainer()) + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + grpc_stub.Start(training_pb2.TrainingSessionId(id=init_response.id)) + trainer_is_called.wait(timeout=5) + + def test_concurrent_state_transitions(self, grpc_stub): + """ + Test concurrent calls to Start, Pause, and Resume to ensure no deadlocks or race conditions. + + The test should exit gracefully without hanging processes or threads. + """ + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + threads = [] + for _ in range(3): + threads.append(threading.Thread(target=lambda: grpc_stub.Start(training_session_id))) + threads.append(threading.Thread(target=lambda: grpc_stub.Pause(training_session_id))) + threads.append(threading.Thread(target=lambda: grpc_stub.Resume(training_session_id))) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + def test_queueing_multiple_commands(self, grpc_stub): + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + def assert_state(state_to_check): + self.assert_state(grpc_stub, training_session_id, state_to_check) + + grpc_stub.Start(training_session_id) + assert_state(TrainerState.RUNNING) + + for _ in range(3): + grpc_stub.Pause(training_session_id) + assert_state(TrainerState.PAUSED) + + grpc_stub.Resume(training_session_id) + assert_state(TrainerState.RUNNING) + + def test_error_handling_on_invalid_state_transitions_after_training_started(self, grpc_stub): + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + # Attempt to start again while already running + grpc_stub.Start(training_session_id) + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Start(training_session_id) + assert "Invalid state transition: TrainerState.RUNNING -> TrainerState.RUNNING" in excinfo.value.details() + + # Attempt to pause again while already paused + grpc_stub.Pause(training_session_id) + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Pause(training_session_id) + assert "Invalid state transition: TrainerState.PAUSED -> TrainerState.PAUSED" in excinfo.value.details() + + # Attempt to resume again while already resumed + grpc_stub.Resume(training_session_id) + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Resume(training_session_id) + assert "Invalid state transition: TrainerState.RUNNING -> TrainerState.RUNNING" in excinfo.value.details() + + def test_error_handling_on_invalid_state_transitions_before_training_started(self, grpc_stub): + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + # Attempt to resume before start + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Resume(training_session_id) + assert "Invalid state transition: TrainerState.IDLE -> TrainerState.RUNNING" in excinfo.value.details() + + # Attempt to pause before start + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Pause(training_session_id) + assert "Invalid state transition: TrainerState.IDLE -> TrainerState.PAUSED" in excinfo.value.details() + + def test_start_training_without_init(self, grpc_stub): + """ + Test starting training without initializing a session. + """ + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.Start(training_pb2.Empty()) + assert excinfo.value.code() == grpc.StatusCode.FAILED_PRECONDITION + assert "trainer-session with id doesn't exist" in excinfo.value.details() + + def test_recover_training_failed(self, grpc_stub, monkeypatch): + class MockedExceptionTrainer: + def __init__(self): + self.should_stop_callbacks = Callbacks() + + def fit(self): + raise Exception("mocked exception") + + class MockedNominalTrainer: + def __init__(self): + self.num_epochs = 0 + self.max_num_epochs = 10 + self.num_iterations = 0 + self.max_num_iterations = 100 + self.should_stop_callbacks = Callbacks() + + def fit(self): + for epoch in range(self.max_num_epochs): + self.num_epochs += 1 + + monkeypatch.setattr(TrainerYamlParser, "parse", lambda *args: MockedExceptionTrainer()) + + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + grpc_stub.Start(training_session_id) + + # client detects that state is failed, closes the session and starts a new one + self.poll_for_state(grpc_stub=grpc_stub, session_id=training_session_id, expected_state=TrainerState.FAILED) + + grpc_stub.CloseTrainerSession(training_session_id) + + monkeypatch.setattr(TrainerYamlParser, "parse", lambda *args: MockedNominalTrainer()) + + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + grpc_stub.Start(training_session_id) + self.poll_for_state(grpc_stub=grpc_stub, session_id=training_session_id, expected_state=TrainerState.FINISHED) + + def test_graceful_shutdown_for_any_state(self, grpc_stub): + # after init + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + grpc_stub.CloseTrainerSession(training_session_id) + + # after start + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + grpc_stub.Start(training_session_id) + grpc_stub.CloseTrainerSession(training_session_id) + + # after pause + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + grpc_stub.Start(training_session_id) + grpc_stub.Pause(training_session_id) + grpc_stub.CloseTrainerSession(training_session_id) + + # after resume + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + grpc_stub.Start(training_session_id) + grpc_stub.Pause(training_session_id) + grpc_stub.Resume(training_session_id) + grpc_stub.CloseTrainerSession(training_session_id) + + # attempt to close again + with pytest.raises(grpc.RpcError) as excinfo: + grpc_stub.CloseTrainerSession(training_session_id) + assert "Unknown session" in excinfo.value.details() + + def test_close_session(self, grpc_stub): + """ + Test closing a training session. + """ + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + grpc_stub.CloseTrainerSession(training_session_id) + + # attempt to perform an operation while session is closed + operations = [grpc_stub.Start, grpc_stub.Pause, grpc_stub.Resume] + for operation in operations: + with pytest.raises(grpc.RpcError) as excinfo: + operation(training_session_id) + assert "doesn't exist" in excinfo.value.details() + + def test_multiple_sessions(self, grpc_stub): + response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + assert response.id is not None + + response = grpc_stub.Init( + training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(device="gpu")) + ) + assert response.id is not None diff --git a/tests/test_server/test_training/test_training.py b/tests/test_server/test_training/test_training.py deleted file mode 100644 index 3492d6f4..00000000 --- a/tests/test_server/test_training/test_training.py +++ /dev/null @@ -1,126 +0,0 @@ -import threading -import time -from concurrent.futures import Future - -import numpy as np -import pytest -import xarray as xr - -from tiktorch.server.session import State -from tiktorch.server.session.backend import commands -from tiktorch.server.session.backend.supervisor import Supervisor -from tiktorch.utils import wait - - -class TestExemplumSupervisor: - class DummyCmd(commands.ICommand): - def execute(self, ctx): - pass - - class DummyExemplum: - def __init__(self): - self.iteration_count = 0 - self.max_num_iterations = 0 - self._break_cb = None - self._devs = [] - - def set_break_callback(self, cb): - self._break_cb = cb - - def predict_sample_without_blocking(self, input_tensors): - return [xr.DataArray(np.array([42]), dims=("x",))] - - def set_max_num_iterations(self, val): - self.max_num_iterations = val - - def stop_training(self, max_num_iterations=None, max_num_epochs=None): - return self._break_cb and self._break_cb() or self.iteration_count >= self.max_num_iterations - - def train(self): - while not self.stop_training(): - self.iteration_count += 1 - time.sleep(0.01) - - @pytest.fixture - def exemplum(self): - return self.DummyExemplum() - - @pytest.fixture - def supervisor(self, exemplum): - return Supervisor(exemplum) - - @pytest.fixture - def worker_thread(self, supervisor): - t = threading.Thread(target=supervisor.run) - t.start() - yield t - supervisor.send_command(commands.StopCmd()) - t.join() - - def test_not_running_worker_has_stopped_status(self, supervisor): - assert State.Stopped == supervisor.state - - def test_started_worker_has_idle_status(self, supervisor, worker_thread): - cmd = self.DummyCmd().awaitable - supervisor.send_command(cmd) - cmd.wait() - - assert State.Paused == supervisor.state - - def test_resuming_transitions_to_idle_with_no_devices(self, supervisor, worker_thread): - cmd = commands.ResumeCmd().awaitable - supervisor.send_command(cmd) - cmd.wait() - - assert State.Idle == supervisor.state - - def test_transition_to_running(self, supervisor, worker_thread): - cmd = commands.ResumeCmd() - supervisor.send_command(cmd) - - add_work = commands.SetMaxNumIterations(2).awaitable - supervisor.send_command(add_work) - add_work.wait() - - assert supervisor.state == State.Running - - def test_exception_during_train_should_transition_to_paused(self, supervisor, worker_thread, exemplum): - train_called = threading.Event() - train_proceed = threading.Event() - - def _exc(): - train_called.set() - train_proceed.wait() - raise Exception() - - exemplum.train = _exc - - cmd = commands.ResumeCmd() - supervisor.send_command(cmd) - - assert supervisor.state == State.Paused - add_work = commands.SetMaxNumIterations(2).awaitable - supervisor.send_command(add_work) - add_work.wait() - - train_called.wait() - wait(lambda: supervisor.state == State.Running, max_wait=1) - train_proceed.set() - wait(lambda: supervisor.state == State.Paused, max_wait=1) - - def test_finished_training_should_transition_to_paused(self, supervisor, worker_thread, exemplum): - cmd = commands.ResumeCmd() - supervisor.send_command(cmd) - - add_work = commands.SetMaxNumIterations(2).awaitable - supervisor.send_command(add_work) - add_work.wait() - assert supervisor.state == State.Running - time.sleep(0.1) # FIXME: Find a better way to wait for pause event with timeout - assert supervisor.state == State.Idle - - def test_forward(self, supervisor, worker_thread, exemplum): - fut = Future() - forward_cmd = commands.ForwardPass(fut, [xr.DataArray(np.array([1]), dims=("x",))]) - supervisor.send_command(forward_cmd) - assert [42] == fut.result() diff --git a/tests/test_server/test_training/test_worker/test_commands.py b/tests/test_server/test_training/test_worker/test_commands.py index f8307fa7..38ad4289 100644 --- a/tests/test_server/test_training/test_worker/test_commands.py +++ b/tests/test_server/test_training/test_worker/test_commands.py @@ -9,17 +9,17 @@ class TestCommandQueue: def test_stop_command_has_higher_priorityj(self): cmd_queue = cmds.CommandPriorityQueue() - stop_cmd = cmds.StopCmd() - cmd_queue.put_nowait(cmds.ResumeCmd()) + stop_cmd = cmds.ShutdownCmd() + cmd_queue.put_nowait(cmds.ResumeTrainingCmd()) cmd_queue.put_nowait(stop_cmd) - cmd_queue.put_nowait(cmds.PauseCmd()) + cmd_queue.put_nowait(cmds.PauseTrainingCmd()) received_cmd = cmd_queue.get_nowait() assert stop_cmd is received_cmd def test_queue_order_is_stable(self): cmd_queue = cmds.CommandPriorityQueue() - stop_cmds = [cmds.StopCmd() for _ in range(100)] + stop_cmds = [cmds.ShutdownCmd() for _ in range(100)] for cmd in stop_cmds: cmd_queue.put_nowait(cmd) diff --git a/tiktorch/converters.py b/tiktorch/converters.py index 3159c05c..2ec632a2 100644 --- a/tiktorch/converters.py +++ b/tiktorch/converters.py @@ -7,7 +7,18 @@ from bioimageio.core import Sample, Tensor from bioimageio.spec.model.v0_5 import TensorId -from tiktorch.proto import inference_pb2 +from tiktorch.proto import inference_pb2, training_pb2 +from tiktorch.trainer import TrainerState + +trainer_state_to_pb = { + TrainerState.IDLE: training_pb2.GetStatusResponse.State.Idle, + TrainerState.RUNNING: training_pb2.GetStatusResponse.State.Running, + TrainerState.PAUSED: training_pb2.GetStatusResponse.State.Paused, + TrainerState.FAILED: training_pb2.GetStatusResponse.State.Failed, + TrainerState.FINISHED: training_pb2.GetStatusResponse.State.Finished, +} + +pb_state_to_trainer = {value: key for key, value in trainer_state_to_pb.items()} def pb_tensors_to_sample(pb_tensors: List[inference_pb2.Tensor]) -> Sample: diff --git a/tiktorch/proto/training_pb2.py b/tiktorch/proto/training_pb2.py new file mode 100644 index 00000000..80d87d2b --- /dev/null +++ b/tiktorch/proto/training_pb2.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: training.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\"\x07\n\x05\x45mpty\"\x1f\n\x11TrainingSessionId\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"J\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12!\n\x05shape\x18\x04 \x03(\x0b\x32\x12.training.NamedInt\"\\\n\x0ePredictRequest\x12!\n\x07tensors\x18\x01 \x03(\x0b\x32\x10.training.Tensor\x12\'\n\x02id\x18\x02 \x01(\x0b\x32\x1b.training.TrainingSessionId\"L\n\x0fPredictResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12!\n\x07tensors\x18\x02 \x03(\x0b\x32\x10.training.Tensor\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xd2\x05\n\x08Training\x12?\n\x04Init\x12\x18.training.TrainingConfig\x1a\x1b.training.TrainingSessionId\"\x00\x12\x37\n\x05Start\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12\x38\n\x06Resume\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12\x37\n\x05Pause\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12P\n\rStreamUpdates\x12\x1b.training.TrainingSessionId\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x43\n\x07GetLogs\x12\x1b.training.TrainingSessionId\x1a\x19.training.GetLogsResponse\"\x00\x12\x36\n\x04Save\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12\x38\n\x06\x45xport\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12@\n\x07Predict\x12\x18.training.PredictRequest\x1a\x19.training.PredictResponse\"\x00\x12G\n\tGetStatus\x12\x1b.training.TrainingSessionId\x1a\x1b.training.GetStatusResponse\"\x00\x12\x45\n\x13\x43loseTrainerSession\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x62\x06proto3') + + + +_EMPTY = DESCRIPTOR.message_types_by_name['Empty'] +_TRAININGSESSIONID = DESCRIPTOR.message_types_by_name['TrainingSessionId'] +_LOGS = DESCRIPTOR.message_types_by_name['Logs'] +_STREAMUPDATERESPONSE = DESCRIPTOR.message_types_by_name['StreamUpdateResponse'] +_GETLOGSRESPONSE = DESCRIPTOR.message_types_by_name['GetLogsResponse'] +_NAMEDINT = DESCRIPTOR.message_types_by_name['NamedInt'] +_TENSOR = DESCRIPTOR.message_types_by_name['Tensor'] +_PREDICTREQUEST = DESCRIPTOR.message_types_by_name['PredictRequest'] +_PREDICTRESPONSE = DESCRIPTOR.message_types_by_name['PredictResponse'] +_VALIDATIONRESPONSE = DESCRIPTOR.message_types_by_name['ValidationResponse'] +_GETSTATUSRESPONSE = DESCRIPTOR.message_types_by_name['GetStatusResponse'] +_GETCURRENTBESTMODELIDXRESPONSE = DESCRIPTOR.message_types_by_name['GetCurrentBestModelIdxResponse'] +_TRAININGCONFIG = DESCRIPTOR.message_types_by_name['TrainingConfig'] +_LOGS_MODELPHASE = _LOGS.enum_types_by_name['ModelPhase'] +_GETSTATUSRESPONSE_STATE = _GETSTATUSRESPONSE.enum_types_by_name['State'] +Empty = _reflection.GeneratedProtocolMessageType('Empty', (_message.Message,), { + 'DESCRIPTOR' : _EMPTY, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.Empty) + }) +_sym_db.RegisterMessage(Empty) + +TrainingSessionId = _reflection.GeneratedProtocolMessageType('TrainingSessionId', (_message.Message,), { + 'DESCRIPTOR' : _TRAININGSESSIONID, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.TrainingSessionId) + }) +_sym_db.RegisterMessage(TrainingSessionId) + +Logs = _reflection.GeneratedProtocolMessageType('Logs', (_message.Message,), { + 'DESCRIPTOR' : _LOGS, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.Logs) + }) +_sym_db.RegisterMessage(Logs) + +StreamUpdateResponse = _reflection.GeneratedProtocolMessageType('StreamUpdateResponse', (_message.Message,), { + 'DESCRIPTOR' : _STREAMUPDATERESPONSE, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.StreamUpdateResponse) + }) +_sym_db.RegisterMessage(StreamUpdateResponse) + +GetLogsResponse = _reflection.GeneratedProtocolMessageType('GetLogsResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETLOGSRESPONSE, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.GetLogsResponse) + }) +_sym_db.RegisterMessage(GetLogsResponse) + +NamedInt = _reflection.GeneratedProtocolMessageType('NamedInt', (_message.Message,), { + 'DESCRIPTOR' : _NAMEDINT, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.NamedInt) + }) +_sym_db.RegisterMessage(NamedInt) + +Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), { + 'DESCRIPTOR' : _TENSOR, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.Tensor) + }) +_sym_db.RegisterMessage(Tensor) + +PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), { + 'DESCRIPTOR' : _PREDICTREQUEST, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.PredictRequest) + }) +_sym_db.RegisterMessage(PredictRequest) + +PredictResponse = _reflection.GeneratedProtocolMessageType('PredictResponse', (_message.Message,), { + 'DESCRIPTOR' : _PREDICTRESPONSE, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.PredictResponse) + }) +_sym_db.RegisterMessage(PredictResponse) + +ValidationResponse = _reflection.GeneratedProtocolMessageType('ValidationResponse', (_message.Message,), { + 'DESCRIPTOR' : _VALIDATIONRESPONSE, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.ValidationResponse) + }) +_sym_db.RegisterMessage(ValidationResponse) + +GetStatusResponse = _reflection.GeneratedProtocolMessageType('GetStatusResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETSTATUSRESPONSE, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.GetStatusResponse) + }) +_sym_db.RegisterMessage(GetStatusResponse) + +GetCurrentBestModelIdxResponse = _reflection.GeneratedProtocolMessageType('GetCurrentBestModelIdxResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETCURRENTBESTMODELIDXRESPONSE, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.GetCurrentBestModelIdxResponse) + }) +_sym_db.RegisterMessage(GetCurrentBestModelIdxResponse) + +TrainingConfig = _reflection.GeneratedProtocolMessageType('TrainingConfig', (_message.Message,), { + 'DESCRIPTOR' : _TRAININGCONFIG, + '__module__' : 'training_pb2' + # @@protoc_insertion_point(class_scope:training.TrainingConfig) + }) +_sym_db.RegisterMessage(TrainingConfig) + +_TRAINING = DESCRIPTOR.services_by_name['Training'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _EMPTY._serialized_start=28 + _EMPTY._serialized_end=35 + _TRAININGSESSIONID._serialized_start=37 + _TRAININGSESSIONID._serialized_end=68 + _LOGS._serialized_start=71 + _LOGS._serialized_end=206 + _LOGS_MODELPHASE._serialized_start=173 + _LOGS_MODELPHASE._serialized_end=206 + _STREAMUPDATERESPONSE._serialized_start=208 + _STREAMUPDATERESPONSE._serialized_end=284 + _GETLOGSRESPONSE._serialized_start=286 + _GETLOGSRESPONSE._serialized_end=333 + _NAMEDINT._serialized_start=335 + _NAMEDINT._serialized_end=373 + _TENSOR._serialized_start=375 + _TENSOR._serialized_end=449 + _PREDICTREQUEST._serialized_start=451 + _PREDICTREQUEST._serialized_end=543 + _PREDICTRESPONSE._serialized_start=545 + _PREDICTRESPONSE._serialized_end=621 + _VALIDATIONRESPONSE._serialized_start=623 + _VALIDATIONRESPONSE._serialized_end=677 + _GETSTATUSRESPONSE._serialized_start=680 + _GETSTATUSRESPONSE._serialized_end=819 + _GETSTATUSRESPONSE_STATE._serialized_start=751 + _GETSTATUSRESPONSE_STATE._serialized_end=819 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_start=821 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_end=865 + _TRAININGCONFIG._serialized_start=867 + _TRAININGCONFIG._serialized_end=905 + _TRAINING._serialized_start=908 + _TRAINING._serialized_end=1630 +# @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/training_pb2_grpc.py b/tiktorch/proto/training_pb2_grpc.py new file mode 100644 index 00000000..5be55746 --- /dev/null +++ b/tiktorch/proto/training_pb2_grpc.py @@ -0,0 +1,396 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import training_pb2 as training__pb2 + + +class TrainingStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Init = channel.unary_unary( + '/training.Training/Init', + request_serializer=training__pb2.TrainingConfig.SerializeToString, + response_deserializer=training__pb2.TrainingSessionId.FromString, + ) + self.Start = channel.unary_unary( + '/training.Training/Start', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.Empty.FromString, + ) + self.Resume = channel.unary_unary( + '/training.Training/Resume', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.Empty.FromString, + ) + self.Pause = channel.unary_unary( + '/training.Training/Pause', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.Empty.FromString, + ) + self.StreamUpdates = channel.unary_stream( + '/training.Training/StreamUpdates', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.StreamUpdateResponse.FromString, + ) + self.GetLogs = channel.unary_unary( + '/training.Training/GetLogs', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.GetLogsResponse.FromString, + ) + self.Save = channel.unary_unary( + '/training.Training/Save', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.Empty.FromString, + ) + self.Export = channel.unary_unary( + '/training.Training/Export', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.Empty.FromString, + ) + self.Predict = channel.unary_unary( + '/training.Training/Predict', + request_serializer=training__pb2.PredictRequest.SerializeToString, + response_deserializer=training__pb2.PredictResponse.FromString, + ) + self.GetStatus = channel.unary_unary( + '/training.Training/GetStatus', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.GetStatusResponse.FromString, + ) + self.CloseTrainerSession = channel.unary_unary( + '/training.Training/CloseTrainerSession', + request_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_deserializer=training__pb2.Empty.FromString, + ) + + +class TrainingServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Init(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Start(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Resume(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Pause(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def StreamUpdates(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetLogs(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Save(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Export(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Predict(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetStatus(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CloseTrainerSession(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_TrainingServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Init': grpc.unary_unary_rpc_method_handler( + servicer.Init, + request_deserializer=training__pb2.TrainingConfig.FromString, + response_serializer=training__pb2.TrainingSessionId.SerializeToString, + ), + 'Start': grpc.unary_unary_rpc_method_handler( + servicer.Start, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.Empty.SerializeToString, + ), + 'Resume': grpc.unary_unary_rpc_method_handler( + servicer.Resume, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.Empty.SerializeToString, + ), + 'Pause': grpc.unary_unary_rpc_method_handler( + servicer.Pause, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.Empty.SerializeToString, + ), + 'StreamUpdates': grpc.unary_stream_rpc_method_handler( + servicer.StreamUpdates, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.StreamUpdateResponse.SerializeToString, + ), + 'GetLogs': grpc.unary_unary_rpc_method_handler( + servicer.GetLogs, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.GetLogsResponse.SerializeToString, + ), + 'Save': grpc.unary_unary_rpc_method_handler( + servicer.Save, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.Empty.SerializeToString, + ), + 'Export': grpc.unary_unary_rpc_method_handler( + servicer.Export, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.Empty.SerializeToString, + ), + 'Predict': grpc.unary_unary_rpc_method_handler( + servicer.Predict, + request_deserializer=training__pb2.PredictRequest.FromString, + response_serializer=training__pb2.PredictResponse.SerializeToString, + ), + 'GetStatus': grpc.unary_unary_rpc_method_handler( + servicer.GetStatus, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.GetStatusResponse.SerializeToString, + ), + 'CloseTrainerSession': grpc.unary_unary_rpc_method_handler( + servicer.CloseTrainerSession, + request_deserializer=training__pb2.TrainingSessionId.FromString, + response_serializer=training__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'training.Training', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Training(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Init(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Init', + training__pb2.TrainingConfig.SerializeToString, + training__pb2.TrainingSessionId.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Start(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Start', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Resume(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Resume', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Pause(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Pause', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def StreamUpdates(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/training.Training/StreamUpdates', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.StreamUpdateResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetLogs(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/GetLogs', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.GetLogsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Save(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Save', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Export(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Export', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Predict(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/Predict', + training__pb2.PredictRequest.SerializeToString, + training__pb2.PredictResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/GetStatus', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.GetStatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def CloseTrainerSession(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/CloseTrainerSession', + training__pb2.TrainingSessionId.SerializeToString, + training__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tiktorch/rpc/interface.py b/tiktorch/rpc/interface.py index d47543a5..f00f98b1 100644 --- a/tiktorch/rpc/interface.py +++ b/tiktorch/rpc/interface.py @@ -6,6 +6,7 @@ class RPCInterfaceMeta(type): def __new__(mcls, name, bases, namespace, **kwargs): cls = super().__new__(mcls, name, bases, namespace, **kwargs) + exposed = {name for name, value in namespace.items() if getattr(value, "__exposed__", False)} for base in bases: diff --git a/tiktorch/server/grpc/__init__.py b/tiktorch/server/grpc/__init__.py index a2132a51..21aaf93c 100644 --- a/tiktorch/server/grpc/__init__.py +++ b/tiktorch/server/grpc/__init__.py @@ -6,7 +6,7 @@ import grpc -from tiktorch.proto import data_store_pb2_grpc, inference_pb2_grpc +from tiktorch.proto import data_store_pb2_grpc, inference_pb2_grpc, training_pb2_grpc from tiktorch.server.data_store import DataStore from tiktorch.server.device_pool import IDevicePool, TorchDevicePool from tiktorch.server.session_manager import SessionManager @@ -14,6 +14,7 @@ from .data_store_servicer import DataStoreServicer from .flight_control_servicer import FlightControlServicer from .inference_servicer import InferenceServicer +from .training_servicer import TrainingServicer def _print_available_devices(device_pool: IDevicePool) -> None: @@ -51,16 +52,20 @@ def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeou ) data_store = DataStore() - device_pool = TorchDevicePool() + inference_svc = InferenceServicer(device_pool, SessionManager(), data_store) fligh_svc = FlightControlServicer(done_evt=done_evt, kill_timeout=kill_timeout) data_svc = DataStoreServicer(data_store) + + training_svc = TrainingServicer(device_pool=device_pool, session_manager=SessionManager()) + _print_available_devices(device_pool) inference_pb2_grpc.add_InferenceServicer_to_server(inference_svc, server) inference_pb2_grpc.add_FlightControlServicer_to_server(fligh_svc, server) data_store_pb2_grpc.add_DataStoreServicer_to_server(data_svc, server) + training_pb2_grpc.add_TrainingServicer_to_server(training_svc, server) acquired_port = server.add_insecure_port(f"{host}:{port}") print() diff --git a/tiktorch/server/grpc/training_servicer.py b/tiktorch/server/grpc/training_servicer.py new file mode 100644 index 00000000..d4f571eb --- /dev/null +++ b/tiktorch/server/grpc/training_servicer.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import logging +import queue +from typing import Callable, List + +import grpc + +from tiktorch.converters import trainer_state_to_pb +from tiktorch.proto import training_pb2, training_pb2_grpc +from tiktorch.server.device_pool import IDevicePool +from tiktorch.server.session.process import start_trainer_process +from tiktorch.server.session.rpc_interface import IRPCTrainer +from tiktorch.server.session_manager import Session, SessionManager +from tiktorch.trainer import TrainerYamlParser + +logger = logging.getLogger(__name__) + + +class TrainingServicer(training_pb2_grpc.TrainingServicer): + def __init__( + self, + device_pool: IDevicePool, + session_manager: SessionManager[IRPCTrainer], + ) -> None: + self._device_pool = device_pool + self._logs_queue_stream = queue.Queue() + self._should_stop_callbacks: List[Callable] = [] + self._session_manager = session_manager + + def Init(self, request: training_pb2.TrainingConfig, context): + parser = TrainerYamlParser(request.yaml_content) + device = parser.get_device() + + _, client = start_trainer_process() + session = self._session_manager.create_session(client) + session.on_close(client.shutdown) + + lease = self._device_pool.lease([device]) + session.on_close(lease.terminate) + + client.init(request.yaml_content) + + return training_pb2.TrainingSessionId(id=session.id) + + def Start(self, request, context): + session = self._getTrainerSession(context, request.id) + session.client.start_training() + return training_pb2.Empty() + + def Resume(self, request, context): + session = self._getTrainerSession(context, request.id) + session.client.resume_training() + return training_pb2.Empty() + + def Pause(self, request: training_pb2.TrainingSessionId, context): + session = self._getTrainerSession(context, request.id) + session.client.pause_training() + return training_pb2.Empty() + + def Save(self, request: training_pb2.TrainingSessionId, context): + session = self._getTrainerSession(context, request.modelSessionId) + session.client.save() + return training_pb2.Empty() + + def Export(self, request: training_pb2.TrainingSessionId, context): + session = self._getTrainerSession(context, request.modelSessionId) + session.client.export() + return training_pb2.Empty() + + def Predict(self, request: training_pb2.TrainingSessionId, context): + raise NotImplementedError + + def StreamUpdates(self, request: training_pb2.TrainingSessionId, context): + raise NotImplementedError + + def GetLogs(self, request: training_pb2.TrainingSessionId, context): + raise NotImplementedError + + def GetStatus(self, request: training_pb2.TrainingSessionId, context): + session = self._getTrainerSession(context, request.id) + state = session.client.get_state() + return training_pb2.GetStatusResponse(state=trainer_state_to_pb[state]) + + def CloseTrainerSession(self, request: training_pb2.TrainingSessionId, context) -> training_pb2.Empty: + self._session_manager.close_session(request.id) + return training_pb2.Empty() + + def close_all_sessions(self): + self._session_manager.close_all_sessions() + + def _getTrainerSession(self, context, trainer_session_id: str) -> Session[IRPCTrainer]: + session = self._session_manager.get(trainer_session_id) + + if session is None: + context.abort( + grpc.StatusCode.FAILED_PRECONDITION, f"trainer-session with id {trainer_session_id} doesn't exist" + ) + + return session diff --git a/tiktorch/server/session/backend/base.py b/tiktorch/server/session/backend/base.py index eab3ea44..471c2d71 100644 --- a/tiktorch/server/session/backend/base.py +++ b/tiktorch/server/session/backend/base.py @@ -1,60 +1,91 @@ from __future__ import annotations import logging -import threading -import typing +from abc import ABC from concurrent.futures import Future from bioimageio.core import PredictionPipeline from tiktorch.configkeys import TRAINING, VALIDATION -from tiktorch.server.session import types -from tiktorch.server.session.backend import commands, supervisor +from tiktorch.server.session.backend import commands +from tiktorch.server.session.backend.supervisor import BioModelSupervisor, QueueTasks, TrainerState, TrainerSupervisor from tiktorch.tiktypes import TikTensorBatch +from tiktorch.trainer import Trainer logger = logging.getLogger(__name__) -class SessionBackend: +class SessionBackend(ABC): + def __init__(self, supervisor): + self._supervisor = supervisor + self._queue_tasks = QueueTasks(supervisor) + self._queue_tasks.start() + + def shutdown(self): + self._queue_tasks.shutdown() + logger.debug("Shutdown complete") + + +class BioModelSessionBackend(SessionBackend): + """Session backend for bioimageio models + + Currently used only for inference. + """ + def __init__(self, pipeline: PredictionPipeline): - self._supervisor = supervisor.Supervisor(pipeline) - self._supervisor_thread = threading.Thread(target=self._supervisor.run, name="ModelThread") - self._supervisor_thread.start() + supervisor = BioModelSupervisor(pipeline) + super().__init__(supervisor) def update_dataset(self, name: str, *, data: TikTensorBatch, labels: TikTensorBatch) -> None: assert name in (TRAINING, VALIDATION), f"{name} not in ({TRAINING}, {VALIDATION})" update_cmd = commands.UpdateDatasetCmd(name, raw_data=data, labels=labels) - self._supervisor.send_command(update_cmd) + self._queue_tasks.send_command(update_cmd) def set_max_num_iterations(self, num: int) -> None: - self._supervisor.send_command(commands.SetMaxNumIterations(num)) + self._queue_tasks.send_command(commands.SetMaxNumIterations(num)) def forward(self, input_tensors): res = Future() - self._supervisor.send_command(commands.ForwardPass(res, input_tensors)) + self._queue_tasks.send_command(commands.ForwardPass(res, input_tensors)) return res - def shutdown(self) -> None: - logger.debug("Shutting down...") - stop_cmd = commands.StopCmd() - self._supervisor.send_command(stop_cmd.awaitable) - stop_cmd.awaitable.wait() +class TrainerSessionBackend(SessionBackend): + """Session backend for training - self._supervisor_thread.join() + Currently, supports only custom unet models decoupled from bioimageio models + """ - logger.debug("Shutdown complete") + def __init__(self, trainer: Trainer): + self._trainer = trainer + supervisor = TrainerSupervisor(trainer) + super().__init__(supervisor) + + def forward(self, input_tensors): + res = Future() + self._queue_tasks.send_command(commands.ForwardPass(res, input_tensors)) + return res def resume_training(self) -> None: - resume_cmd = commands.ResumeCmd() - self._supervisor.send_command(resume_cmd.awaitable) + resume_cmd = commands.ResumeTrainingCmd() + self._queue_tasks.send_command(resume_cmd.awaitable) resume_cmd.awaitable.wait() def pause_training(self) -> None: - self._supervisor.send_command(commands.PauseCmd()) + pause_cmd = commands.PauseTrainingCmd() + self._queue_tasks.send_command(pause_cmd.awaitable) + pause_cmd.awaitable.wait() + + def start_training(self) -> None: + start_cmd = commands.StartTrainingCmd() + self._queue_tasks.send_command(start_cmd.awaitable) + start_cmd.awaitable.wait() + + def save(self) -> None: + raise NotImplementedError - def get_idle(self) -> bool: - return self._supervisor.state == types.State.Paused + def export(self) -> None: + raise NotImplementedError - def on_idle(self, callback: typing.Callable[[], None]) -> None: - self._supervisor.on_idle(callback) + def get_state(self) -> TrainerState: + return self._supervisor.get_state() diff --git a/tiktorch/server/session/backend/commands.py b/tiktorch/server/session/backend/commands.py index 8943b325..a99c5cd7 100644 --- a/tiktorch/server/session/backend/commands.py +++ b/tiktorch/server/session/backend/commands.py @@ -6,11 +6,12 @@ import threading import typing from dataclasses import dataclass, field +from typing import Generic, Type, TypeVar -from tiktorch.server.session import types +from tiktorch.trainer import TrainerState if typing.TYPE_CHECKING: - from tiktorch.server.session.backend.supervisor import Supervisor + from tiktorch.server.session.backend.supervisor import BioModelSupervisor, Supervisors, TrainerSupervisor # from tiktorch.server.datasets import DynamicDataset @@ -20,27 +21,41 @@ __all__ = [ "ICommand", "AwaitableCommand", - "PauseCmd", - "ResumeCmd", - "StopCmd", + "StartTrainingCmd", + "PauseTrainingCmd", + "ResumeTrainingCmd", + "ShutdownWithTeardownCmd", + "NominalShutdownCmd", + "ShutdownWithErrorCmd", + "SetResumeStateTrainingCmd", + "SetPauseStateTrainingCmd", + "SetStartStateTrainingCmd", "UpdateDatasetCmd", "SetMaxNumIterations", ] +SupervisorType = TypeVar("SupervisorType") -class Context: + +class Context(Generic[SupervisorType]): """ Command execution context Contains modifiable entities as attributes """ - def __init__(self, *, supervisor: Supervisor) -> None: + def __init__(self, *, supervisor: SupervisorType) -> None: self.session = supervisor class ICommand: __awaitable = None + def __init__(self, is_termination_signal: bool = False): + self._is_termination_signal = is_termination_signal + + def is_stop(self): + return self._is_termination_signal + @property def awaitable(self): if not self.__awaitable: @@ -51,18 +66,33 @@ def awaitable(self): def execute(self, ctx: Context) -> None: raise NotImplementedError() + def is_command(self, command_to_check: Type[ICommand]): + """ + Identify the command even if it is wrapped as an awaitable one + """ + if isinstance(self, AwaitableCommand): + return isinstance(self._cmd, command_to_check) + else: + return isinstance(self, command_to_check) + class AwaitableCommand(ICommand): def __init__(self, cmd: ICommand): self._cmd = cmd self._done_evt = threading.Event() + self._exception: Exception | None = None # Store the exception + super().__init__(is_termination_signal=self._cmd.is_stop()) def wait(self): self._done_evt.wait() + if self._exception is not None: + raise self._exception def execute(self, ctx: Context) -> None: try: self._cmd.execute(ctx) + except Exception as e: + self._exception = e finally: self._done_evt.set() @@ -70,55 +100,128 @@ def __repr__(self): return f"Awaitable {self._cmd!r}" -class PauseCmd(ICommand): - def execute(self, ctx: Context) -> None: - ctx.session.transition_to(types.State.Paused) +class PauseTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.pause() -class ResumeCmd(ICommand): - def execute(self, ctx: Context) -> None: - ctx.session.transition_to(types.State.Running) +class ResumeTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.resume() + + +class SetStartStateTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.transition_to_state(new_state=TrainerState.RUNNING, valid_states={TrainerState.IDLE}) + +class SetPauseStateTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.transition_to_state(new_state=TrainerState.PAUSED, valid_states={TrainerState.RUNNING}) + + +class SetResumeStateTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.transition_to_state(new_state=TrainerState.RUNNING, valid_states={TrainerState.PAUSED}) + + +class ShutdownCmd(ICommand): + def __init__(self): + super().__init__(is_termination_signal=True) -class StopCmd(ICommand): def execute(self, ctx: Context) -> None: - ctx.session.transition_to(types.State.Stopped) + pass + + +class ShutdownWithTeardownCmd(ShutdownCmd): + def execute(self, ctx: Context[Supervisors]) -> None: + ctx.session.shutdown() + + +class NominalShutdownCmd(ShutdownCmd): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.transition_to_state(new_state=TrainerState.FINISHED, valid_states={TrainerState.RUNNING}) + + +class ShutdownWithErrorCmd(ShutdownCmd): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.transition_to_state(new_state=TrainerState.FAILED, valid_states={TrainerState.RUNNING}) + + +class StartTrainingCmd(ICommand): + def execute(self, ctx: Context[TrainerSupervisor]) -> None: + ctx.session.start() class UpdateDatasetCmd(ICommand): def __init__(self, name, *, raw_data, labels): + super().__init__() self._name = name self._raw_data = raw_data self._labels = labels - def execute(self, ctx: Context) -> None: + def execute(self, ctx: Context[BioModelSupervisor]) -> None: logger.warning("Not Implemented") + ctx.session.update_dataset() # dataset = ctx.exemplum.get_dataset(self._name) # dataset.update(self._raw_data, self._labels) class SetMaxNumIterations(ICommand): def __init__(self, num_iterations: int) -> None: + super().__init__() self._num_iterations = num_iterations - def execute(self, ctx: Context) -> None: + def execute(self, ctx: Context[BioModelSupervisor]) -> None: ctx.session.set_max_num_iterations(self._num_iterations) class ForwardPass(ICommand): def __init__(self, future, input_tensors): + super().__init__() self._input_tensors = input_tensors self._future = future - def execute(self, ctx: Context) -> None: + def execute(self, ctx: Context[Supervisors]) -> None: try: self._future.set_result(ctx.session.forward(self._input_tensors)) except Exception as e: self._future.set_exception(e) +class CommandPriorityQueueUtils: + """ + Utility for managing and processing commands in a priority queue. + """ + + def __init__(self) -> None: + self.queue = CommandPriorityQueue() + + def send_command(self, cmd: ICommand) -> None: + if not isinstance(cmd, ICommand): + raise ValueError(f"Expected instance of ICommand got {cmd}") + + logger.debug("Sending command %s", cmd) + self.queue.put(cmd) + + def process_commands(self, session): + cmd: ICommand = self.queue.get() + ctx = Context(supervisor=session) + logger.debug("Executing %s", cmd) + + try: + cmd.execute(ctx) + except Exception as e: + logger.exception(f"Failed to execute %s with exception {e}", cmd) + finally: + self.queue.task_done() + logger.debug(f"Finished executing {cmd}") + + return cmd.is_stop() + + class CommandPriorityQueue(queue.PriorityQueue): - COMMAND_PRIORITIES = {StopCmd: 0} + COMMAND_PRIORITIES = {ShutdownWithTeardownCmd: 0} @dataclass(order=True) class _PrioritizedItem: @@ -129,7 +232,10 @@ class _PrioritizedItem: @classmethod def _make_queue_item(cls, cmd: ICommand): - priority = cls.COMMAND_PRIORITIES.get(type(cmd), 999) + if cmd.is_stop(): + priority = 0 + else: + priority = cls.COMMAND_PRIORITIES.get(type(cmd), 999) return cls._PrioritizedItem((priority, next(cls.__counter)), cmd) def put(self, cmd: ICommand, block=True, timeout=None) -> None: diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index a5594e08..e980db47 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -1,137 +1,220 @@ -from __future__ import annotations - import logging -import queue +import threading +from typing import Generic, Set, TypeVar, Union from bioimageio.core import PredictionPipeline, Sample -from tiktorch.server.session import types from tiktorch.server.session.backend import commands +from tiktorch.server.session.backend.commands import CommandPriorityQueueUtils, ShutdownWithTeardownCmd +from tiktorch.trainer import Callbacks, ErrorCallbacks, Trainer, TrainerState logger = logging.getLogger(__name__) -class Supervisor: - def __init__(self, pipeline: PredictionPipeline) -> None: - self._state = types.State.Stopped +class StateTransitionError(Exception): + def __init__(self, current_state: TrainerState, transitioning_state: TrainerState, valid_states: Set[TrainerState]): + super().__init__( + f"Invalid state transition: {current_state} -> {transitioning_state}. Valids are {valid_states}" + ) + self.current_state = current_state + self.transitioning_state = transitioning_state + self.valid_states = valid_states + + def __reduce__(self): + return ( + self.__class__, + (self.current_state, self.transitioning_state, self.valid_states), + ) - self._command_queue = commands.CommandPriorityQueue() - self._pipeline = pipeline - # self._pipeline.set_break_callback(self.has_commands) - self._idle_callbacks = [] - def send_command(self, cmd: commands.ICommand) -> None: - if not isinstance(cmd, commands.ICommand): - raise ValueError(f"Expected instance of ICommand got {cmd}") +class TrainerSupervisor: + """Training supervisor for custom models supported by the 'Trainer' interface. - logger.debug("Sending command %s", cmd) - self._command_queue.put(cmd) + Monitoring the training thread and its status. + """ - @property - def state(self): + def __init__(self, trainer: Trainer) -> None: + super().__init__() + self._trainer = trainer + self._trainer.should_stop_callbacks.register(self._should_stop) + self._shutdown_event = threading.Event() + self._state = TrainerState.IDLE + self._pause_triggered = False + self._session_thread = threading.Thread(target=self._start_session, name="SessionThread") + self._command_queue_utils = CommandPriorityQueueUtils() + self.training_error_callbacks: ErrorCallbacks = Callbacks() + + def get_state(self) -> TrainerState: + logger.debug(f"Get state called {self._state}") return self._state - def has_commands(self): - return not self._command_queue.empty() + def start(self): + if self._state != TrainerState.IDLE: + raise StateTransitionError( + current_state=self._state, transitioning_state=TrainerState.RUNNING, valid_states={TrainerState.IDLE} + ) + self._session_thread.start() + self._pause_triggered = False + start_cmd = commands.SetStartStateTrainingCmd() + self._command_queue_utils.send_command(start_cmd.awaitable) + start_cmd.awaitable.wait() + + def _start_session(self): + logger.info("Starting session worker") + try: + while True: + if self._command_queue_utils.process_commands(self): + break + + if self._state == TrainerState.RUNNING: + self._fit() + except Exception as e: + logger.exception(f"Uncaught exception in session worker. Exception: {e}") + finally: + logger.info("Stopped session worker") + + def _fit(self): + try: + self._trainer.fit() + except Exception as e: + logger.exception(f"Training error: {e}") + self.training_error_callbacks(e) + self._command_queue_utils.send_command(commands.ShutdownWithErrorCmd()) + return + + if self.is_training_finished(): + logger.info(f"Training has finished: {self._get_num_iterations_epochs()} ") + self._command_queue_utils.send_command(commands.NominalShutdownCmd()) + + def is_training_finished(self): + return ( + self._trainer.num_epochs == self._trainer.max_num_epochs + or self._trainer.num_iterations == self._trainer.max_num_iterations + ) + + def _get_num_iterations_epochs(self) -> str: + iterations = f"Iterations[{self._trainer.num_iterations}/{self._trainer.max_num_iterations}]" + epochs = f"Epochs[{self._trainer.num_epochs}/{self._trainer.max_num_epochs}]" + return f"{iterations}, {epochs}" + + def resume(self): + self._check_transition_to_state(TrainerState.RUNNING, valid_states={TrainerState.PAUSED}) + self._pause_triggered = False + resume_cmd = commands.SetResumeStateTrainingCmd() + self._command_queue_utils.send_command(resume_cmd.awaitable) + resume_cmd.awaitable.wait() # make sure that the state has actually changed (acknowledge) + logger.info(f"Resume training: {self._get_num_iterations_epochs()}") + + def pause(self): + self._check_transition_to_state(TrainerState.PAUSED, valid_states={TrainerState.RUNNING}) + self._pause_triggered = True + pause_cmd = commands.SetPauseStateTrainingCmd() + self._command_queue_utils.send_command(pause_cmd.awaitable) + pause_cmd.awaitable.wait() # make sure that the state has actually changed (acknowledge) + + def shutdown(self): + if not self._session_thread.is_alive(): + # nothing to do if session thread not alive + return + self._pause_triggered = True + self._command_queue_utils.send_command(commands.ShutdownCmd()) + self._session_thread.join() + + def forward(self, input_tensors): + self.pause() + self._trainer.forward(input_tensors) + self.resume() + + def save(self): + raise NotImplementedError + + def export(self): + raise NotImplementedError + + def _should_stop(self): + return self._pause_triggered + + def transition_to_state(self, new_state: TrainerState, valid_states: Set[TrainerState]): + """ + Should be used via the ICommands to monitor the state of the training + """ + self._check_transition_to_state(new_state, valid_states) + logger.info(f"State transition: {self._state} -> {new_state}") + self._state = new_state + + def _check_transition_to_state(self, new_state: TrainerState, valid_states: Set[TrainerState]): + if self._state not in valid_states: + raise StateTransitionError( + current_state=self._state, transitioning_state=new_state, valid_states=valid_states + ) + + +class BioModelSupervisor: + """Supervisor for bioimageio models + + Currently used only for inference. - def has_work(self): - return self._pipeline.max_num_iterations and self._pipeline.max_num_iterations > self._pipeline.iteration_count + Allows to serialize and offload commands by multiple threads requests. + """ + + def __init__(self, pipeline: PredictionPipeline) -> None: + super().__init__() + self._pipeline = pipeline def forward(self, sample: Sample): results = self._pipeline.predict_sample_without_blocking(sample) return results - def transition_to(self, new_state: types.State) -> None: - logger.debug("Attempting transition to state %s", new_state) - self._state = new_state - self._update_state() - - def set_max_num_iterations(self, num: int): - self._pipeline.set_max_num_iterations(num) - self._update_state() - - def on_idle(self, callback): - self._idle_callbacks.append(callback) - self._notify_idle() - - def _notify_idle(self): - if self._state in (types.State.Idle, types.State.Paused): - idle_cbs = self._idle_callbacks - self._idle_callbacks = [] - for cb in idle_cbs: - try: - cb() - except Exception: - logger.exception("Exception during idle callback") - - def run(self): + def set_max_num_iterations(self, num_iterations: int): + raise NotImplementedError + + def update_dataset(self): + raise NotImplementedError + + def shutdown(self): + pass + + +Supervisors = Union[BioModelSupervisor, TrainerSupervisor] +SupervisorTypeVar = TypeVar("SupervisorTypeVar", bound=Supervisors) + + +class QueueTasks(Generic[SupervisorTypeVar]): + """ + A task queue manager for processing commands with a supervisor. + + Serializes multiple async requests wrapped as commands. + """ + + def __init__(self, supervisor: SupervisorTypeVar) -> None: + self._command_queue = CommandPriorityQueueUtils() + self._supervisor = supervisor + self._thread = threading.Thread(target=self._run, name="QueueTasksWorker") + + def start(self): + self._thread.start() + + def _run(self): logger.info("Starting session worker") try: - self._run() - except Exception: - logger.exception("Uncaught exception in session worker") + while True: + if self._command_queue.process_commands(self._supervisor): + break + except Exception as e: + logger.exception(f"Uncaught exception in session worker {e}") finally: logger.info("Stopped session worker") - def _run(self): - self._set_state(types.State.Paused) - - while True: - self._process_commands() - - if self.state == types.State.Stopped: - break - - elif self._state == types.State.Idle or self._state == types.State.Paused: - with self._command_queue.not_empty: - self._command_queue.not_empty.wait() - - elif self._state == types.State.Running: - self._train() - self._update_state() - - def _process_commands(self): - while not self._command_queue.empty(): - try: - cmd = self._command_queue.get_nowait() - logger.debug("Executing %s", cmd) - ctx = commands.Context(supervisor=self) - - try: - cmd.execute(ctx) - except Exception: - logger.exception("Failed to execute %s", cmd) - finally: - self._command_queue.task_done() - - except queue.Empty: - pass - - def _train(self): - logger.info( - "Start session for %d iterations", self._pipeline.max_num_iterations - self._pipeline.iteration_count - ) - try: - self._pipeline.train() - except Exception: - logger.error("Exception during session training. Pausing...", exc_info=True) - # FIXME: Should we use PauseCmd here? Maybe we should only know about ICommand on this level. - self.send_command(commands.PauseCmd()) - - self._update_state() - - def _update_state(self): - if self._state == types.State.Running: - should_idle = not self.has_work() - if should_idle: - self._set_state(types.State.Idle) - - elif self._state == types.State.Idle: - should_run = self.has_work() - if should_run: - self._set_state(types.State.Running) - - def _set_state(self, new_state: types.State) -> None: - self._state = new_state - self._notify_idle() - logger.debug("Set new state %s", self._state) + def send_command(self, command: commands.ICommand): + self._command_queue.send_command(command) + + def shutdown(self): + if not self._thread.is_alive(): + logger.debug("Worker thread isn't alive") + return + logger.debug("Shutting down...") + stop_cmd = ShutdownWithTeardownCmd() + self.send_command(stop_cmd.awaitable) + stop_cmd.awaitable.wait() + logger.debug("Shutdown complete") + self._thread.join() diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index f289697e..3d164127 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -17,8 +17,9 @@ from tiktorch.rpc.mp import BioModelClient, MPServer from ...converters import Sample +from ...trainer import TrainerYamlParser from .backend import base -from .rpc_interface import IRPCModelSession +from .rpc_interface import IRPCModelSession, IRPCTrainer class InputSampleValidator: @@ -77,11 +78,11 @@ class ModelSessionProcess(IRPCModelSession): def __init__(self) -> None: super().__init__() self._datasets = {} - self._worker: Optional[base.SessionBackend] = None + self._worker: Optional[base.BioModelSessionBackend] = None def init(self, model_bytes: bytes, devices: List[str]): prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_bytes, devices) - self._worker = base.SessionBackend(prediction_pipeline) + self._worker = base.BioModelSessionBackend(prediction_pipeline) def forward(self, sample: Sample) -> Future: res = self.worker.forward(sample) @@ -99,12 +100,56 @@ def shutdown(self) -> Shutdown: return Shutdown() @property - def worker(self) -> base.SessionBackend: + def worker(self) -> base.BioModelSessionBackend: if self._worker is None: raise ValueError("Server isn't initialized") return self._worker +class TrainerSessionProcess(IRPCTrainer): + def __init__(self): + self._worker: Optional[base.TrainerSessionBackend] = None + + @property + def worker(self) -> base.TrainerSessionBackend: + if self._worker is None: + raise ValueError("Server isn't initialized") + return self._worker + + def init(self, trainer_yaml_config): + parser = TrainerYamlParser(trainer_yaml_config) + trainer = parser.parse() + self._worker = base.TrainerSessionBackend(trainer) + + def forward(self, input_tensors) -> Future: + res = self.worker.forward(input_tensors) + return res + + def resume_training(self): + self.worker.resume_training() + + def start_training(self): + self.worker.start_training() + + def pause_training(self): + self.worker.pause_training() + + def save(self): + self.worker.save() + + def export(self): + self.worker.export() + + def get_state(self): + return self.worker.get_state() + + def shutdown(self): + if self._worker is None: + return Shutdown() + self.worker.shutdown() + return Shutdown() + + def _run_server(api: RPCInterface, conn: Connection, log_queue: Optional[_mp.Queue] = None): try: # from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667 @@ -125,6 +170,10 @@ def _run_server(api: RPCInterface, conn: Connection, log_queue: Optional[_mp.Que T = TypeVar("T", bound=RPCInterface) +def start_trainer_process(log_queue: Optional[_mp.Queue] = None) -> Tuple[_mp.Process, TrainerSessionProcess]: + return start_process(interface_class=TrainerSessionProcess, log_queue=log_queue) + + def start_process(interface_class: Type[T], log_queue: Optional[_mp.Queue] = None) -> Tuple[_mp.Process, T]: client_conn, server_conn = _mp.Pipe() proc = _mp.Process( diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index 4efface8..db714cb5 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -2,7 +2,9 @@ from tiktorch.converters import Sample from tiktorch.rpc import RPCInterface, exposed +from tiktorch.rpc.exceptions import Shutdown from tiktorch.tiktypes import TikTensorBatch +from tiktorch.trainer import TrainerState from tiktorch.types import ModelState @@ -46,3 +48,41 @@ def create_dataset_description(self, mean, stddev) -> str: @exposed def forward(self, input_tensors: Sample): raise NotImplementedError + + +class IRPCTrainer(RPCInterface): + @exposed + def init(self, trainer_yaml_config: str): + raise NotImplementedError + + @exposed + def forward(self, input_tensors: Sample): + raise NotImplementedError + + @exposed + def resume_training(self) -> None: + raise NotImplementedError + + @exposed + def pause_training(self) -> None: + raise NotImplementedError + + @exposed + def start_training(self) -> None: + raise NotImplementedError + + @exposed + def shutdown(self) -> Shutdown: + raise NotImplementedError + + @exposed + def save(self): + raise NotImplementedError + + @exposed + def export(self): + raise NotImplementedError + + @exposed + def get_state(self) -> TrainerState: + raise NotImplementedError diff --git a/tiktorch/trainer.py b/tiktorch/trainer.py new file mode 100644 index 00000000..34fb79c7 --- /dev/null +++ b/tiktorch/trainer.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Generic, List, TypeVar + +import torch +import yaml +from pytorch3dunet.datasets.utils import get_train_loaders +from pytorch3dunet.unet3d.losses import get_loss_criterion +from pytorch3dunet.unet3d.metrics import get_evaluation_metric +from pytorch3dunet.unet3d.model import get_model +from pytorch3dunet.unet3d.trainer import UNetTrainer +from pytorch3dunet.unet3d.utils import create_lr_scheduler, create_optimizer, get_tensorboard_formatter +from torch import nn + +T = TypeVar("T", bound=Callable) + + +class Callbacks(Generic[T]): + def __init__(self): + self._callbacks: List[T] = [] + + def register(self, callback: T): + self._callbacks.append(callback) + + def unregister(self, callback: T): + self._callbacks.remove(callback) + + def __call__(self, *args, **kwargs): + for callback in self._callbacks: + callback(*args, **kwargs) + + +ErrorCallbacks = Callbacks[Callable[[Exception], None]] + + +class ModelPhase(Enum): + Train = "train" + Eval = "val" + + +@dataclass(frozen=True) +class Logs: + mode: ModelPhase + loss: float + eval_score: float + iteration: int + + +LogsCallbacks = Callbacks[Callable[[Logs], None]] + + +class TrainerState(Enum): + IDLE = 0 + RUNNING = 1 + PAUSED = 2 + FAILED = 3 + FINISHED = 4 + + +class Trainer(UNetTrainer): + def __init__( + self, + model, + optimizer, + lr_scheduler, + loss_criterion, + eval_criterion, + loaders, + checkpoint_dir, + max_num_epochs, + max_num_iterations, + validate_after_iters=200, + log_after_iters=100, + validate_iters=None, + num_iterations=1, + num_epoch=0, + eval_score_higher_is_better=True, + tensorboard_formatter=None, + skip_train_validation=False, + resume=None, + pre_trained=None, + **kwargs, + ): + super().__init__( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, + eval_criterion=eval_criterion, + loaders=loaders, + checkpoint_dir=checkpoint_dir, + max_num_epochs=max_num_epochs, + max_num_iterations=max_num_iterations, + validate_after_iters=validate_after_iters, + log_after_iters=log_after_iters, + validate_iters=validate_iters, + num_iterations=num_iterations, + num_epoch=num_epoch, + eval_score_higher_is_better=eval_score_higher_is_better, + tensorboard_formatter=tensorboard_formatter, + skip_train_validation=skip_train_validation, + resume=resume, + pre_trained=pre_trained, + **kwargs, + ) + self._should_stop_callbacks: List[Callable] = [] + self._logs_callbacks: List[Callable] = [] + self.logs_callbacks: LogsCallbacks = Callbacks() + self.should_stop_callbacks: Callbacks = Callbacks() + + def fit(self): + return super().fit() + + def train(self): + return super().train() + + def validate(self): + return super().validate() + + def forward(self, input_tensors): + self.model.eval() + with torch.no_grad(): + self.model(input_tensors) + + def should_stop(self): + return self.should_stop_callbacks() or super().should_stop() + + def _log_stats(self, phase, loss_avg, eval_score_avg): + logs = Logs(mode=ModelPhase(phase), loss=loss_avg, eval_score=eval_score_avg, iteration=self.num_iterations) + self.logs_callbacks(logs) + return super()._log_stats(phase, loss_avg, eval_score_avg) + + +class TrainerYamlParser: + def __init__(self, yaml_string: str): + self._yaml_string = yaml_string + self._yaml_config = yaml.safe_load(self._yaml_string) + + def get_device(self): + return self._yaml_config["device"] + + def parse(self) -> Trainer: + """ + Source: pytorch 3d unet + """ + + config = self._yaml_config + + model = get_model(config["model"]) + + if torch.cuda.device_count() > 1 and not config["device"] == "cpu": + model = nn.DataParallel(model) + if torch.cuda.is_available() and not config["device"] == "cpu": + model = model.cuda() + + # Create loss criterion + loss_criterion = get_loss_criterion(config) + # Create evaluation metric + eval_criterion = get_evaluation_metric(config) + + # Create data loaders + loaders = get_train_loaders(config) + + # Create the optimizer + optimizer = create_optimizer(config["optimizer"], model) + + # Create learning rate adjustment strategy + lr_scheduler = create_lr_scheduler(config.get("lr_scheduler", None), optimizer) + + trainer_config = config["trainer"] + # Create tensorboard formatter + tensorboard_formatter = get_tensorboard_formatter(trainer_config.pop("tensorboard_formatter", None)) + # Create trainer + resume = trainer_config.pop("resume", None) + pre_trained = trainer_config.pop("pre_trained", None) + + return Trainer( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_criterion=loss_criterion, + eval_criterion=eval_criterion, + loaders=loaders, + tensorboard_formatter=tensorboard_formatter, + resume=resume, + pre_trained=pre_trained, + **trainer_config, + )