From feea07cf7edbf77e577c4cfee540dad825e16d70 Mon Sep 17 00:00:00 2001 From: Maksim Novikov Date: Thu, 28 Mar 2019 13:47:41 +0100 Subject: [PATCH] Convert CRLF to LF in repo. Enforce correct line endings --- .editorconfig | 16 + .gitattributes | 1 + tests/handler/test_inference.py | 70 ++-- tiktorch/handler/inference.py | 298 ++++++++-------- tiktorch/tiktypes.py | 578 ++++++++++++++++---------------- 5 files changed, 490 insertions(+), 473 deletions(-) create mode 100644 .editorconfig create mode 100644 .gitattributes diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..1afe1baf --- /dev/null +++ b/.editorconfig @@ -0,0 +1,16 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +tab_width = 8 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.{yml,yaml}] +indent_size = 2 + +[Makefile] +indent_style = tab diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..6313b56c --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto eol=lf diff --git a/tests/handler/test_inference.py b/tests/handler/test_inference.py index 4905aabf..ae6d9b0c 100644 --- a/tests/handler/test_inference.py +++ b/tests/handler/test_inference.py @@ -1,37 +1,37 @@ -import torch - -from torch import multiprocessing as mp - -from tiktorch.handler.inference import IInference, InferenceProcess, run -from tiktorch.tiktypes import TikTensor -from tiktorch.rpc.mp import MPClient, Shutdown - -from tests.data.tiny_models import TinyConvNet2d - - -def test_inference(tiny_model_2d): - config = tiny_model_2d["config"] - in_channels = config["input_channels"] - model = TinyConvNet2d(in_channels=in_channels) - inference = InferenceProcess(config=config, model=model) - data = TikTensor(torch.zeros(in_channels, 15, 15), (0,)) - pred = inference.forward(data) - assert isinstance(pred.result(), TikTensor) - try: - inference.shutdown() - except Shutdown: - pass - - -def test_inference_in_proc(tiny_model_2d, log_queue): - config = tiny_model_2d["config"] - in_channels = config["input_channels"] - model = TinyConvNet2d(in_channels=in_channels) - handler_conn, inference_conn = mp.Pipe() - p = mp.Process(target=run, kwargs={"conn": inference_conn, "model": model, "config": config, "log_queue": log_queue}) - p.start() - client = MPClient(IInference(), handler_conn) - data = TikTensor(torch.zeros(in_channels, 15, 15), (0,)) - f = client.forward(data) +import torch + +from torch import multiprocessing as mp + +from tiktorch.handler.inference import IInference, InferenceProcess, run +from tiktorch.tiktypes import TikTensor +from tiktorch.rpc.mp import MPClient, Shutdown + +from tests.data.tiny_models import TinyConvNet2d + + +def test_inference(tiny_model_2d): + config = tiny_model_2d["config"] + in_channels = config["input_channels"] + model = TinyConvNet2d(in_channels=in_channels) + inference = InferenceProcess(config=config, model=model) + data = TikTensor(torch.zeros(in_channels, 15, 15), (0,)) + pred = inference.forward(data) + assert isinstance(pred.result(), TikTensor) + try: + inference.shutdown() + except Shutdown: + pass + + +def test_inference_in_proc(tiny_model_2d, log_queue): + config = tiny_model_2d["config"] + in_channels = config["input_channels"] + model = TinyConvNet2d(in_channels=in_channels) + handler_conn, inference_conn = mp.Pipe() + p = mp.Process(target=run, kwargs={"conn": inference_conn, "model": model, "config": config, "log_queue": log_queue}) + p.start() + client = MPClient(IInference(), handler_conn) + data = TikTensor(torch.zeros(in_channels, 15, 15), (0,)) + f = client.forward(data) f.result() client.shutdown() diff --git a/tiktorch/handler/inference.py b/tiktorch/handler/inference.py index 4297080f..361e784b 100644 --- a/tiktorch/handler/inference.py +++ b/tiktorch/handler/inference.py @@ -1,33 +1,33 @@ -import logging -import os -import queue -import torch.nn -import threading +import logging +import os +import queue +import torch.nn +import threading import multiprocessing as mp - -from concurrent.futures import ThreadPoolExecutor, Future -from multiprocessing.connection import Connection -from typing import Any, List, Generic, Iterator, Iterable, Sequence, TypeVar, Mapping, Callable, Dict, Optional, Tuple - -from .constants import SHUTDOWN, SHUTDOWN_ANSWER, REPORT_EXCEPTION, REQUEST_FOR_DEVICES -from tiktorch.rpc import RPCInterface, exposed, Shutdown -from tiktorch.rpc.mp import MPServer -from tiktorch.tiktypes import TikTensor, TikTensorBatch + +from concurrent.futures import ThreadPoolExecutor, Future +from multiprocessing.connection import Connection +from typing import Any, List, Generic, Iterator, Iterable, Sequence, TypeVar, Mapping, Callable, Dict, Optional, Tuple + +from .constants import SHUTDOWN, SHUTDOWN_ANSWER, REPORT_EXCEPTION, REQUEST_FOR_DEVICES +from tiktorch.rpc import RPCInterface, exposed, Shutdown +from tiktorch.rpc.mp import MPServer +from tiktorch.tiktypes import TikTensor, TikTensorBatch from tiktorch import log - - -class IInference(RPCInterface): - @exposed - def set_devices(self, device_names: Sequence[str]): - raise NotImplementedError() - - @exposed - def shutdown(self): - raise NotImplementedError() - - @exposed - def forward(self, data: TikTensorBatch): - raise NotImplementedError() + + +class IInference(RPCInterface): + @exposed + def set_devices(self, device_names: Sequence[str]): + raise NotImplementedError() + + @exposed + def shutdown(self): + raise NotImplementedError() + + @exposed + def forward(self, data: TikTensorBatch): + raise NotImplementedError() def run(conn: Connection, config: dict, model: torch.nn.Module, log_queue: Optional[mp.Queue] = None): @@ -37,125 +37,125 @@ def run(conn: Connection, config: dict, model: torch.nn.Module, log_queue: Optio srv.listen() -class InferenceProcess(IInference): - """ - Process for neural network inference - """ - - name = "tiktorch.InferenceProcess" - - def __init__(self, config: dict, model: torch.nn.Module) -> None: - self.logger = logging.getLogger(self.name) - self.logger.info("Starting") - self.config = config - self.training_model = model - self.model = model.__class__() - self.model.eval() - self._shutdown_event = threading.Event() - - self.devices = [] - self._forward_queue = queue.Queue() - self.batch_size: int = config.get("inference_batch_size", None) - if self.batch_size is None: - self.batch_size = 1 - self.increase_batch_size = True - else: - self.increase_batch_size = False - - self.forward_thread = threading.Thread(target=self._forward_worker) - self.forward_thread.start() - - def _forward_worker(self) -> None: - while not self._shutdown_event.is_set(): - data_batch, fut_batch = [], [] - while not self._forward_queue.empty() and len(data_batch) < self.batch_size: - data, fut = self._forward_queue.get() - data_batch.append(data) - fut_batch.append(fut) - - if data_batch: - self._forward(TikTensorBatch(data_batch), fut_batch) - - def shutdown(self) -> None: - self._shutdown_event.set() - self.forward_thread.join() - self.logger.debug("Shutdown complete") - raise Shutdown - - def set_devices(self, device_names: Sequence[str]): - raise NotImplementedError() - # todo: with lock - # torch.cuda.empty_cache() - # os.environ['CUDA_VISIBLE_DEVICES'] = ... - - def update_inference_model(self): - self.model.load_state_dict(self.training_model.state_dict()) - assert not self.model.training, "Model switched back to training mode somehow???" - - def forward(self, data: TikTensor) -> Future: - fut = Future() - self._forward_queue.put((data, fut)) - return fut - - def _forward(self, data: TikTensorBatch, fut: List[Future]) -> None: - """ - :param data: input data to neural network - :return: predictions - """ - keys: List = [d.id for d in data] - data: List[torch.Tensor] = data.as_torch() - - # TODO: fixT return data - - self.logger.debug("this is forward") - - start = 0 - last_batch_size = self.batch_size - - def create_end_generator(start, end, batch_size): - for batch_end in range(start + batch_size, end, batch_size): - yield batch_end - - yield end - - end_generator = create_end_generator(start, len(keys), self.batch_size) - while start < len(keys): - # todo: callback - end = next(end_generator) - self.update_inference_model() - try: - with torch.no_grad(): - pred = self.model(torch.stack(data[start:end])) - except Exception as e: - if self.batch_size > last_batch_size: - self.logger.info( - "forward pass with batch size %d threw exception %s. Using previous batch size %d again.", - self.batch_size, - e, - last_batch_size, - ) - self.batch_size = last_batch_size - self.increase_batch_size = False - else: - last_batch_size = self.batch_size - self.batch_size //= 2 - if self.batch_size == 0: - self.logger.error("Forward pass failed. Processed %d/%d", start, len(keys)) - break - - self.increase_batch_size = True - self.logger.info( - "forward pass with batch size %d threw exception %s. Trying again with smaller batch_size %d", - last_batch_size, - e, - self.batch_size, - ) - end_generator = create_end_generator(start, len(keys), self.batch_size) - else: - for i in range(start, end): - fut[i].set_result(TikTensor(pred[i], id_=keys[i])) - start = end - last_batch_size = self.batch_size - if self.increase_batch_size: - self.batch_size += 1 +class InferenceProcess(IInference): + """ + Process for neural network inference + """ + + name = "tiktorch.InferenceProcess" + + def __init__(self, config: dict, model: torch.nn.Module) -> None: + self.logger = logging.getLogger(self.name) + self.logger.info("Starting") + self.config = config + self.training_model = model + self.model = model.__class__() + self.model.eval() + self._shutdown_event = threading.Event() + + self.devices = [] + self._forward_queue = queue.Queue() + self.batch_size: int = config.get("inference_batch_size", None) + if self.batch_size is None: + self.batch_size = 1 + self.increase_batch_size = True + else: + self.increase_batch_size = False + + self.forward_thread = threading.Thread(target=self._forward_worker) + self.forward_thread.start() + + def _forward_worker(self) -> None: + while not self._shutdown_event.is_set(): + data_batch, fut_batch = [], [] + while not self._forward_queue.empty() and len(data_batch) < self.batch_size: + data, fut = self._forward_queue.get() + data_batch.append(data) + fut_batch.append(fut) + + if data_batch: + self._forward(TikTensorBatch(data_batch), fut_batch) + + def shutdown(self) -> None: + self._shutdown_event.set() + self.forward_thread.join() + self.logger.debug("Shutdown complete") + raise Shutdown + + def set_devices(self, device_names: Sequence[str]): + raise NotImplementedError() + # todo: with lock + # torch.cuda.empty_cache() + # os.environ['CUDA_VISIBLE_DEVICES'] = ... + + def update_inference_model(self): + self.model.load_state_dict(self.training_model.state_dict()) + assert not self.model.training, "Model switched back to training mode somehow???" + + def forward(self, data: TikTensor) -> Future: + fut = Future() + self._forward_queue.put((data, fut)) + return fut + + def _forward(self, data: TikTensorBatch, fut: List[Future]) -> None: + """ + :param data: input data to neural network + :return: predictions + """ + keys: List = [d.id for d in data] + data: List[torch.Tensor] = data.as_torch() + + # TODO: fixT return data + + self.logger.debug("this is forward") + + start = 0 + last_batch_size = self.batch_size + + def create_end_generator(start, end, batch_size): + for batch_end in range(start + batch_size, end, batch_size): + yield batch_end + + yield end + + end_generator = create_end_generator(start, len(keys), self.batch_size) + while start < len(keys): + # todo: callback + end = next(end_generator) + self.update_inference_model() + try: + with torch.no_grad(): + pred = self.model(torch.stack(data[start:end])) + except Exception as e: + if self.batch_size > last_batch_size: + self.logger.info( + "forward pass with batch size %d threw exception %s. Using previous batch size %d again.", + self.batch_size, + e, + last_batch_size, + ) + self.batch_size = last_batch_size + self.increase_batch_size = False + else: + last_batch_size = self.batch_size + self.batch_size //= 2 + if self.batch_size == 0: + self.logger.error("Forward pass failed. Processed %d/%d", start, len(keys)) + break + + self.increase_batch_size = True + self.logger.info( + "forward pass with batch size %d threw exception %s. Trying again with smaller batch_size %d", + last_batch_size, + e, + self.batch_size, + ) + end_generator = create_end_generator(start, len(keys), self.batch_size) + else: + for i in range(start, end): + fut[i].set_result(TikTensor(pred[i], id_=keys[i])) + start = end + last_batch_size = self.batch_size + if self.increase_batch_size: + self.batch_size += 1 end_generator = create_end_generator(start, len(keys), self.batch_size) diff --git a/tiktorch/tiktypes.py b/tiktorch/tiktypes.py index 71ae045b..90d9afe6 100644 --- a/tiktorch/tiktypes.py +++ b/tiktorch/tiktypes.py @@ -1,290 +1,290 @@ -""" -Types defining interop between processes on the server -""" -import torch - -from typing import List, Tuple, Optional, Union, Sequence - - -class TikTensor: - """ - Containter for pytorch tensor to transfer additional properties - e.g. position of array in dataset (id_) - """ - - def __init__(self, tensor: torch.Tensor, id_: Optional[Tuple[int, ...]] = None, label: Optional[torch.Tensor] = None) -> None: - self._torch = tensor - self.id = id_ - self.label = label - - def as_torch(self, with_label=False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if with_label: - return self._torch, self.label - else: - return self._torch - - @property - def dtype(self): - return self._torch.dtype - - @property - def shape(self): - return self._torch.shape - - -class TikTensorBatch: - """ - Batch of TikTensor - """ - - def __init__(self, tensors: List[TikTensor]): - assert all([isinstance(t, TikTensor) for t in tensors]) - self._tensors = tensors - - def tensor_metas(self): - return [{"dtype": t.dtype.str, "shape": t.shape, "id": t.id} for t in self._tensors] - - def __len__(self): - return len(self._tensors) - - def __iter__(self): - for item in self._tensors: - yield item - - def as_torch(self, with_label=False) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: - return [t.as_torch(with_label=with_label) for t in self._tensors] - - @property - def ids(self) -> List[Tuple[int]]: - return [t.id for t in self._tensors] - - -class PointAndBatchPointBase: - order: str = "" - b: int - t: int - c: int - z: int - y: int - x: int - - def __init__(self, b: int = 0, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): - self.b = b - self.t = t - self.c = c - self.z = z - self.y = y - self.x = x - super().__init__() - - def __getitem__(self, key: Union[int, str]): - if isinstance(key, int): - key = self.order[key] - - return getattr(self, key) - - def __setitem__(self, key: Union[int, str], item: int): - if isinstance(key, int): - key = self.order[key] - - return setattr(self, key, item) - - def __repr__(self): - return f"{self.__class__.__name__}({', '.join([f'{a}:{getattr(self, a)}' for a in self.order])})" - - def __len__(self): - return len(self.order) - - def __iter__(self): - for a in self.order: - yield getattr(self, a) - - def __bool__(self): - return bool(len(self)) - - @staticmethod - def upcast_dim( - a: "PointAndBatchPointBase", b: "PointAndBatchPointBase" - ) -> Tuple["PointAndBatchPointBase", "PointAndBatchPointBase"]: - space_dim_a = len(a) - 1 - if a.__class__.__name__.startswith("Batch"): - space_dim_a -= 1 - - space_dim_b = len(b) - 1 - if b.__class__.__name__.startswith("Batch"): - space_dim_b -= 1 - - if space_dim_a < space_dim_b: - a = a.as_d(space_dim_b) - elif space_dim_a > space_dim_b: - b = b.as_d(space_dim_a) - - return a, b - - def __lt__(self, other): - me, other = self.upcast_dim(self, other) - return all([m < o for m, o in zip(me, other)]) - - def __gt__(self, other): - me, other = self.upcast_dim(self, other) - return all([m > o for m, o in zip(me, other)]) - - def __eq__(self, other): - me, other = self.upcast_dim(self, other) - return all([m == o for m, o in zip(me, other)]) - - def __le__(self, other): - me, other = self.upcast_dim(self, other) - return all([m <= o for m, o in zip(me, other)]) - - def __ge__(self, other): - me, other = self.upcast_dim(self, other) - return all([m >= o for m, o in zip(me, other)]) - - def as_d(self, d: int) -> "PointAndBatchPointBase": - """ - :param d: number of spacial dimensions - """ - if d == 2: - return self.as_2d() - elif d == 3: - return self.as_3d() - elif d == 4: - return self.as_4d() - else: - raise NotImplementedError(f"Unclear number of dimensions d={d}") - - def as_2d(self): - raise NotImplementedError("To be implemented in subclass!") - - def as_3d(self): - raise NotImplementedError("To be implemented in subclass!") - - def as_4d(self): - raise NotImplementedError("To be implemented in subclass!") - - def drop_batch(self): - raise NotImplementedError("To be implemented in subclass!") - -class BatchPointBase(PointAndBatchPointBase): - def __init__(self, b: int = 0, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): - super().__init__(b=b, t=t, c=c, z=z, y=y, x=x) - - @staticmethod - def from_spacetime( - b: int, c: int, spacetime: Sequence[int] - ) -> Union["BatchPoint2D", "BatchPoint3D", "BatchPoint4D"]: - """ - :return: a suitable BatchPoint instance - :raises: ValueError - """ - if len(spacetime) == 4: - t, z, y, x = spacetime - return BatchPoint4D(b, t, c, z, y, x) - elif len(spacetime) == 3: - return BatchPoint3D(b, c, *spacetime) - elif len(spacetime) == 2: - return BatchPoint2D(b, c, *spacetime) - else: - raise ValueError(f"Uninterpretable spacetime: {spacetime}") - - def as_2d(self) -> "BatchPoint2D": - return BatchPoint2D(b=self.b, c=self.c, y=self.y, x=self.x) - - def as_3d(self) -> "BatchPoint3D": - return BatchPoint3D(b=self.b, c=self.c, z=self.z, y=self.y, x=self.x) - - def as_4d(self) -> "BatchPoint4D": - return BatchPoint4D(b=self.b, t=self.t, c=self.c, z=self.z, y=self.y, x=self.x) - - -class BatchPoint2D(BatchPointBase): - order: str = "bcyx" - - def __init__(self, b: int = 0, c: int = 0, y: int = 0, x: int = 0): - super().__init__(b=b, c=c, y=y, x=x) - - def drop_batch(self) -> "Point2D": - return Point2D(c=self.c, y=self.y, x=self.x) - - -class BatchPoint3D(BatchPointBase): - order: str = "bczyx" - - def __init__(self, b: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): - super().__init__(b=b, c=c, z=z, y=y, x=x) - - def drop_batch(self) -> "Point3D": - return Point3D(c=self.c, z=self.z, y=self.y, x=self.x) - - -class BatchPoint4D(BatchPointBase): - order: str = "btczyx" - - def __init__(self, b: int = 0, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): - super().__init__(b=b, t=t, c=c, z=z, y=y, x=x) - - def drop_batch(self): - return Point4D(t=self.t, c=self.c, z=self.z, y=self.y, x=self.x) - - -class PointBase(PointAndBatchPointBase): - def __init__(self, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): - super().__init__(t=t, c=c, z=z, y=y, x=x) - - @staticmethod - def from_spacetime(cls, c: int, spacetime: Sequence[int]) -> Union["Point2D", "Point3D", "Point4D"]: - """ - :return: a suitable BatchPoint instance - :raises: ValueError - """ - if len(spacetime) == 4: - t, z, y, x = spacetime - return Point4D(t, c, z, y, x) - elif len(spacetime) == 3: - return Point3D(c, *spacetime) - elif len(spacetime) == 2: - return Point2D(c, *spacetime) - else: - raise ValueError(f"Uninterpretable spacetime: {spacetime}") - - def as_2d(self) -> "Point2D": - return Point2D(c=self.c, y=self.y, x=self.x) - - def as_3d(self) -> "Point3D": - return Point3D(c=self.c, z=self.z, y=self.y, x=self.x) - - def as_4d(self) -> "Point4D": - return Point4D(t=self.t, c=self.c, z=self.z, y=self.y, x=self.x) - - def drop_batch(self) -> "PointBase": - return self - -class Point2D(PointBase): - order: str = "cyx" - - def __init__(self, c: int = 0, y: int = 0, x: int = 0): - super().__init__(c=c, y=y, x=x) - - def add_batch(self) -> BatchPoint2D: - return BatchPoint2D(c=self.c, y=self.y, x=self.x) - - -class Point3D(PointBase): - order: str = "czyx" - - def __init__(self, c: int = 0, z: int = 0, y: int = 0, x: int = 0): - super().__init__(c=c, z=z, y=y, x=x) - - def add_batch(self) -> BatchPoint3D: - return BatchPoint3D(c=self.c, z=self.z, y=self.y, x=self.x) - - -class Point4D(PointBase): - order: str = "tczyx" - - def __init__(self, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): - super().__init__(t=t, c=c, z=z, y=y, x=x) - - def add_batch(self) -> BatchPoint4D: +""" +Types defining interop between processes on the server +""" +import torch + +from typing import List, Tuple, Optional, Union, Sequence + + +class TikTensor: + """ + Containter for pytorch tensor to transfer additional properties + e.g. position of array in dataset (id_) + """ + + def __init__(self, tensor: torch.Tensor, id_: Optional[Tuple[int, ...]] = None, label: Optional[torch.Tensor] = None) -> None: + self._torch = tensor + self.id = id_ + self.label = label + + def as_torch(self, with_label=False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if with_label: + return self._torch, self.label + else: + return self._torch + + @property + def dtype(self): + return self._torch.dtype + + @property + def shape(self): + return self._torch.shape + + +class TikTensorBatch: + """ + Batch of TikTensor + """ + + def __init__(self, tensors: List[TikTensor]): + assert all([isinstance(t, TikTensor) for t in tensors]) + self._tensors = tensors + + def tensor_metas(self): + return [{"dtype": t.dtype.str, "shape": t.shape, "id": t.id} for t in self._tensors] + + def __len__(self): + return len(self._tensors) + + def __iter__(self): + for item in self._tensors: + yield item + + def as_torch(self, with_label=False) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + return [t.as_torch(with_label=with_label) for t in self._tensors] + + @property + def ids(self) -> List[Tuple[int]]: + return [t.id for t in self._tensors] + + +class PointAndBatchPointBase: + order: str = "" + b: int + t: int + c: int + z: int + y: int + x: int + + def __init__(self, b: int = 0, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): + self.b = b + self.t = t + self.c = c + self.z = z + self.y = y + self.x = x + super().__init__() + + def __getitem__(self, key: Union[int, str]): + if isinstance(key, int): + key = self.order[key] + + return getattr(self, key) + + def __setitem__(self, key: Union[int, str], item: int): + if isinstance(key, int): + key = self.order[key] + + return setattr(self, key, item) + + def __repr__(self): + return f"{self.__class__.__name__}({', '.join([f'{a}:{getattr(self, a)}' for a in self.order])})" + + def __len__(self): + return len(self.order) + + def __iter__(self): + for a in self.order: + yield getattr(self, a) + + def __bool__(self): + return bool(len(self)) + + @staticmethod + def upcast_dim( + a: "PointAndBatchPointBase", b: "PointAndBatchPointBase" + ) -> Tuple["PointAndBatchPointBase", "PointAndBatchPointBase"]: + space_dim_a = len(a) - 1 + if a.__class__.__name__.startswith("Batch"): + space_dim_a -= 1 + + space_dim_b = len(b) - 1 + if b.__class__.__name__.startswith("Batch"): + space_dim_b -= 1 + + if space_dim_a < space_dim_b: + a = a.as_d(space_dim_b) + elif space_dim_a > space_dim_b: + b = b.as_d(space_dim_a) + + return a, b + + def __lt__(self, other): + me, other = self.upcast_dim(self, other) + return all([m < o for m, o in zip(me, other)]) + + def __gt__(self, other): + me, other = self.upcast_dim(self, other) + return all([m > o for m, o in zip(me, other)]) + + def __eq__(self, other): + me, other = self.upcast_dim(self, other) + return all([m == o for m, o in zip(me, other)]) + + def __le__(self, other): + me, other = self.upcast_dim(self, other) + return all([m <= o for m, o in zip(me, other)]) + + def __ge__(self, other): + me, other = self.upcast_dim(self, other) + return all([m >= o for m, o in zip(me, other)]) + + def as_d(self, d: int) -> "PointAndBatchPointBase": + """ + :param d: number of spacial dimensions + """ + if d == 2: + return self.as_2d() + elif d == 3: + return self.as_3d() + elif d == 4: + return self.as_4d() + else: + raise NotImplementedError(f"Unclear number of dimensions d={d}") + + def as_2d(self): + raise NotImplementedError("To be implemented in subclass!") + + def as_3d(self): + raise NotImplementedError("To be implemented in subclass!") + + def as_4d(self): + raise NotImplementedError("To be implemented in subclass!") + + def drop_batch(self): + raise NotImplementedError("To be implemented in subclass!") + +class BatchPointBase(PointAndBatchPointBase): + def __init__(self, b: int = 0, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): + super().__init__(b=b, t=t, c=c, z=z, y=y, x=x) + + @staticmethod + def from_spacetime( + b: int, c: int, spacetime: Sequence[int] + ) -> Union["BatchPoint2D", "BatchPoint3D", "BatchPoint4D"]: + """ + :return: a suitable BatchPoint instance + :raises: ValueError + """ + if len(spacetime) == 4: + t, z, y, x = spacetime + return BatchPoint4D(b, t, c, z, y, x) + elif len(spacetime) == 3: + return BatchPoint3D(b, c, *spacetime) + elif len(spacetime) == 2: + return BatchPoint2D(b, c, *spacetime) + else: + raise ValueError(f"Uninterpretable spacetime: {spacetime}") + + def as_2d(self) -> "BatchPoint2D": + return BatchPoint2D(b=self.b, c=self.c, y=self.y, x=self.x) + + def as_3d(self) -> "BatchPoint3D": + return BatchPoint3D(b=self.b, c=self.c, z=self.z, y=self.y, x=self.x) + + def as_4d(self) -> "BatchPoint4D": + return BatchPoint4D(b=self.b, t=self.t, c=self.c, z=self.z, y=self.y, x=self.x) + + +class BatchPoint2D(BatchPointBase): + order: str = "bcyx" + + def __init__(self, b: int = 0, c: int = 0, y: int = 0, x: int = 0): + super().__init__(b=b, c=c, y=y, x=x) + + def drop_batch(self) -> "Point2D": + return Point2D(c=self.c, y=self.y, x=self.x) + + +class BatchPoint3D(BatchPointBase): + order: str = "bczyx" + + def __init__(self, b: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): + super().__init__(b=b, c=c, z=z, y=y, x=x) + + def drop_batch(self) -> "Point3D": + return Point3D(c=self.c, z=self.z, y=self.y, x=self.x) + + +class BatchPoint4D(BatchPointBase): + order: str = "btczyx" + + def __init__(self, b: int = 0, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): + super().__init__(b=b, t=t, c=c, z=z, y=y, x=x) + + def drop_batch(self): + return Point4D(t=self.t, c=self.c, z=self.z, y=self.y, x=self.x) + + +class PointBase(PointAndBatchPointBase): + def __init__(self, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): + super().__init__(t=t, c=c, z=z, y=y, x=x) + + @staticmethod + def from_spacetime(cls, c: int, spacetime: Sequence[int]) -> Union["Point2D", "Point3D", "Point4D"]: + """ + :return: a suitable BatchPoint instance + :raises: ValueError + """ + if len(spacetime) == 4: + t, z, y, x = spacetime + return Point4D(t, c, z, y, x) + elif len(spacetime) == 3: + return Point3D(c, *spacetime) + elif len(spacetime) == 2: + return Point2D(c, *spacetime) + else: + raise ValueError(f"Uninterpretable spacetime: {spacetime}") + + def as_2d(self) -> "Point2D": + return Point2D(c=self.c, y=self.y, x=self.x) + + def as_3d(self) -> "Point3D": + return Point3D(c=self.c, z=self.z, y=self.y, x=self.x) + + def as_4d(self) -> "Point4D": + return Point4D(t=self.t, c=self.c, z=self.z, y=self.y, x=self.x) + + def drop_batch(self) -> "PointBase": + return self + +class Point2D(PointBase): + order: str = "cyx" + + def __init__(self, c: int = 0, y: int = 0, x: int = 0): + super().__init__(c=c, y=y, x=x) + + def add_batch(self) -> BatchPoint2D: + return BatchPoint2D(c=self.c, y=self.y, x=self.x) + + +class Point3D(PointBase): + order: str = "czyx" + + def __init__(self, c: int = 0, z: int = 0, y: int = 0, x: int = 0): + super().__init__(c=c, z=z, y=y, x=x) + + def add_batch(self) -> BatchPoint3D: + return BatchPoint3D(c=self.c, z=self.z, y=self.y, x=self.x) + + +class Point4D(PointBase): + order: str = "tczyx" + + def __init__(self, t: int = 0, c: int = 0, z: int = 0, y: int = 0, x: int = 0): + super().__init__(t=t, c=c, z=z, y=y, x=x) + + def add_batch(self) -> BatchPoint4D: return BatchPoint4D(t=self.t, c=self.c, z=self.z, y=self.y, x=self.x) \ No newline at end of file