From 20bcd835dcf69333532a72531d2eb631fccbba72 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Tue, 15 Oct 2024 21:46:52 +0200 Subject: [PATCH] include tensor id when converting numpy to pb tensor --- tests/test_converters.py | 4 ++-- tiktorch/converters.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_converters.py b/tests/test_converters.py index a197a617..cf5a51d2 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -16,11 +16,11 @@ from tiktorch.proto import inference_pb2 -def _numpy_to_pb_tensor(arr): +def _numpy_to_pb_tensor(arr, tensor_id: str = "dummy_tensor_name"): """ Makes sure that tensor was serialized/deserialized """ - tensor = numpy_to_pb_tensor(arr) + tensor = numpy_to_pb_tensor(tensor_id, arr) parsed = inference_pb2.Tensor() parsed.ParseFromString(tensor.SerializeToString()) return parsed diff --git a/tiktorch/converters.py b/tiktorch/converters.py index 5a6669b8..3159c05c 100644 --- a/tiktorch/converters.py +++ b/tiktorch/converters.py @@ -34,12 +34,12 @@ def sample_to_pb_tensors(sample: Sample) -> List[inference_pb2.Tensor]: return [xarray_to_pb_tensor(tensor_id, res_tensor.data) for tensor_id, res_tensor in sample.members.items()] -def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor: +def numpy_to_pb_tensor(tensor_id: str, array: np.ndarray, axistags=None) -> inference_pb2.Tensor: if axistags: shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)] else: shape = [inference_pb2.NamedInt(size=dim) for dim in array.shape] - return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array)) + return inference_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array)) def xarray_to_pb_tensor(tensor_id: str, array: xr.DataArray) -> inference_pb2.Tensor: