Skip to content

Commit

Permalink
Add utils proto
Browse files Browse the repository at this point in the history
Move NamedInt and Tensor proto to a separate file so training proto can
use as well
  • Loading branch information
thodkatz committed Dec 20, 2024
1 parent 6c1875a commit 637c4d1
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 443 deletions.
20 changes: 5 additions & 15 deletions proto/inference.proto
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
syntax = "proto3";

package inference;

import "utils.proto";


service Inference {
rpc CreateModelSession(CreateModelSessionRequest) returns (ModelSession) {}

Expand Down Expand Up @@ -80,22 +85,7 @@ message Devices {
repeated Device devices = 1;
}

message NamedInt {
uint32 size = 1;
string name = 2;
}

message NamedFloat {
float size = 1;
string name = 2;
}

message Tensor {
bytes buffer = 1;
string dtype = 2;
string tensorId = 3;
repeated NamedInt shape = 4;
}

message PredictRequest {
string modelSessionId = 1;
Expand Down
19 changes: 4 additions & 15 deletions proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ syntax = "proto3";

package training;

import "utils.proto";


message Empty {}

Expand Down Expand Up @@ -57,28 +59,15 @@ message GetLogsResponse {
}


message NamedInt {
uint32 size = 1;
string name = 2;
}


message Tensor {
bytes buffer = 1;
string dtype = 2;
repeated NamedInt shape = 4;
}


message PredictRequest {
repeated Tensor tensors = 1;
TrainingSessionId id = 2;
TrainingSessionId sessionId = 2;
}


message PredictResponse {
uint32 best_model_idx = 1;
repeated Tensor tensors = 2;
repeated Tensor tensors = 1;
}

message ValidationResponse {
Expand Down
18 changes: 18 additions & 0 deletions proto/utils.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
syntax = "proto3";

message NamedInt {
uint32 size = 1;
string name = 2;
}

message NamedFloat {
float size = 1;
string name = 2;
}

message Tensor {
bytes buffer = 1;
string dtype = 2;
string tensorId = 3;
repeated NamedInt shape = 4;
}
34 changes: 17 additions & 17 deletions tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
xarray_to_pb_tensor,
xr_tensors_to_sample,
)
from tiktorch.proto import inference_pb2
from tiktorch.proto import utils_pb2


def _numpy_to_pb_tensor(arr, tensor_id: str = "dummy_tensor_name"):
"""
Makes sure that tensor was serialized/deserialized
"""
tensor = numpy_to_pb_tensor(tensor_id, arr)
parsed = inference_pb2.Tensor()
parsed = utils_pb2.Tensor()
parsed.ParseFromString(tensor.SerializeToString())
return parsed

Expand All @@ -31,7 +31,7 @@ def to_pb_tensor(tensor_id: str, arr: xr.DataArray):
Makes sure that tensor was serialized/deserialized
"""
tensor = xarray_to_pb_tensor(tensor_id, arr)
parsed = inference_pb2.Tensor()
parsed = utils_pb2.Tensor()
parsed.ParseFromString(tensor.SerializeToString())
return parsed

Expand All @@ -40,7 +40,7 @@ class TestNumpyToPBTensor:
def test_should_serialize_to_tensor_type(self):
arr = np.arange(9)
tensor = _numpy_to_pb_tensor(arr)
assert isinstance(tensor, inference_pb2.Tensor)
assert isinstance(tensor, utils_pb2.Tensor)

@pytest.mark.parametrize("np_dtype,dtype_str", [(np.int64, "int64"), (np.uint8, "uint8"), (np.float32, "float32")])
def test_should_have_dtype_as_str(self, np_dtype, dtype_str):
Expand All @@ -65,12 +65,12 @@ def test_should_have_serialized_bytes(self):

class TestPBTensorToNumpy:
def test_should_raise_on_empty_dtype(self):
tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)])
tensor = utils_pb2.Tensor(dtype="", shape=[utils_pb2.NamedInt(size=1), utils_pb2.NamedInt(size=2)])
with pytest.raises(ValueError):
pb_tensor_to_numpy(tensor)

def test_should_raise_on_empty_shape(self):
tensor = inference_pb2.Tensor(dtype="int64", shape=[])
tensor = utils_pb2.Tensor(dtype="int64", shape=[])
with pytest.raises(ValueError):
pb_tensor_to_numpy(tensor)

Expand Down Expand Up @@ -109,7 +109,7 @@ class TestXarrayToPBTensor:
def test_should_serialize_to_tensor_type(self):
xarr = xr.DataArray(np.arange(8).reshape((2, 4)), dims=("x", "y"))
pb_tensor = to_pb_tensor("input0", xarr)
assert isinstance(pb_tensor, inference_pb2.Tensor)
assert isinstance(pb_tensor, utils_pb2.Tensor)
assert len(pb_tensor.shape) == 2
dim1 = pb_tensor.shape[0]
dim2 = pb_tensor.shape[1]
Expand Down Expand Up @@ -137,12 +137,12 @@ def test_should_have_serialized_bytes(self):

class TestPBTensorToXarray:
def test_should_raise_on_empty_dtype(self):
tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)])
tensor = utils_pb2.Tensor(dtype="", shape=[utils_pb2.NamedInt(size=1), utils_pb2.NamedInt(size=2)])
with pytest.raises(ValueError):
pb_tensor_to_xarray(tensor)

def test_should_raise_on_empty_shape(self):
tensor = inference_pb2.Tensor(dtype="int64", shape=[])
tensor = utils_pb2.Tensor(dtype="int64", shape=[])
with pytest.raises(ValueError):
pb_tensor_to_xarray(tensor)

Expand Down Expand Up @@ -178,19 +178,19 @@ def test_should_same_data(self, shape):
class TestSample:
def test_pb_tensors_to_sample(self):
arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32)
tensor_1 = inference_pb2.Tensor(
tensor_1 = utils_pb2.Tensor(
dtype="int64",
tensorId="input1",
buffer=bytes(arr_1),
shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)],
shape=[utils_pb2.NamedInt(name="x", size=32), utils_pb2.NamedInt(name="y", size=32)],
)

arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64)
tensor_2 = inference_pb2.Tensor(
tensor_2 = utils_pb2.Tensor(
dtype="int64",
tensorId="input2",
buffer=bytes(arr_2),
shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)],
shape=[utils_pb2.NamedInt(name="x", size=64), utils_pb2.NamedInt(name="y", size=64)],
)

sample = pb_tensors_to_sample([tensor_1, tensor_2])
Expand Down Expand Up @@ -218,17 +218,17 @@ def test_sample_to_pb_tensors(self):
tensors_ids = ["input1", "input2"]
sample = xr_tensors_to_sample(tensors_ids, [tensor_1, tensor_2])

pb_tensor_1 = inference_pb2.Tensor(
pb_tensor_1 = utils_pb2.Tensor(
dtype="int64",
tensorId="input1",
buffer=bytes(arr_1),
shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)],
shape=[utils_pb2.NamedInt(name="x", size=32), utils_pb2.NamedInt(name="y", size=32)],
)
pb_tensor_2 = inference_pb2.Tensor(
pb_tensor_2 = utils_pb2.Tensor(
dtype="int64",
tensorId="input2",
buffer=bytes(arr_2),
shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)],
shape=[utils_pb2.NamedInt(name="x", size=64), utils_pb2.NamedInt(name="y", size=64)],
)
expected_tensors = [pb_tensor_1, pb_tensor_2]

Expand Down
24 changes: 12 additions & 12 deletions tiktorch/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bioimageio.core import Sample, Tensor
from bioimageio.spec.model.v0_5 import TensorId

from tiktorch.proto import inference_pb2, training_pb2
from tiktorch.proto import inference_pb2, training_pb2, utils_pb2
from tiktorch.trainer import TrainerState

trainer_state_to_pb = {
Expand All @@ -21,7 +21,7 @@
pb_state_to_trainer = {value: key for key, value in trainer_state_to_pb.items()}


def pb_tensors_to_sample(pb_tensors: List[inference_pb2.Tensor]) -> Sample:
def pb_tensors_to_sample(pb_tensors: List[utils_pb2.Tensor]) -> Sample:
return Sample(
members={TensorId(tensor.tensorId): Tensor.from_xarray(pb_tensor_to_xarray(tensor)) for tensor in pb_tensors},
id=None,
Expand All @@ -41,21 +41,21 @@ def xr_tensors_to_sample(tensor_ids: List[str], tensors_data: List[xr.DataArray]
)


def sample_to_pb_tensors(sample: Sample) -> List[inference_pb2.Tensor]:
def sample_to_pb_tensors(sample: Sample) -> List[utils_pb2.Tensor]:
return [xarray_to_pb_tensor(tensor_id, res_tensor.data) for tensor_id, res_tensor in sample.members.items()]


def numpy_to_pb_tensor(tensor_id: str, array: np.ndarray, axistags=None) -> inference_pb2.Tensor:
def numpy_to_pb_tensor(tensor_id: str, array: np.ndarray, axistags=None) -> utils_pb2.Tensor:
if axistags:
shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)]
shape = [utils_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)]
else:
shape = [inference_pb2.NamedInt(size=dim) for dim in array.shape]
return inference_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array))
shape = [utils_pb2.NamedInt(size=dim) for dim in array.shape]
return utils_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array))


def xarray_to_pb_tensor(tensor_id: str, array: xr.DataArray) -> inference_pb2.Tensor:
shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, array.dims)]
return inference_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array.data))
def xarray_to_pb_tensor(tensor_id: str, array: xr.DataArray) -> utils_pb2.Tensor:
shape = [utils_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, array.dims)]
return utils_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array.data))


def name_int_tuples_to_pb_NamedInts(name_int_tuples) -> inference_pb2.NamedInts:
Expand All @@ -70,7 +70,7 @@ def name_float_tuples_to_pb_NamedFloats(name_float_tuples) -> inference_pb2.Name
)


def pb_tensor_to_xarray(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor:
def pb_tensor_to_xarray(tensor: utils_pb2.Tensor) -> xr.DataArray:
if not tensor.dtype:
raise ValueError("Tensor dtype is not specified")

Expand All @@ -82,7 +82,7 @@ def pb_tensor_to_xarray(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor:
return xr.DataArray(data, dims=[d.name for d in tensor.shape])


def pb_tensor_to_numpy(tensor: inference_pb2.Tensor) -> np.ndarray:
def pb_tensor_to_numpy(tensor: utils_pb2.Tensor) -> np.ndarray:
if not tensor.dtype:
raise ValueError("Tensor dtype is not specified")

Expand Down
48 changes: 3 additions & 45 deletions tiktorch/proto/data_store_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 637c4d1

Please sign in to comment.