Skip to content

Commit

Permalink
Separate listing devices as an utility for inference and training
Browse files Browse the repository at this point in the history
- The inference servicer had a procedure to list the available devices.
  This is needed or the training servicer as well. So list devices is
  decoupled to be shared.
  • Loading branch information
thodkatz committed Dec 20, 2024
1 parent 637c4d1 commit 27b3923
Show file tree
Hide file tree
Showing 16 changed files with 228 additions and 182 deletions.
4 changes: 2 additions & 2 deletions examples/grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import grpc

from tiktorch.proto import inference_pb2, inference_pb2_grpc
from tiktorch.proto import inference_pb2_grpc, utils_pb2


def run():
with grpc.insecure_channel("127.0.0.1:5567") as channel:
stub = inference_pb2_grpc.InferenceStub(channel)
response = stub.ListDevices(inference_pb2.Empty())
response = stub.ListDevices(utils_pb2.Empty())
print(response)


Expand Down
15 changes: 0 additions & 15 deletions proto/inference.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,6 @@ service Inference {
rpc Predict(PredictRequest) returns (PredictResponse) {}
}

message Device {
enum Status {
AVAILABLE = 0;
IN_USE = 1;
}

string id = 1;
Status status = 2;
}

message CreateDatasetDescriptionRequest {
string modelSessionId = 1;
Expand Down Expand Up @@ -81,11 +72,6 @@ message LogEntry {
string content = 3;
}

message Devices {
repeated Device devices = 1;
}



message PredictRequest {
string modelSessionId = 1;
Expand All @@ -97,7 +83,6 @@ message PredictResponse {
repeated Tensor tensors = 1;
}

message Empty {}

service FlightControl {
rpc Ping(Empty) returns (Empty) {}
Expand Down
4 changes: 2 additions & 2 deletions proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ package training;
import "utils.proto";


message Empty {}


service Training {
rpc ListDevices(Empty) returns (Devices) {}

rpc Init(TrainingConfig) returns (TrainingSessionId) {}

rpc Start(TrainingSessionId) returns (Empty) {}
Expand Down
16 changes: 16 additions & 0 deletions proto/utils.proto
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
syntax = "proto3";

message Empty {}

message NamedInt {
uint32 size = 1;
string name = 2;
Expand All @@ -15,4 +17,18 @@ message Tensor {
string dtype = 2;
string tensorId = 3;
repeated NamedInt shape = 4;
}

message Device {
enum Status {
AVAILABLE = 0;
IN_USE = 1;
}

string id = 1;
Status status = 2;
}

message Devices {
repeated Device devices = 1;
}
20 changes: 10 additions & 10 deletions tests/test_server/test_grpc/test_inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from tiktorch import converters
from tiktorch.converters import pb_tensor_to_xarray
from tiktorch.proto import inference_pb2, inference_pb2_grpc
from tiktorch.proto import inference_pb2, inference_pb2_grpc, utils_pb2
from tiktorch.server.data_store import DataStore
from tiktorch.server.device_pool import TorchDevicePool
from tiktorch.server.grpc import inference_servicer
Expand Down Expand Up @@ -101,13 +101,13 @@ def test_model_init_failed_close_session(self, bioimage_model_explicit_add_one_s

class TestDeviceManagement:
def test_list_devices(self, grpc_stub):
resp = grpc_stub.ListDevices(inference_pb2.Empty())
resp = grpc_stub.ListDevices(utils_pb2.Empty())
device_by_id = {d.id: d for d in resp.devices}
assert "cpu" in device_by_id
assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status
assert utils_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status

def _query_devices(self, grpc_stub):
dev_resp = grpc_stub.ListDevices(inference_pb2.Empty())
dev_resp = grpc_stub.ListDevices(utils_pb2.Empty())
device_by_id = {d.id: d for d in dev_resp.devices}
return device_by_id

Expand All @@ -121,19 +121,19 @@ def test_if_model_create_fails_devices_are_released(self, grpc_stub):

device_by_id = self._query_devices(grpc_stub)
assert "cpu" in device_by_id
assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status
assert utils_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status

def test_use_device(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
device_by_id = self._query_devices(grpc_stub)
assert "cpu" in device_by_id
assert inference_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status
assert utils_pb2.Device.Status.AVAILABLE == device_by_id["cpu"].status

grpc_stub.CreateModelSession(valid_model_request(model_bytes, device_ids=["cpu"]))

device_by_id = self._query_devices(grpc_stub)
assert "cpu" in device_by_id
assert inference_pb2.Device.Status.IN_USE == device_by_id["cpu"].status
assert utils_pb2.Device.Status.IN_USE == device_by_id["cpu"].status

def test_using_same_device_fails(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
Expand All @@ -147,20 +147,20 @@ def test_closing_session_releases_devices(self, grpc_stub, bioimage_model_explic

device_by_id = self._query_devices(grpc_stub)
assert "cpu" in device_by_id
assert inference_pb2.Device.Status.IN_USE == device_by_id["cpu"].status
assert utils_pb2.Device.Status.IN_USE == device_by_id["cpu"].status

grpc_stub.CloseModelSession(model)

device_by_id_after_close = self._query_devices(grpc_stub)
assert "cpu" in device_by_id_after_close
assert inference_pb2.Device.Status.AVAILABLE == device_by_id_after_close["cpu"].status
assert utils_pb2.Device.Status.AVAILABLE == device_by_id_after_close["cpu"].status


class TestGetLogs:
def test_returns_ack_message(self, bioimage_model_explicit_add_one_siso_v5, grpc_stub):
model_bytes = bioimage_model_explicit_add_one_siso_v5
grpc_stub.CreateModelSession(valid_model_request(model_bytes))
resp = grpc_stub.GetLogs(inference_pb2.Empty())
resp = grpc_stub.GetLogs(utils_pb2.Empty())
record = next(resp)
assert inference_pb2.LogEntry.Level.INFO == record.level
assert "Sending model logs" == record.content
Expand Down
2 changes: 1 addition & 1 deletion tests/test_server/test_grpc/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import grpc

from tiktorch.proto.inference_pb2 import Empty
from tiktorch.proto.inference_pb2_grpc import FlightControlStub
from tiktorch.proto.utils_pb2 import Empty
from tiktorch.server.grpc import serve
from tiktorch.utils import wait

Expand Down
4 changes: 2 additions & 2 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

from tiktorch.converters import pb_state_to_trainer, trainer_state_to_pb
from tiktorch.proto import training_pb2, training_pb2_grpc
from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2
from tiktorch.server.device_pool import TorchDevicePool
from tiktorch.server.grpc import training_servicer
from tiktorch.server.session.backend.base import TrainerSessionBackend
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_start_training_without_init(self, grpc_stub):
Test starting training without initializing a session.
"""
with pytest.raises(grpc.RpcError) as excinfo:
grpc_stub.Start(training_pb2.Empty())
grpc_stub.Start(utils_pb2.Empty())
assert excinfo.value.code() == grpc.StatusCode.FAILED_PRECONDITION
assert "trainer-session with id doesn't exist" in excinfo.value.details()

Expand Down
62 changes: 27 additions & 35 deletions tiktorch/proto/inference_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 27b3923

Please sign in to comment.