Skip to content

Commit

Permalink
include tensor id when converting numpy to pb tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Oct 15, 2024
1 parent 617684e commit 20bcd83
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from tiktorch.proto import inference_pb2


def _numpy_to_pb_tensor(arr):
def _numpy_to_pb_tensor(arr, tensor_id: str = "dummy_tensor_name"):
"""
Makes sure that tensor was serialized/deserialized
"""
tensor = numpy_to_pb_tensor(arr)
tensor = numpy_to_pb_tensor(tensor_id, arr)
parsed = inference_pb2.Tensor()
parsed.ParseFromString(tensor.SerializeToString())
return parsed
Expand Down
4 changes: 2 additions & 2 deletions tiktorch/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def sample_to_pb_tensors(sample: Sample) -> List[inference_pb2.Tensor]:
return [xarray_to_pb_tensor(tensor_id, res_tensor.data) for tensor_id, res_tensor in sample.members.items()]


def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor:
def numpy_to_pb_tensor(tensor_id: str, array: np.ndarray, axistags=None) -> inference_pb2.Tensor:
if axistags:
shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)]
else:
shape = [inference_pb2.NamedInt(size=dim) for dim in array.shape]
return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array))
return inference_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array))


def xarray_to_pb_tensor(tensor_id: str, array: xr.DataArray) -> inference_pb2.Tensor:
Expand Down

0 comments on commit 20bcd83

Please sign in to comment.