From 27b3923db104b79e79ed2084ed59bfd3aa4c3909 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Thu, 19 Dec 2024 01:37:42 +0100 Subject: [PATCH] Separate listing devices as an utility for inference and training - The inference servicer had a procedure to list the available devices. This is needed or the training servicer as well. So list devices is decoupled to be shared. --- examples/grpc_client.py | 4 +- proto/inference.proto | 15 ---- proto/training.proto | 4 +- proto/utils.proto | 16 +++++ .../test_grpc/test_inference_servicer.py | 20 +++--- tests/test_server/test_grpc/test_init.py | 2 +- .../test_grpc/test_training_servicer.py | 4 +- tiktorch/proto/inference_pb2.py | 62 +++++++--------- tiktorch/proto/inference_pb2_grpc.py | 49 ++++++------- tiktorch/proto/training_pb2.py | 56 +++++++-------- tiktorch/proto/training_pb2_grpc.py | 70 ++++++++++++++----- tiktorch/proto/utils_pb2.py | 22 ++++-- .../server/grpc/flight_control_servicer.py | 10 +-- tiktorch/server/grpc/inference_servicer.py | 27 +++---- tiktorch/server/grpc/training_servicer.py | 31 ++++---- tiktorch/server/grpc/utils_servicer.py | 18 +++++ 16 files changed, 228 insertions(+), 182 deletions(-) create mode 100644 tiktorch/server/grpc/utils_servicer.py diff --git a/examples/grpc_client.py b/examples/grpc_client.py index f0fe9ce5..f64ebd5b 100644 --- a/examples/grpc_client.py +++ b/examples/grpc_client.py @@ -1,12 +1,12 @@ import grpc -from tiktorch.proto import inference_pb2, inference_pb2_grpc +from tiktorch.proto import inference_pb2_grpc, utils_pb2 def run(): with grpc.insecure_channel("127.0.0.1:5567") as channel: stub = inference_pb2_grpc.InferenceStub(channel) - response = stub.ListDevices(inference_pb2.Empty()) + response = stub.ListDevices(utils_pb2.Empty()) print(response) diff --git a/proto/inference.proto b/proto/inference.proto index cf3439f0..3d10929b 100644 --- a/proto/inference.proto +++ b/proto/inference.proto @@ -19,15 +19,6 @@ service Inference { rpc Predict(PredictRequest) returns (PredictResponse) {} } -message Device { - enum Status { - AVAILABLE = 0; - IN_USE = 1; - } - - string id = 1; - Status status = 2; -} message CreateDatasetDescriptionRequest { string modelSessionId = 1; @@ -81,11 +72,6 @@ message LogEntry { string content = 3; } -message Devices { - repeated Device devices = 1; -} - - message PredictRequest { string modelSessionId = 1; @@ -97,7 +83,6 @@ message PredictResponse { repeated Tensor tensors = 1; } -message Empty {} service FlightControl { rpc Ping(Empty) returns (Empty) {} diff --git a/proto/training.proto b/proto/training.proto index 62406160..496a6eaa 100644 --- a/proto/training.proto +++ b/proto/training.proto @@ -5,10 +5,10 @@ package training; import "utils.proto"; -message Empty {} - service Training { + rpc ListDevices(Empty) returns (Devices) {} + rpc Init(TrainingConfig) returns (TrainingSessionId) {} rpc Start(TrainingSessionId) returns (Empty) {} diff --git a/proto/utils.proto b/proto/utils.proto index 6f0dbc8b..cb24d3e3 100644 --- a/proto/utils.proto +++ b/proto/utils.proto @@ -1,5 +1,7 @@ syntax = "proto3"; +message Empty {} + message NamedInt { uint32 size = 1; string name = 2; @@ -15,4 +17,18 @@ message Tensor { string dtype = 2; string tensorId = 3; repeated NamedInt shape = 4; +} + +message Device { + enum Status { + AVAILABLE = 0; + IN_USE = 1; + } + + string id = 1; + Status status = 2; +} + +message Devices { + repeated Device devices = 1; } \ No newline at end of file diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index b1de0213..52827193 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -9,7 +9,7 @@ from tiktorch import converters from tiktorch.converters import pb_tensor_to_xarray -from tiktorch.proto import inference_pb2, inference_pb2_grpc +from tiktorch.proto import inference_pb2, inference_pb2_grpc, utils_pb2 from tiktorch.server.data_store import DataStore from tiktorch.server.device_pool import TorchDevicePool from tiktorch.server.grpc import inference_servicer @@ -101,13 +101,13 @@ def test_model_init_failed_close_session(self, bioimage_model_explicit_add_one_s class TestDeviceManagement: def test_list_devices(self, grpc_stub): - resp = grpc_stub.ListDevices(inference_pb2.Empty()) + resp = grpc_stub.ListDevices(utils_pb2.Empty()) device_by_id = {d.id: d for d in resp.devices} assert "cpu" in device_by_id - assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status def _query_devices(self, grpc_stub): - dev_resp = grpc_stub.ListDevices(inference_pb2.Empty()) + dev_resp = grpc_stub.ListDevices(utils_pb2.Empty()) device_by_id = {d.id: d for d in dev_resp.devices} return device_by_id @@ -121,19 +121,19 @@ def test_if_model_create_fails_devices_are_released(self, grpc_stub): device_by_id = self._query_devices(grpc_stub) assert "cpu" in device_by_id - assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status def test_use_device(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5): model_bytes = bioimage_model_explicit_add_one_siso_v5 device_by_id = self._query_devices(grpc_stub) assert "cpu" in device_by_id - assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status grpc_stub.CreateModelSession(valid_model_request(model_bytes, device_ids=["cpu"])) device_by_id = self._query_devices(grpc_stub) assert "cpu" in device_by_id - assert inference_pb2.Device.Status.IN_USE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.IN_USE == device_by_id["cpu"].status def test_using_same_device_fails(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5): model_bytes = bioimage_model_explicit_add_one_siso_v5 @@ -147,20 +147,20 @@ def test_closing_session_releases_devices(self, grpc_stub, bioimage_model_explic device_by_id = self._query_devices(grpc_stub) assert "cpu" in device_by_id - assert inference_pb2.Device.Status.IN_USE == device_by_id["cpu"].status + assert utils_pb2.Device.Status.IN_USE == device_by_id["cpu"].status grpc_stub.CloseModelSession(model) device_by_id_after_close = self._query_devices(grpc_stub) assert "cpu" in device_by_id_after_close - assert inference_pb2.Device.Status.AVAILABLE == device_by_id_after_close["cpu"].status + assert utils_pb2.Device.Status.AVAILABLE == device_by_id_after_close["cpu"].status class TestGetLogs: def test_returns_ack_message(self, bioimage_model_explicit_add_one_siso_v5, grpc_stub): model_bytes = bioimage_model_explicit_add_one_siso_v5 grpc_stub.CreateModelSession(valid_model_request(model_bytes)) - resp = grpc_stub.GetLogs(inference_pb2.Empty()) + resp = grpc_stub.GetLogs(utils_pb2.Empty()) record = next(resp) assert inference_pb2.LogEntry.Level.INFO == record.level assert "Sending model logs" == record.content diff --git a/tests/test_server/test_grpc/test_init.py b/tests/test_server/test_grpc/test_init.py index af7bcf8a..0c8b57a9 100644 --- a/tests/test_server/test_grpc/test_init.py +++ b/tests/test_server/test_grpc/test_init.py @@ -4,8 +4,8 @@ import grpc -from tiktorch.proto.inference_pb2 import Empty from tiktorch.proto.inference_pb2_grpc import FlightControlStub +from tiktorch.proto.utils_pb2 import Empty from tiktorch.server.grpc import serve from tiktorch.utils import wait diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index f40a8fcc..61ef620a 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -10,7 +10,7 @@ import pytest from tiktorch.converters import pb_state_to_trainer, trainer_state_to_pb -from tiktorch.proto import training_pb2, training_pb2_grpc +from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2 from tiktorch.server.device_pool import TorchDevicePool from tiktorch.server.grpc import training_servicer from tiktorch.server.session.backend.base import TrainerSessionBackend @@ -342,7 +342,7 @@ def test_start_training_without_init(self, grpc_stub): Test starting training without initializing a session. """ with pytest.raises(grpc.RpcError) as excinfo: - grpc_stub.Start(training_pb2.Empty()) + grpc_stub.Start(utils_pb2.Empty()) assert excinfo.value.code() == grpc.StatusCode.FAILED_PRECONDITION assert "trainer-session with id doesn't exist" in excinfo.value.details() diff --git a/tiktorch/proto/inference_pb2.py b/tiktorch/proto/inference_pb2.py index e5c575ba..be9fbde2 100644 --- a/tiktorch/proto/inference_pb2.py +++ b/tiktorch/proto/inference_pb2.py @@ -14,45 +14,37 @@ from . import utils_pb2 as utils__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\x1a\x0butils.proto\"c\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12(\n\x06status\x18\x02 \x01(\x0e\x32\x18.inference.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"s\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12%\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x0f.inference.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\xa8\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12(\n\x05level\x18\x02 \x01(\x0e\x32\x19.inference.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"-\n\x07\x44\x65vices\x12\"\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x11.inference.Device\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty2\xbe\x03\n\tInference\x12U\n\x12\x43reateModelSession\x12$.inference.CreateModelSessionRequest\x1a\x17.inference.ModelSession\"\x00\x12@\n\x11\x43loseModelSession\x12\x17.inference.ModelSession\x1a\x10.inference.Empty\"\x00\x12g\n\x18\x43reateDatasetDescription\x12*.inference.CreateDatasetDescriptionRequest\x1a\x1d.inference.DatasetDescription\"\x00\x12\x34\n\x07GetLogs\x12\x10.inference.Empty\x1a\x13.inference.LogEntry\"\x00\x30\x01\x12\x35\n\x0bListDevices\x12\x10.inference.Empty\x1a\x12.inference.Devices\"\x00\x12\x42\n\x07Predict\x12\x19.inference.PredictRequest\x1a\x1a.inference.PredictResponse\"\x00\x32o\n\rFlightControl\x12,\n\x04Ping\x12\x10.inference.Empty\x1a\x10.inference.Empty\"\x00\x12\x30\n\x08Shutdown\x12\x10.inference.Empty\x1a\x10.inference.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\x1a\x0butils.proto\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"s\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12%\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x0f.inference.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\xa8\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12(\n\x05level\x18\x02 \x01(\x0e\x32\x19.inference.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor2\x96\x03\n\tInference\x12U\n\x12\x43reateModelSession\x12$.inference.CreateModelSessionRequest\x1a\x17.inference.ModelSession\"\x00\x12\x36\n\x11\x43loseModelSession\x12\x17.inference.ModelSession\x1a\x06.Empty\"\x00\x12g\n\x18\x43reateDatasetDescription\x12*.inference.CreateDatasetDescriptionRequest\x1a\x1d.inference.DatasetDescription\"\x00\x12*\n\x07GetLogs\x12\x06.Empty\x1a\x13.inference.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x42\n\x07Predict\x12\x19.inference.PredictRequest\x1a\x1a.inference.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'inference_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _DEVICE._serialized_start=43 - _DEVICE._serialized_end=142 - _DEVICE_STATUS._serialized_start=107 - _DEVICE_STATUS._serialized_end=142 - _CREATEDATASETDESCRIPTIONREQUEST._serialized_start=144 - _CREATEDATASETDESCRIPTIONREQUEST._serialized_end=231 - _DATASETDESCRIPTION._serialized_start=233 - _DATASETDESCRIPTION._serialized_end=265 - _BLOB._serialized_start=267 - _BLOB._serialized_end=306 - _CREATEMODELSESSIONREQUEST._serialized_start=308 - _CREATEMODELSESSIONREQUEST._serialized_end=423 - _NAMEDINTS._serialized_start=425 - _NAMEDINTS._serialized_end=466 - _NAMEDFLOATS._serialized_start=468 - _NAMEDFLOATS._serialized_end=515 - _MODELSESSION._serialized_start=517 - _MODELSESSION._serialized_end=543 - _LOGENTRY._serialized_start=546 - _LOGENTRY._serialized_end=714 - _LOGENTRY_LEVEL._serialized_start=636 - _LOGENTRY_LEVEL._serialized_end=714 - _DEVICES._serialized_start=716 - _DEVICES._serialized_end=761 - _PREDICTREQUEST._serialized_start=763 - _PREDICTREQUEST._serialized_end=848 - _PREDICTRESPONSE._serialized_start=850 - _PREDICTRESPONSE._serialized_end=893 - _EMPTY._serialized_start=895 - _EMPTY._serialized_end=902 - _INFERENCE._serialized_start=905 - _INFERENCE._serialized_end=1351 - _FLIGHTCONTROL._serialized_start=1353 - _FLIGHTCONTROL._serialized_end=1464 + _CREATEDATASETDESCRIPTIONREQUEST._serialized_start=43 + _CREATEDATASETDESCRIPTIONREQUEST._serialized_end=130 + _DATASETDESCRIPTION._serialized_start=132 + _DATASETDESCRIPTION._serialized_end=164 + _BLOB._serialized_start=166 + _BLOB._serialized_end=205 + _CREATEMODELSESSIONREQUEST._serialized_start=207 + _CREATEMODELSESSIONREQUEST._serialized_end=322 + _NAMEDINTS._serialized_start=324 + _NAMEDINTS._serialized_end=365 + _NAMEDFLOATS._serialized_start=367 + _NAMEDFLOATS._serialized_end=414 + _MODELSESSION._serialized_start=416 + _MODELSESSION._serialized_end=442 + _LOGENTRY._serialized_start=445 + _LOGENTRY._serialized_end=613 + _LOGENTRY_LEVEL._serialized_start=535 + _LOGENTRY_LEVEL._serialized_end=613 + _PREDICTREQUEST._serialized_start=615 + _PREDICTREQUEST._serialized_end=700 + _PREDICTRESPONSE._serialized_start=702 + _PREDICTRESPONSE._serialized_end=745 + _INFERENCE._serialized_start=748 + _INFERENCE._serialized_end=1154 + _FLIGHTCONTROL._serialized_start=1156 + _FLIGHTCONTROL._serialized_end=1227 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/inference_pb2_grpc.py b/tiktorch/proto/inference_pb2_grpc.py index d967077f..b49f2e42 100644 --- a/tiktorch/proto/inference_pb2_grpc.py +++ b/tiktorch/proto/inference_pb2_grpc.py @@ -3,6 +3,7 @@ import grpc from . import inference_pb2 as inference__pb2 +from . import utils_pb2 as utils__pb2 class InferenceStub(object): @@ -22,7 +23,7 @@ def __init__(self, channel): self.CloseModelSession = channel.unary_unary( '/inference.Inference/CloseModelSession', request_serializer=inference__pb2.ModelSession.SerializeToString, - response_deserializer=inference__pb2.Empty.FromString, + response_deserializer=utils__pb2.Empty.FromString, ) self.CreateDatasetDescription = channel.unary_unary( '/inference.Inference/CreateDatasetDescription', @@ -31,13 +32,13 @@ def __init__(self, channel): ) self.GetLogs = channel.unary_stream( '/inference.Inference/GetLogs', - request_serializer=inference__pb2.Empty.SerializeToString, + request_serializer=utils__pb2.Empty.SerializeToString, response_deserializer=inference__pb2.LogEntry.FromString, ) self.ListDevices = channel.unary_unary( '/inference.Inference/ListDevices', - request_serializer=inference__pb2.Empty.SerializeToString, - response_deserializer=inference__pb2.Devices.FromString, + request_serializer=utils__pb2.Empty.SerializeToString, + response_deserializer=utils__pb2.Devices.FromString, ) self.Predict = channel.unary_unary( '/inference.Inference/Predict', @@ -96,7 +97,7 @@ def add_InferenceServicer_to_server(servicer, server): 'CloseModelSession': grpc.unary_unary_rpc_method_handler( servicer.CloseModelSession, request_deserializer=inference__pb2.ModelSession.FromString, - response_serializer=inference__pb2.Empty.SerializeToString, + response_serializer=utils__pb2.Empty.SerializeToString, ), 'CreateDatasetDescription': grpc.unary_unary_rpc_method_handler( servicer.CreateDatasetDescription, @@ -105,13 +106,13 @@ def add_InferenceServicer_to_server(servicer, server): ), 'GetLogs': grpc.unary_stream_rpc_method_handler( servicer.GetLogs, - request_deserializer=inference__pb2.Empty.FromString, + request_deserializer=utils__pb2.Empty.FromString, response_serializer=inference__pb2.LogEntry.SerializeToString, ), 'ListDevices': grpc.unary_unary_rpc_method_handler( servicer.ListDevices, - request_deserializer=inference__pb2.Empty.FromString, - response_serializer=inference__pb2.Devices.SerializeToString, + request_deserializer=utils__pb2.Empty.FromString, + response_serializer=utils__pb2.Devices.SerializeToString, ), 'Predict': grpc.unary_unary_rpc_method_handler( servicer.Predict, @@ -158,7 +159,7 @@ def CloseModelSession(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/inference.Inference/CloseModelSession', inference__pb2.ModelSession.SerializeToString, - inference__pb2.Empty.FromString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -191,7 +192,7 @@ def GetLogs(request, timeout=None, metadata=None): return grpc.experimental.unary_stream(request, target, '/inference.Inference/GetLogs', - inference__pb2.Empty.SerializeToString, + utils__pb2.Empty.SerializeToString, inference__pb2.LogEntry.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -208,8 +209,8 @@ def ListDevices(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/inference.Inference/ListDevices', - inference__pb2.Empty.SerializeToString, - inference__pb2.Devices.FromString, + utils__pb2.Empty.SerializeToString, + utils__pb2.Devices.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -242,13 +243,13 @@ def __init__(self, channel): """ self.Ping = channel.unary_unary( '/inference.FlightControl/Ping', - request_serializer=inference__pb2.Empty.SerializeToString, - response_deserializer=inference__pb2.Empty.FromString, + request_serializer=utils__pb2.Empty.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, ) self.Shutdown = channel.unary_unary( '/inference.FlightControl/Shutdown', - request_serializer=inference__pb2.Empty.SerializeToString, - response_deserializer=inference__pb2.Empty.FromString, + request_serializer=utils__pb2.Empty.SerializeToString, + response_deserializer=utils__pb2.Empty.FromString, ) @@ -272,13 +273,13 @@ def add_FlightControlServicer_to_server(servicer, server): rpc_method_handlers = { 'Ping': grpc.unary_unary_rpc_method_handler( servicer.Ping, - request_deserializer=inference__pb2.Empty.FromString, - response_serializer=inference__pb2.Empty.SerializeToString, + request_deserializer=utils__pb2.Empty.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, ), 'Shutdown': grpc.unary_unary_rpc_method_handler( servicer.Shutdown, - request_deserializer=inference__pb2.Empty.FromString, - response_serializer=inference__pb2.Empty.SerializeToString, + request_deserializer=utils__pb2.Empty.FromString, + response_serializer=utils__pb2.Empty.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -302,8 +303,8 @@ def Ping(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/inference.FlightControl/Ping', - inference__pb2.Empty.SerializeToString, - inference__pb2.Empty.FromString, + utils__pb2.Empty.SerializeToString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -319,7 +320,7 @@ def Shutdown(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/inference.FlightControl/Shutdown', - inference__pb2.Empty.SerializeToString, - inference__pb2.Empty.FromString, + utils__pb2.Empty.SerializeToString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tiktorch/proto/training_pb2.py b/tiktorch/proto/training_pb2.py index 669a1592..70d93933 100644 --- a/tiktorch/proto/training_pb2.py +++ b/tiktorch/proto/training_pb2.py @@ -14,39 +14,37 @@ from . import utils_pb2 as utils__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"\x07\n\x05\x45mpty\"\x1f\n\x11TrainingSessionId\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"Z\n\x0ePredictRequest\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\x12.\n\tsessionId\x18\x02 \x01(\x0b\x32\x1b.training.TrainingSessionId\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xd2\x05\n\x08Training\x12?\n\x04Init\x12\x18.training.TrainingConfig\x1a\x1b.training.TrainingSessionId\"\x00\x12\x37\n\x05Start\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12\x38\n\x06Resume\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12\x37\n\x05Pause\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12P\n\rStreamUpdates\x12\x1b.training.TrainingSessionId\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x43\n\x07GetLogs\x12\x1b.training.TrainingSessionId\x1a\x19.training.GetLogsResponse\"\x00\x12\x36\n\x04Save\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12\x38\n\x06\x45xport\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x12@\n\x07Predict\x12\x18.training.PredictRequest\x1a\x19.training.PredictResponse\"\x00\x12G\n\tGetStatus\x12\x1b.training.TrainingSessionId\x1a\x1b.training.GetStatusResponse\"\x00\x12\x45\n\x13\x43loseTrainerSession\x12\x1b.training.TrainingSessionId\x1a\x0f.training.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"\x1f\n\x11TrainingSessionId\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"Z\n\x0ePredictRequest\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\x12.\n\tsessionId\x18\x02 \x01(\x0b\x32\x1b.training.TrainingSessionId\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xbf\x05\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12?\n\x04Init\x12\x18.training.TrainingConfig\x1a\x1b.training.TrainingSessionId\"\x00\x12.\n\x05Start\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12/\n\x06Resume\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12.\n\x05Pause\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12P\n\rStreamUpdates\x12\x1b.training.TrainingSessionId\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x43\n\x07GetLogs\x12\x1b.training.TrainingSessionId\x1a\x19.training.GetLogsResponse\"\x00\x12-\n\x04Save\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12/\n\x06\x45xport\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12@\n\x07Predict\x12\x18.training.PredictRequest\x1a\x19.training.PredictResponse\"\x00\x12G\n\tGetStatus\x12\x1b.training.TrainingSessionId\x1a\x1b.training.GetStatusResponse\"\x00\x12<\n\x13\x43loseTrainerSession\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'training_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _EMPTY._serialized_start=41 - _EMPTY._serialized_end=48 - _TRAININGSESSIONID._serialized_start=50 - _TRAININGSESSIONID._serialized_end=81 - _LOGS._serialized_start=84 - _LOGS._serialized_end=219 - _LOGS_MODELPHASE._serialized_start=186 - _LOGS_MODELPHASE._serialized_end=219 - _STREAMUPDATERESPONSE._serialized_start=221 - _STREAMUPDATERESPONSE._serialized_end=297 - _GETLOGSRESPONSE._serialized_start=299 - _GETLOGSRESPONSE._serialized_end=346 - _PREDICTREQUEST._serialized_start=348 - _PREDICTREQUEST._serialized_end=438 - _PREDICTRESPONSE._serialized_start=440 - _PREDICTRESPONSE._serialized_end=483 - _VALIDATIONRESPONSE._serialized_start=485 - _VALIDATIONRESPONSE._serialized_end=539 - _GETSTATUSRESPONSE._serialized_start=542 - _GETSTATUSRESPONSE._serialized_end=681 - _GETSTATUSRESPONSE_STATE._serialized_start=613 - _GETSTATUSRESPONSE_STATE._serialized_end=681 - _GETCURRENTBESTMODELIDXRESPONSE._serialized_start=683 - _GETCURRENTBESTMODELIDXRESPONSE._serialized_end=727 - _TRAININGCONFIG._serialized_start=729 - _TRAININGCONFIG._serialized_end=767 - _TRAINING._serialized_start=770 - _TRAINING._serialized_end=1492 + _TRAININGSESSIONID._serialized_start=41 + _TRAININGSESSIONID._serialized_end=72 + _LOGS._serialized_start=75 + _LOGS._serialized_end=210 + _LOGS_MODELPHASE._serialized_start=177 + _LOGS_MODELPHASE._serialized_end=210 + _STREAMUPDATERESPONSE._serialized_start=212 + _STREAMUPDATERESPONSE._serialized_end=288 + _GETLOGSRESPONSE._serialized_start=290 + _GETLOGSRESPONSE._serialized_end=337 + _PREDICTREQUEST._serialized_start=339 + _PREDICTREQUEST._serialized_end=429 + _PREDICTRESPONSE._serialized_start=431 + _PREDICTRESPONSE._serialized_end=474 + _VALIDATIONRESPONSE._serialized_start=476 + _VALIDATIONRESPONSE._serialized_end=530 + _GETSTATUSRESPONSE._serialized_start=533 + _GETSTATUSRESPONSE._serialized_end=672 + _GETSTATUSRESPONSE_STATE._serialized_start=604 + _GETSTATUSRESPONSE_STATE._serialized_end=672 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_start=674 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_end=718 + _TRAININGCONFIG._serialized_start=720 + _TRAININGCONFIG._serialized_end=758 + _TRAINING._serialized_start=761 + _TRAINING._serialized_end=1464 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/training_pb2_grpc.py b/tiktorch/proto/training_pb2_grpc.py index 5be55746..79bf33df 100644 --- a/tiktorch/proto/training_pb2_grpc.py +++ b/tiktorch/proto/training_pb2_grpc.py @@ -3,6 +3,7 @@ import grpc from . import training_pb2 as training__pb2 +from . import utils_pb2 as utils__pb2 class TrainingStub(object): @@ -14,6 +15,11 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ + self.ListDevices = channel.unary_unary( + '/training.Training/ListDevices', + request_serializer=utils__pb2.Empty.SerializeToString, + response_deserializer=utils__pb2.Devices.FromString, + ) self.Init = channel.unary_unary( '/training.Training/Init', request_serializer=training__pb2.TrainingConfig.SerializeToString, @@ -22,17 +28,17 @@ def __init__(self, channel): self.Start = channel.unary_unary( '/training.Training/Start', request_serializer=training__pb2.TrainingSessionId.SerializeToString, - response_deserializer=training__pb2.Empty.FromString, + response_deserializer=utils__pb2.Empty.FromString, ) self.Resume = channel.unary_unary( '/training.Training/Resume', request_serializer=training__pb2.TrainingSessionId.SerializeToString, - response_deserializer=training__pb2.Empty.FromString, + response_deserializer=utils__pb2.Empty.FromString, ) self.Pause = channel.unary_unary( '/training.Training/Pause', request_serializer=training__pb2.TrainingSessionId.SerializeToString, - response_deserializer=training__pb2.Empty.FromString, + response_deserializer=utils__pb2.Empty.FromString, ) self.StreamUpdates = channel.unary_stream( '/training.Training/StreamUpdates', @@ -47,12 +53,12 @@ def __init__(self, channel): self.Save = channel.unary_unary( '/training.Training/Save', request_serializer=training__pb2.TrainingSessionId.SerializeToString, - response_deserializer=training__pb2.Empty.FromString, + response_deserializer=utils__pb2.Empty.FromString, ) self.Export = channel.unary_unary( '/training.Training/Export', request_serializer=training__pb2.TrainingSessionId.SerializeToString, - response_deserializer=training__pb2.Empty.FromString, + response_deserializer=utils__pb2.Empty.FromString, ) self.Predict = channel.unary_unary( '/training.Training/Predict', @@ -67,13 +73,19 @@ def __init__(self, channel): self.CloseTrainerSession = channel.unary_unary( '/training.Training/CloseTrainerSession', request_serializer=training__pb2.TrainingSessionId.SerializeToString, - response_deserializer=training__pb2.Empty.FromString, + response_deserializer=utils__pb2.Empty.FromString, ) class TrainingServicer(object): """Missing associated documentation comment in .proto file.""" + def ListDevices(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def Init(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -143,6 +155,11 @@ def CloseTrainerSession(self, request, context): def add_TrainingServicer_to_server(servicer, server): rpc_method_handlers = { + 'ListDevices': grpc.unary_unary_rpc_method_handler( + servicer.ListDevices, + request_deserializer=utils__pb2.Empty.FromString, + response_serializer=utils__pb2.Devices.SerializeToString, + ), 'Init': grpc.unary_unary_rpc_method_handler( servicer.Init, request_deserializer=training__pb2.TrainingConfig.FromString, @@ -151,17 +168,17 @@ def add_TrainingServicer_to_server(servicer, server): 'Start': grpc.unary_unary_rpc_method_handler( servicer.Start, request_deserializer=training__pb2.TrainingSessionId.FromString, - response_serializer=training__pb2.Empty.SerializeToString, + response_serializer=utils__pb2.Empty.SerializeToString, ), 'Resume': grpc.unary_unary_rpc_method_handler( servicer.Resume, request_deserializer=training__pb2.TrainingSessionId.FromString, - response_serializer=training__pb2.Empty.SerializeToString, + response_serializer=utils__pb2.Empty.SerializeToString, ), 'Pause': grpc.unary_unary_rpc_method_handler( servicer.Pause, request_deserializer=training__pb2.TrainingSessionId.FromString, - response_serializer=training__pb2.Empty.SerializeToString, + response_serializer=utils__pb2.Empty.SerializeToString, ), 'StreamUpdates': grpc.unary_stream_rpc_method_handler( servicer.StreamUpdates, @@ -176,12 +193,12 @@ def add_TrainingServicer_to_server(servicer, server): 'Save': grpc.unary_unary_rpc_method_handler( servicer.Save, request_deserializer=training__pb2.TrainingSessionId.FromString, - response_serializer=training__pb2.Empty.SerializeToString, + response_serializer=utils__pb2.Empty.SerializeToString, ), 'Export': grpc.unary_unary_rpc_method_handler( servicer.Export, request_deserializer=training__pb2.TrainingSessionId.FromString, - response_serializer=training__pb2.Empty.SerializeToString, + response_serializer=utils__pb2.Empty.SerializeToString, ), 'Predict': grpc.unary_unary_rpc_method_handler( servicer.Predict, @@ -196,7 +213,7 @@ def add_TrainingServicer_to_server(servicer, server): 'CloseTrainerSession': grpc.unary_unary_rpc_method_handler( servicer.CloseTrainerSession, request_deserializer=training__pb2.TrainingSessionId.FromString, - response_serializer=training__pb2.Empty.SerializeToString, + response_serializer=utils__pb2.Empty.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -208,6 +225,23 @@ def add_TrainingServicer_to_server(servicer, server): class Training(object): """Missing associated documentation comment in .proto file.""" + @staticmethod + def ListDevices(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Training/ListDevices', + utils__pb2.Empty.SerializeToString, + utils__pb2.Devices.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod def Init(request, target, @@ -238,7 +272,7 @@ def Start(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Start', training__pb2.TrainingSessionId.SerializeToString, - training__pb2.Empty.FromString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -255,7 +289,7 @@ def Resume(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Resume', training__pb2.TrainingSessionId.SerializeToString, - training__pb2.Empty.FromString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -272,7 +306,7 @@ def Pause(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Pause', training__pb2.TrainingSessionId.SerializeToString, - training__pb2.Empty.FromString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -323,7 +357,7 @@ def Save(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Save', training__pb2.TrainingSessionId.SerializeToString, - training__pb2.Empty.FromString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -340,7 +374,7 @@ def Export(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Export', training__pb2.TrainingSessionId.SerializeToString, - training__pb2.Empty.FromString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -391,6 +425,6 @@ def CloseTrainerSession(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/CloseTrainerSession', training__pb2.TrainingSessionId.SerializeToString, - training__pb2.Empty.FromString, + utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tiktorch/proto/utils_pb2.py b/tiktorch/proto/utils_pb2.py index ba10af3a..c0709d97 100644 --- a/tiktorch/proto/utils_pb2.py +++ b/tiktorch/proto/utils_pb2.py @@ -13,17 +13,25 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0butils.proto\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"S\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x10\n\x08tensorId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedIntb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0butils.proto\"\x07\n\x05\x45mpty\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"S\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x10\n\x08tensorId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Deviceb\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'utils_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _NAMEDINT._serialized_start=15 - _NAMEDINT._serialized_end=53 - _NAMEDFLOAT._serialized_start=55 - _NAMEDFLOAT._serialized_end=95 - _TENSOR._serialized_start=97 - _TENSOR._serialized_end=180 + _EMPTY._serialized_start=15 + _EMPTY._serialized_end=22 + _NAMEDINT._serialized_start=24 + _NAMEDINT._serialized_end=62 + _NAMEDFLOAT._serialized_start=64 + _NAMEDFLOAT._serialized_end=104 + _TENSOR._serialized_start=106 + _TENSOR._serialized_end=189 + _DEVICE._serialized_start=191 + _DEVICE._serialized_end=280 + _DEVICE_STATUS._serialized_start=245 + _DEVICE_STATUS._serialized_end=280 + _DEVICES._serialized_start=282 + _DEVICES._serialized_end=317 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/server/grpc/flight_control_servicer.py b/tiktorch/server/grpc/flight_control_servicer.py index b1c41cd0..4ac2866e 100644 --- a/tiktorch/server/grpc/flight_control_servicer.py +++ b/tiktorch/server/grpc/flight_control_servicer.py @@ -3,7 +3,7 @@ import time from typing import Optional -from tiktorch.proto import inference_pb2, inference_pb2_grpc +from tiktorch.proto import inference_pb2_grpc, utils_pb2 logger = logging.getLogger(__name__) @@ -43,11 +43,11 @@ def _run_watchdog(): watchdog_thread.start() return watchdog_thread - def Ping(self, request: inference_pb2.Empty, context) -> inference_pb2.Empty: + def Ping(self, request: utils_pb2.Empty, context) -> utils_pb2.Empty: self.__last_ping = time.time() - return inference_pb2.Empty() + return utils_pb2.Empty() - def Shutdown(self, request: inference_pb2.Empty, context) -> inference_pb2.Empty: + def Shutdown(self, request: utils_pb2.Empty, context) -> utils_pb2.Empty: if self.__done_evt: self.__done_evt.set() - return inference_pb2.Empty() + return utils_pb2.Empty() diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 230674ee..4318fa50 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -3,10 +3,11 @@ import grpc from tiktorch.converters import pb_tensors_to_sample, sample_to_pb_tensors -from tiktorch.proto import inference_pb2, inference_pb2_grpc +from tiktorch.proto import inference_pb2, inference_pb2_grpc, utils_pb2 from tiktorch.rpc.mp import BioModelClient from tiktorch.server.data_store import IDataStore -from tiktorch.server.device_pool import DeviceStatus, IDevicePool +from tiktorch.server.device_pool import IDevicePool +from tiktorch.server.grpc.utils_servicer import list_devices from tiktorch.server.session.process import InputSampleValidator, start_model_session_process from tiktorch.server.session_manager import Session, SessionManager @@ -55,9 +56,9 @@ def CreateDatasetDescription( id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) return inference_pb2.DatasetDescription(id=id) - def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty: + def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> utils_pb2.Empty: self.__session_manager.close_session(request.id) - return inference_pb2.Empty() + return utils_pb2.Empty() def close_all_sessions(self): """ @@ -68,25 +69,13 @@ def close_all_sessions(self): self.__session_manager.close_all_sessions() assert len(self.__device_pool.list_reserved_devices()) == 0 - def GetLogs(self, request: inference_pb2.Empty, context): + def GetLogs(self, request: utils_pb2.Empty, context): yield inference_pb2.LogEntry( timestamp=int(time.time()), level=inference_pb2.LogEntry.Level.INFO, content="Sending model logs" ) - def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.Devices: - devices = self.__device_pool.list_devices() - pb_devices = [] - for dev in devices: - if dev.status == DeviceStatus.AVAILABLE: - pb_status = inference_pb2.Device.Status.AVAILABLE - elif dev.status == DeviceStatus.IN_USE: - pb_status = inference_pb2.Device.Status.IN_USE - else: - raise ValueError(f"Unknown status value {dev.status}") - - pb_devices.append(inference_pb2.Device(id=dev.id, status=pb_status)) - - return inference_pb2.Devices(devices=pb_devices) + def ListDevices(self, request: utils_pb2.Empty, context) -> utils_pb2.Devices: + return list_devices(self.__device_pool) def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) diff --git a/tiktorch/server/grpc/training_servicer.py b/tiktorch/server/grpc/training_servicer.py index 4dac6b18..9488a592 100644 --- a/tiktorch/server/grpc/training_servicer.py +++ b/tiktorch/server/grpc/training_servicer.py @@ -2,13 +2,15 @@ import logging import queue +from pathlib import Path from typing import Callable, List import grpc from tiktorch.converters import trainer_state_to_pb -from tiktorch.proto import training_pb2, training_pb2_grpc +from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2 from tiktorch.server.device_pool import IDevicePool +from tiktorch.server.grpc.utils_servicer import list_devices from tiktorch.server.session.process import start_trainer_process from tiktorch.server.session.rpc_interface import IRPCTrainer from tiktorch.server.session_manager import Session, SessionManager @@ -28,6 +30,9 @@ def __init__( self._should_stop_callbacks: List[Callable] = [] self._session_manager = session_manager + def ListDevices(self, request: utils_pb2.Empty, context) -> utils_pb2.Devices: + return list_devices(self._device_pool) + def Init(self, request: training_pb2.TrainingConfig, context): parser = TrainerYamlParser(request.yaml_content) device = parser.get_device() @@ -50,27 +55,27 @@ def Init(self, request: training_pb2.TrainingConfig, context): def Start(self, request, context): session = self._getTrainerSession(context, request.id) session.client.start_training() - return training_pb2.Empty() + return utils_pb2.Empty() def Resume(self, request, context): session = self._getTrainerSession(context, request.id) session.client.resume_training() - return training_pb2.Empty() + return utils_pb2.Empty() def Pause(self, request: training_pb2.TrainingSessionId, context): session = self._getTrainerSession(context, request.id) session.client.pause_training() - return training_pb2.Empty() + return utils_pb2.Empty() - def Save(self, request: training_pb2.TrainingSessionId, context): - session = self._getTrainerSession(context, request.modelSessionId) - session.client.save() - return training_pb2.Empty() + def Save(self, request: training_pb2.SaveRequest, context): + session = self._getTrainerSession(context, request.sessionId.id) + session.client.save(Path(request.filePath)) + return utils_pb2.Empty() - def Export(self, request: training_pb2.TrainingSessionId, context): - session = self._getTrainerSession(context, request.modelSessionId) - session.client.export() - return training_pb2.Empty() + def Export(self, request: training_pb2.ExportRequest, context): + session = self._getTrainerSession(context, request.sessionId.id) + session.client.export(Path(request.filePath)) + return utils_pb2.Empty() def Predict(self, request: training_pb2.TrainingSessionId, context): raise NotImplementedError @@ -88,7 +93,7 @@ def GetStatus(self, request: training_pb2.TrainingSessionId, context): def CloseTrainerSession(self, request: training_pb2.TrainingSessionId, context) -> training_pb2.Empty: self._session_manager.close_session(request.id) - return training_pb2.Empty() + return utils_pb2.Empty() def close_all_sessions(self): self._session_manager.close_all_sessions() diff --git a/tiktorch/server/grpc/utils_servicer.py b/tiktorch/server/grpc/utils_servicer.py new file mode 100644 index 00000000..bb23b40c --- /dev/null +++ b/tiktorch/server/grpc/utils_servicer.py @@ -0,0 +1,18 @@ +from tiktorch.proto import utils_pb2 +from tiktorch.server.device_pool import DeviceStatus, IDevicePool + + +def list_devices(device_pool: IDevicePool) -> utils_pb2.Devices: + devices = device_pool.list_devices() + pb_devices = [] + for dev in devices: + if dev.status == DeviceStatus.AVAILABLE: + pb_status = utils_pb2.Device.Status.AVAILABLE + elif dev.status == DeviceStatus.IN_USE: + pb_status = utils_pb2.Device.Status.IN_USE + else: + raise ValueError(f"Unknown status value {dev.status}") + + pb_devices.append(utils_pb2.Device(id=dev.id, status=pb_status)) + + return utils_pb2.Devices(devices=pb_devices)