From c9873b05c6786e078bd28b2c24184fa070dd5ff9 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Mon, 17 Oct 2022 13:59:50 +0200 Subject: [PATCH 01/14] [math] Padding fixes This mostly concerns PyTorch. * Add unit tests --- phi/math/_ops.py | 4 ++- phi/math/extrapolation.py | 2 ++ phi/torch/_torch_backend.py | 42 ++++++++++++++++--------- tests/commit/math/test_extrapolation.py | 31 +++++++++++++++++- 4 files changed, 62 insertions(+), 17 deletions(-) diff --git a/phi/math/_ops.py b/phi/math/_ops.py index ccec56e5b..de5b1d8aa 100644 --- a/phi/math/_ops.py +++ b/phi/math/_ops.py @@ -710,7 +710,7 @@ def concat_tensor(values: tuple or list, dim: str) -> Tensor: return NativeTensor(concatenated, broadcast_shape.with_sizes(backend.staticshape(concatenated))) -def pad(value: Tensor, widths: dict, mode: 'e_.Extrapolation', **kwargs) -> Tensor: +def pad(value: Tensor, widths: dict, mode: 'e_.Extrapolation' or Tensor or Number, **kwargs) -> Tensor: """ Pads a tensor along the specified dimensions, determining the added values using the given extrapolation. Unlike `Extrapolation.pad()`, this function can handle negative widths which slice off outer values. @@ -720,6 +720,7 @@ def pad(value: Tensor, widths: dict, mode: 'e_.Extrapolation', **kwargs) -> Tens widths: `dict` mapping dimension name (`str`) to `(lower, upper)` where `lower` and `upper` are `int` that can be positive (pad), negative (slice) or zero (pass). mode: `Extrapolation` used to determine values added from positive `widths`. + Assumes constant extrapolation if given a number or `Tensor` instead. kwargs: Additional padding arguments. These are ignored by the standard extrapolations defined in `phi.math.extrapolation` but can be used to pass additional contextual information to custom extrapolations. Grid classes from `phi.field` will pass the argument `bounds: Box`. @@ -727,6 +728,7 @@ def pad(value: Tensor, widths: dict, mode: 'e_.Extrapolation', **kwargs) -> Tens Returns: Padded `Tensor` """ + mode = mode if isinstance(mode, e_.Extrapolation) else e_.ConstantExtrapolation(mode) has_negative_widths = any(w[0] < 0 or w[1] < 0 for w in widths.values()) slices = None if has_negative_widths: diff --git a/phi/math/extrapolation.py b/phi/math/extrapolation.py index a5ad473f2..166b4eec8 100644 --- a/phi/math/extrapolation.py +++ b/phi/math/extrapolation.py @@ -242,6 +242,8 @@ def pad(self, value: Tensor, widths: dict, **kwargs): ordered_pad_widths = order_by_shape(value.shape, widths, default=(0, 0)) backend = choose_backend(native) result_tensor = backend.pad(native, ordered_pad_widths, 'constant', pad_value.native()) + if result_tensor is NotImplemented: + return Extrapolation.pad(self, value, widths, **kwargs) new_shape = value.shape.with_sizes(backend.staticshape(result_tensor)) return NativeTensor(result_tensor, new_shape) elif isinstance(value, CollapsedTensor): diff --git a/phi/torch/_torch_backend.py b/phi/torch/_torch_backend.py index c872a7682..f6e3398f5 100644 --- a/phi/torch/_torch_backend.py +++ b/phi/torch/_torch_backend.py @@ -218,30 +218,38 @@ def pad(self, value, pad_width, mode='constant', constant_values=0): mode = {'constant': 'constant', 'reflect': 'reflect', 'boundary': 'replicate', 'periodic': 'circular'}.get(mode, None) if not mode: return NotImplemented - # transpose for leading zero-pad: [(0, 0), (0, 0), ...] + # for PyTorch, we have to reshape value such that the outer 2 dimensions are not padded. ndims = self.ndims(value) - if ndims > 2 and pad_width[0] == pad_width[1] == (0, 0): - reordered = value - pad_width_reordered = pad_width[2:] - undo_transform = lambda x: x - elif ndims > 2 and pad_width[0] == (0, 0) and self.ndims(value) < 5: - reordered = torch.unsqueeze(value, 0) - pad_width_reordered = pad_width[1:] - undo_transform = lambda x: torch.squeeze(x, 0) - elif ndims < 4: - reordered = torch.unsqueeze(torch.unsqueeze(value, 0), 0) - pad_width_reordered = pad_width + no_pad_dims = [i for i in range(ndims) if pad_width[i] == (0, 0)] + pad_dims = [i for i in range(ndims) if pad_width[i] != (0, 0)] + if not pad_dims: + return value + if len(pad_dims) > 3: + return NotImplemented + value = torch.permute(value, no_pad_dims + pad_dims) + if len(no_pad_dims) == 0: + value = torch.unsqueeze(torch.unsqueeze(value, 0), 0) undo_transform = lambda x: torch.squeeze(torch.squeeze(x, 0), 0) + elif len(no_pad_dims) == 1: + value = torch.unsqueeze(value, 0) + undo_transform = lambda x: torch.squeeze(x, 0) + elif len(no_pad_dims) == 2: + undo_transform = lambda x: x else: - raise NotImplementedError() # TODO transpose to get (0, 0) to the front + old_shape = value.shape + value = self.reshape(value, (1, np.prod([value.shape[i] for i in range(len(no_pad_dims))]), *value.shape[len(no_pad_dims):])) + undo_transform = lambda x: x.view(*[old_shape[i] for i in range(len(no_pad_dims))], *x.shape[2:]) + pad_width_reordered = [pad_width[i] for i in pad_dims] pad_width_spatial = [item for sublist in reversed(pad_width_reordered) for item in sublist] # flatten try: constant_values = self.dtype(value).kind(constant_values) - result = torchf.pad(reordered, pad_width_spatial, mode, value=constant_values) # supports 3D to 5D (2 + 1D to 3D) + result = torchf.pad(value, pad_width_spatial, mode, value=constant_values) # supports 3D to 5D (batch, channel, 1D to 3D) except RuntimeError as err: warnings.warn(f"PyTorch error {err}", RuntimeWarning) return NotImplemented result = undo_transform(result) + inv_perm = tuple(np.argsort(no_pad_dims + pad_dims)) + result = torch.permute(result, inv_perm) return result def grid_sample(self, grid, coordinates, extrapolation: str): @@ -270,7 +278,11 @@ def grid_sample(self, grid, coordinates, extrapolation: str): return result def reshape(self, value, shape): - return torch.reshape(self.as_tensor(value), shape) + value = self.as_tensor(value) + if value.is_contiguous(): + return value.view(*shape) + else: + return torch.reshape(value, shape) def sum(self, value, axis=None, keepdims=False): if axis is None: diff --git a/tests/commit/math/test_extrapolation.py b/tests/commit/math/test_extrapolation.py index 3da6ad55b..dbfefba72 100644 --- a/tests/commit/math/test_extrapolation.py +++ b/tests/commit/math/test_extrapolation.py @@ -66,7 +66,7 @@ def test_pad(self): # TypeError('__bool__ should return bool, returned NotImplementedType') # self.assertEqual(val_out, func(val_in)) - def test_pad_tensor(self): + def test_pad_2d(self): for backend in BACKENDS: with backend: a = math.meshgrid(x=4, y=3) @@ -115,6 +115,35 @@ def test_pad_tensor(self): math.assert_close(p.x[0].y[:-1], a.x[-1]) # periodic math.assert_close(p.x[-2:].y[:-1], a.x[:2]) # periodic + def test_pad_3d(self): + for t in [ + math.ones(spatial(x=2, y=2, z=2)), + math.ones(spatial(x=2, y=2, z=2), batch(b1=2)), + math.ones(spatial(x=2, y=2, z=2), batch(b1=2, b2=2)), + math.ones(spatial(x=2, y=2, z=2), batch(b1=2, b2=2, b3=2)), + ]: + results = [] + for backend in BACKENDS: + with backend: + p = math.pad(t, {i: (1, 1) for i in 'xyz'}, 0) + results.append(p) + math.assert_close(*results) + + def test_pad_4d(self): + for t in [ + math.ones(spatial(x=2, y=2, z=2, w=2)), + math.ones(spatial(x=2, y=2, z=2, w=2), batch(b1=2)), + math.ones(spatial(x=2, y=2, z=2, w=2), batch(b1=2, b2=2)), + math.ones(spatial(x=2, y=2, z=2, w=2), batch(b1=2, b2=2, b3=2)), + ]: + results = [] + for backend in BACKENDS: + with backend: + p = math.pad(t, {i: (1, 1) for i in 'xyzw'}, 0) + results.append(p) + math.assert_close(*results) + + def test_pad_collapsed(self): a = math.zeros(spatial(b=2, x=10, y=10) & batch(batch=10)) p = math.pad(a, {'x': (1, 2)}, ZERO) From 97629e2e78a255d29114e04739941ff5cc836df4 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 16:47:47 +0200 Subject: [PATCH 02/14] [math] Improved normal/tangential extrapolation slicing * Add unit test --- phi/field/_grid.py | 6 +++++- phi/math/extrapolation.py | 19 ++++++++++++++++++- tests/commit/math/test_extrapolation.py | 6 ++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/phi/field/_grid.py b/phi/field/_grid.py index 3be2cdb8d..1f4cfee38 100644 --- a/phi/field/_grid.py +++ b/phi/field/_grid.py @@ -387,6 +387,10 @@ def __getitem__(self, item): item = slicing_dict(self, item) if not item: return self + if 'vector' in item: + selection = item['vector'] + if isinstance(selection, int): + item['vector'] = self.resolution.names[selection] values = self._values[{dim: sel for dim, sel in item.items() if dim not in self.shape.spatial}] for dim, sel in item.items(): if dim in self.shape.spatial: @@ -410,7 +414,7 @@ def __getitem__(self, item): assert selection in item_names, f"Accessing field.vector['{selection}'] failed. Item names are {item_names}." selection = item_names.index(selection) if isinstance(selection, int): - dim = self.shape.spatial.names[selection] + dim = self.resolution.names[selection] comp_cells = GridCell(self.resolution, bounds).stagger(dim, *self.extrapolation.valid_outer_faces(dim)) return CenteredGrid(values, bounds=comp_cells.bounds, extrapolation=extrapolation) else: diff --git a/phi/math/extrapolation.py b/phi/math/extrapolation.py index 166b4eec8..8f3997f4e 100644 --- a/phi/math/extrapolation.py +++ b/phi/math/extrapolation.py @@ -152,6 +152,9 @@ def shortest_distance(self, start: Tensor, end: Tensor, domain_size: Tensor): def __getitem__(self, item): return self + def _getitem_with_domain(self, item: dict, dim: str, upper_edge: bool, all_dims: tuple): + return self[item] + def __abs__(self): raise NotImplementedError(self.__class__) @@ -904,7 +907,8 @@ def transform_coordinates(self, coordinates: Tensor, shape: Shape, **kwargs) -> def __getitem__(self, item): if isinstance(item, dict): - return combine_sides(**{dim: (e1[item], e2[item]) for dim, (e1, e2) in self.ext.items()}) + all_dims = tuple(self.ext.keys()) + return combine_sides(**{dim: (e1._getitem_with_domain(item, dim, False, all_dims), e2._getitem_with_domain(item, dim, True, all_dims)) for dim, (e1, e2) in self.ext.items()}) else: dim, face = item return self.ext[dim][face] @@ -994,9 +998,22 @@ def pad_values(self, value: Tensor, width: int, dim: str, upper_edge: bool, **kw result = stack(result, value.shape.only('vector')) return result + def _getitem_with_domain(self, item: dict, dim: str, upper_edge: bool, all_dims: tuple): + if 'vector' not in item: + return self + component = item['vector'] + assert isinstance(component, str), f"Selecting a component of normal/tangential must be done by dimension name but got {component}" + if component == dim: + return self.normal + else: + return self.tangential + def __eq__(self, other): return isinstance(other, _NormalTangentialExtrapolation) and self.normal == other.normal and self.tangential == other.tangential + def __hash__(self): + return hash(self.normal) + hash(self.tangential) + def __add__(self, other): return self._op2(other, lambda e1, e2: e1 + e2) diff --git a/tests/commit/math/test_extrapolation.py b/tests/commit/math/test_extrapolation.py index dbfefba72..5852f3e29 100644 --- a/tests/commit/math/test_extrapolation.py +++ b/tests/commit/math/test_extrapolation.py @@ -201,3 +201,9 @@ def test_map(self): self.assertEqual(ext, extrapolation.map(lambda e: e, ext)) ext = combine_sides(x=PERIODIC, y=(ONE, BOUNDARY)) self.assertEqual(ext, extrapolation.map(lambda e: e, ext)) + + def test_slice_normal_tangential(self): + INFLOW_LEFT = combine_by_direction(normal=1, tangential=0) + ext = combine_sides(x=(INFLOW_LEFT, BOUNDARY), y=0) + self.assertEqual(combine_sides(x=(1, BOUNDARY), y=0), ext[{'vector': 'x'}]) + self.assertEqual(combine_sides(x=(0, BOUNDARY), y=0), ext[{'vector': 'y'}]) From 0efa1810b47f6597df849adf8433ecf41ead7499 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 18:02:45 +0200 Subject: [PATCH 03/14] [math] concat() for non-uniform tensors --- phi/math/_ops.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/phi/math/_ops.py b/phi/math/_ops.py index de5b1d8aa..5f22dfced 100644 --- a/phi/math/_ops.py +++ b/phi/math/_ops.py @@ -701,13 +701,18 @@ def inner_stack(*values): def concat_tensor(values: tuple or list, dim: str) -> Tensor: assert len(values) > 0, "concat() got empty sequence" assert isinstance(dim, str), f"dim must be a single-dimension Shape but got '{dim}' of type {type(dim)}" - broadcast_shape = merge_shapes(*[t.shape._with_item_name(dim, None).with_sizes([None] * t.shape.rank) for t in values]) - natives = [v.native(order=broadcast_shape.names) for v in values] - backend = choose_backend(*natives) - concatenated = backend.concat(natives, broadcast_shape.index(dim)) - if all([v.shape.get_item_names(dim) is not None for v in values]): - broadcast_shape = broadcast_shape._with_item_name(dim, sum([v.shape.get_item_names(dim) for v in values], ())) - return NativeTensor(concatenated, broadcast_shape.with_sizes(backend.staticshape(concatenated))) + + def inner_concat(*values): + broadcast_shape = merge_shapes(*[t.shape._with_item_name(dim, None).with_sizes([None] * t.shape.rank) for t in values]) + natives = [v.native(order=broadcast_shape.names) for v in values] + backend = choose_backend(*natives) + concatenated = backend.concat(natives, broadcast_shape.index(dim)) + if all([v.shape.get_item_names(dim) is not None for v in values]): + broadcast_shape = broadcast_shape._with_item_name(dim, sum([v.shape.get_item_names(dim) for v in values], ())) + return NativeTensor(concatenated, broadcast_shape.with_sizes(backend.staticshape(concatenated))) + + result = broadcast_op(inner_concat, values) + return result def pad(value: Tensor, widths: dict, mode: 'e_.Extrapolation' or Tensor or Number, **kwargs) -> Tensor: @@ -888,6 +893,7 @@ def broadcast_op(operation: Callable, dim = next(iter(iter_dims)) dim_type = None size = None + item_names = None unstacked = [] for tensor in tensors: if dim in tensor.shape.names: @@ -899,6 +905,8 @@ def broadcast_op(operation: Callable, else: assert size == len(unstacked_tensor) assert dim_type == tensor.shape.get_type(dim) + if item_names is None: + item_names = tensor.shape.get_item_names(dim) else: unstacked.append(tensor) result_unstacked = [] @@ -906,7 +914,7 @@ def broadcast_op(operation: Callable, gathered = [t[i] if isinstance(t, tuple) else t for t in unstacked] result_unstacked.append(broadcast_op(operation, gathered, iter_dims=set(iter_dims) - {dim})) if not no_return: - return TensorStack(result_unstacked, Shape((None,), (dim,), (dim_type,), (None,))) + return TensorStack(result_unstacked, Shape((None,), (dim,), (dim_type,), (item_names,))) def where(condition: Tensor or float or int, value_true: Tensor or float or int, value_false: Tensor or float or int): From 3e8ec212469e3bcf11f1437a087b6711d8b2fc58 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 18:41:29 +0200 Subject: [PATCH 04/14] [field] Fix stagger() pad widths --- phi/field/_field_math.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/phi/field/_field_math.py b/phi/field/_field_math.py index ac7c2ecac..a3951c8e7 100644 --- a/phi/field/_field_math.py +++ b/phi/field/_field_math.py @@ -121,10 +121,16 @@ def stagger(field: CenteredGrid, all_lower = [] all_upper = [] if type == StaggeredGrid: - for dim in field.shape.spatial.names: - lo_valid, up_valid = extrapolation.valid_outer_faces(dim) - width_lower = {dim: (int(lo_valid), int(up_valid) - 1)} - width_upper = {dim: (int(lo_valid or up_valid) - 1, int(lo_valid and up_valid))} + for dim in field.resolution.names: + valid_lo, valid_up = extrapolation.valid_outer_faces(dim) + if valid_lo and valid_up: + width_lower, width_upper = {dim: (1, 0)}, {dim: (0, 1)} + elif valid_lo and not valid_up: + width_lower, width_upper = {dim: (1, -1)}, {dim: (0, 0)} + elif not valid_lo and valid_up: + width_lower, width_upper = {dim: (0, 0)}, {dim: (-1, 1)} + else: + width_lower, width_upper = {dim: (0, -1)}, {dim: (-1, 0)} all_lower.append(math.pad(field.values, width_lower, field.extrapolation, bounds=field.bounds)) all_upper.append(math.pad(field.values, width_upper, field.extrapolation, bounds=field.bounds)) all_upper = math.stack(all_upper, channel('vector')) From ed0ddb56eb4cb005cbb2bcbabea1329834829ee8 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 18:42:24 +0200 Subject: [PATCH 05/14] [physics] Fix make_incompressible for mixed boundary conditions --- phi/physics/fluid.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/phi/physics/fluid.py b/phi/physics/fluid.py index 7fd762b63..7dedcc017 100644 --- a/phi/physics/fluid.py +++ b/phi/physics/fluid.py @@ -102,7 +102,7 @@ def make_incompressible(velocity: GridType, pressure = math.solve_linear(masked_laplace, f_args=[hard_bcs, active], y=div, solve=solve) # --- Subtract grad p --- grad_pressure = field.spatial_gradient(pressure, input_velocity.extrapolation, type=type(velocity)) * hard_bcs - velocity = velocity - grad_pressure + velocity = (velocity - grad_pressure).with_extrapolation(input_velocity.extrapolation) return velocity, pressure @@ -123,8 +123,8 @@ def masked_laplace(pressure: CenteredGrid, hard_bcs: Grid, active: CenteredGrid) Returns: `CenteredGrid` """ - grad = spatial_gradient(pressure, hard_bcs.extrapolation, type=type(hard_bcs)) - valid_grad = grad * hard_bcs + grad = spatial_gradient(pressure, extrapolation.NONE, type=type(hard_bcs)) + valid_grad = grad * field.bake_extrapolation(hard_bcs) div = divergence(valid_grad) laplace = where(active, div, pressure) return laplace @@ -189,10 +189,8 @@ def _pressure_extrapolation(vext: Extrapolation): return extrapolation.ZERO elif isinstance(vext, extrapolation.ConstantExtrapolation): return extrapolation.BOUNDARY - elif isinstance(vext, extrapolation._MixedExtrapolation): - return combine_sides(**{dim: (_pressure_extrapolation(lo), _pressure_extrapolation(hi)) for dim, (lo, hi) in vext.ext.items()}) else: - raise ValueError(f"Unsupported extrapolation: {type(vext)}") + return extrapolation.map(_pressure_extrapolation, vext) def _accessible_extrapolation(vext: Extrapolation): From dab8e1af4091a18a6251668b9c2475ee6b1e8f66 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 18:56:40 +0200 Subject: [PATCH 06/14] [field] Fix diffuse.explicit() for constant non-zero extrapolation --- phi/physics/diffuse.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/phi/physics/diffuse.py b/phi/physics/diffuse.py index da61c92a2..4c4e192c1 100644 --- a/phi/physics/diffuse.py +++ b/phi/physics/diffuse.py @@ -31,8 +31,10 @@ def explicit(field: FieldType, amount = diffusivity * dt if isinstance(amount, Field): amount = amount.at(field) + ext = field.extrapolation for i in range(substeps): - field += amount / substeps * laplace(field).with_extrapolation(field.extrapolation) + field += amount / substeps * laplace(field).with_extrapolation(ext) + field = field.with_extrapolation(ext) return field From 9ecd8a2f2baaa0171968093217081e35c67fab71 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 18:58:31 +0200 Subject: [PATCH 07/14] [demos] Update pipe.py to use boundary-driven flow --- demos/pipe.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/demos/pipe.py b/demos/pipe.py index 2cc048f2c..0d25b435d 100644 --- a/demos/pipe.py +++ b/demos/pipe.py @@ -3,14 +3,12 @@ """ from phi.flow import * -DOMAIN = dict(x=50, y=32, extrapolation=extrapolation.combine_sides(x=extrapolation.BOUNDARY, y=extrapolation.ZERO)) -DT = 1.0 -BOUNDARY_MASK = StaggeredGrid(Box(x=(-INF, 0.5), y=None), **DOMAIN) -velocity = StaggeredGrid(0, **DOMAIN) +DT = 1. +INFLOW_BC = extrapolation.combine_by_direction(normal=1, tangential=0) +velocity = StaggeredGrid(0, extrapolation.combine_sides(x=(INFLOW_BC, extrapolation.BOUNDARY), y=0), x=50, y=32) pressure = None for _ in view('velocity, pressure', namespace=globals()).range(): velocity = advect.semi_lagrangian(velocity, velocity, DT) - velocity = velocity * (1 - BOUNDARY_MASK) + BOUNDARY_MASK * (1, 0) velocity, pressure = fluid.make_incompressible(velocity, solve=Solve('CG-adaptive', 1e-5, 0, x0=pressure)) velocity = diffuse.explicit(velocity, 0.1, DT) From d2ca6e74083f871135891529fb8aec51ff0d75a0 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 19:19:11 +0200 Subject: [PATCH 08/14] [field] Fix pad for Jax --- phi/math/extrapolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phi/math/extrapolation.py b/phi/math/extrapolation.py index 8f3997f4e..a3e04c4aa 100644 --- a/phi/math/extrapolation.py +++ b/phi/math/extrapolation.py @@ -243,7 +243,7 @@ def pad(self, value: Tensor, widths: dict, **kwargs): if isinstance(value, NativeTensor): native = value._native ordered_pad_widths = order_by_shape(value.shape, widths, default=(0, 0)) - backend = choose_backend(native) + backend = choose_backend(native, pad_value.native()) result_tensor = backend.pad(native, ordered_pad_widths, 'constant', pad_value.native()) if result_tensor is NotImplemented: return Extrapolation.pad(self, value, widths, **kwargs) From 0fb21321062fa684e99f619fd94f18207a2e4737 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 21:25:52 +0200 Subject: [PATCH 09/14] [field] Fix batched_gather_nd for Jax --- phi/jax/_jax_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phi/jax/_jax_backend.py b/phi/jax/_jax_backend.py index 4822939bb..c32962999 100644 --- a/phi/jax/_jax_backend.py +++ b/phi/jax/_jax_backend.py @@ -346,6 +346,8 @@ def cast(self, x, dtype: DType): return jnp.array(x, to_numpy_dtype(dtype)) def batched_gather_nd(self, values, indices): + values = self.as_tensor(values) + indices = self.as_tensor(indices) assert indices.shape[-1] == self.ndims(values) - 2 batch_size = combined_dim(values.shape[0], indices.shape[0]) results = [] From 4d98504e9f95b75b05f6446e7b9e03fb98ced31b Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 21:27:50 +0200 Subject: [PATCH 10/14] [math] Add pad argument to spatial_gradient() --- phi/math/_nd.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/phi/math/_nd.py b/phi/math/_nd.py index 98dbf8490..9e0693651 100644 --- a/phi/math/_nd.py +++ b/phi/math/_nd.py @@ -319,7 +319,8 @@ def shift(x: Tensor, offsets: tuple, dims: DimFilter = math.spatial, padding: Extrapolation or None = extrapolation.BOUNDARY, - stack_dim: Optional[Shape] = channel('shift')) -> list: + stack_dim: Optional[Shape] = channel('shift'), + extend_bounds=0) -> list: """ shift Tensor by a fixed offset and abiding by extrapolation @@ -343,7 +344,9 @@ def shift(x: Tensor, pad_lower = max(0, -min(offsets)) pad_upper = max(0, max(offsets)) if padding: - x = math.pad(x, {axis: (pad_lower, pad_upper) for axis in dims}, mode=padding) + x = math.pad(x, {axis: (pad_lower + extend_bounds, pad_upper + extend_bounds) for axis in dims}, mode=padding) + if extend_bounds: + assert padding is not None offset_tensors = [] for offset in offsets: components = [] @@ -435,7 +438,8 @@ def spatial_gradient(grid: Tensor, difference: str = 'central', padding: Extrapolation or None = extrapolation.BOUNDARY, dims: DimFilter = spatial, - stack_dim: Shape or None = channel('gradient')) -> Tensor: + stack_dim: Shape or None = channel('gradient'), + pad=0) -> Tensor: """ Calculates the spatial_gradient of a scalar channel from finite differences. The spatial_gradient vectors are in reverse order, lowest dimension first. @@ -448,6 +452,8 @@ def spatial_gradient(grid: Tensor, difference: type of difference, one of ('forward', 'backward', 'central') (default 'forward') padding: tensor padding mode stack_dim: name of the new vector dimension listing the spatial_gradient w.r.t. the various axes + pad: How many cells to extend the result compared to `grid`. + This value is added to the internal padding. For non-trivial extrapolations, this gives the correct result while manual padding before or after this operation would not respect the boundary locations. Returns: `Tensor` @@ -463,13 +469,13 @@ def spatial_gradient(grid: Tensor, if dx.vector.size in (None, 1): dx = dx.vector[0] if difference.lower() == 'central': - left, right = shift(grid, (-1, 1), dims, padding, stack_dim=stack_dim) + left, right = shift(grid, (-1, 1), dims, padding, stack_dim=stack_dim, extend_bounds=pad) return (right - left) / (dx * 2) elif difference.lower() == 'forward': - left, right = shift(grid, (0, 1), dims, padding, stack_dim=stack_dim) + left, right = shift(grid, (0, 1), dims, padding, stack_dim=stack_dim, extend_bounds=pad) return (right - left) / dx elif difference.lower() == 'backward': - left, right = shift(grid, (-1, 0), dims, padding, stack_dim=stack_dim) + left, right = shift(grid, (-1, 0), dims, padding, stack_dim=stack_dim, extend_bounds=pad) return (right - left) / dx else: raise ValueError('Invalid difference type: {}. Can be CENTRAL or FORWARD'.format(difference)) From f9fa54aa4afca54e2cecf2ef316931ad9338696b Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 21:28:18 +0200 Subject: [PATCH 11/14] [field] Fix spatial_gradient(), divergence() for extrapolation=NONE --- phi/field/_field_math.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/phi/field/_field_math.py b/phi/field/_field_math.py index a3951c8e7..94c801a0b 100644 --- a/phi/field/_field_math.py +++ b/phi/field/_field_math.py @@ -70,8 +70,10 @@ def spatial_gradient(field: Field, if extrapolation is None: extrapolation = field.extrapolation.spatial_gradient() if type == CenteredGrid: - values = math.spatial_gradient(field.values, field.dx.vector.as_channel(name=stack_dim.name), difference='central', padding=field.extrapolation, stack_dim=stack_dim) - return CenteredGrid(values, bounds=field.bounds, extrapolation=extrapolation) + pad = 1 if extrapolation == math.extrapolation.NONE else 0 + values = math.spatial_gradient(field.values, field.dx.vector.as_channel(name=stack_dim.name), difference='central', padding=field.extrapolation, stack_dim=stack_dim, pad=pad) + bounds = Box(field.bounds.lower - field.dx, field.bounds.upper + field.dx) if extrapolation == math.extrapolation.NONE else field.bounds + return CenteredGrid(values, bounds=bounds, extrapolation=extrapolation) elif type == StaggeredGrid: assert stack_dim.name == 'vector' return stagger(field, lambda lower, upper: (upper - lower) / field.dx, extrapolation) @@ -175,6 +177,8 @@ def divergence(field: Grid) -> CenteredGrid: grad = (right - left) / (field.dx * 2) components = [grad.vector[i].div_[i] for i in range(grad.div_.size)] result = sum(components) + if field.extrapolation == math.extrapolation.NONE: + result = result.with_bounds(Box(field.bounds.lower + field.dx, field.bounds.upper - field.dx)) return result else: raise NotImplementedError(f"{type(field)} not supported. Only StaggeredGrid allowed.") From ac1131a4ea9316b341023a492e0d3e67471dd8e0 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Tue, 18 Oct 2022 21:36:53 +0200 Subject: [PATCH 12/14] [tests] Ignore boundary cells in divergence test for CenteredGrids --- tests/commit/physics/test_fluid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/commit/physics/test_fluid.py b/tests/commit/physics/test_fluid.py index 0b8f3b831..25517f435 100644 --- a/tests/commit/physics/test_fluid.py +++ b/tests/commit/physics/test_fluid.py @@ -24,7 +24,10 @@ def _test_make_incompressible(self, grid_type: type, extrapolation: math.Extrapo for _ in range(2): velocity += smoke * (0, 0.1) @ velocity velocity, _ = fluid.make_incompressible(velocity) - math.assert_close(divergence(velocity).values, 0, abs_tolerance=2e-5) + if grid_type == StaggeredGrid: + math.assert_close(0, divergence(velocity).values, abs_tolerance=2e-5) + else: + math.assert_close(0, field.pad(divergence(velocity), -1).values, abs_tolerance=2e-5) if result is None: result = velocity else: From 60f9178f3699fd361137a3486b08eb813fd00cfb Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Thu, 20 Oct 2022 13:29:48 +0200 Subject: [PATCH 13/14] =?UTF-8?q?[=CE=A6]=20Bump=20version=20to=202.2.4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- phi/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phi/VERSION b/phi/VERSION index 6b4d15773..047615559 100644 --- a/phi/VERSION +++ b/phi/VERSION @@ -1 +1 @@ -2.2.3 \ No newline at end of file +2.2.4 \ No newline at end of file From a80684fee5d80ea3dfc513b5a74db5e95226857f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Barnab=C3=A1s=20B=C3=B6rcs=C3=B6k?= Date: Fri, 21 Oct 2022 15:55:25 +0200 Subject: [PATCH 14/14] [doc] Fix link in index.md --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index 46df4363d..e5aeab7af 100644 --- a/docs/index.md +++ b/docs/index.md @@ -40,7 +40,7 @@ #### Neural Networks -* [▶️ Introduction Video](https://youtu.be/YRi_c0v3HKs) +* [▶️ Introduction Video](https://youtu.be/aNigTqklCBc) * [Learning to Throw](https://tum-pbs.github.io/PhiFlow/Learn_to_Throw_Tutorial.html)