Skip to content

Commit

Permalink
Merge pull request #90 from tum-pbs/develop
Browse files Browse the repository at this point in the history
2.2.4
  • Loading branch information
holl- authored Oct 28, 2022
2 parents 52eac7f + 6225bde commit 4b7e5cb
Show file tree
Hide file tree
Showing 14 changed files with 154 additions and 55 deletions.
8 changes: 3 additions & 5 deletions demos/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
"""
from phi.flow import *

DOMAIN = dict(x=50, y=32, extrapolation=extrapolation.combine_sides(x=extrapolation.BOUNDARY, y=extrapolation.ZERO))
DT = 1.0
BOUNDARY_MASK = StaggeredGrid(Box(x=(-INF, 0.5), y=None), **DOMAIN)
velocity = StaggeredGrid(0, **DOMAIN)
DT = 1.
INFLOW_BC = extrapolation.combine_by_direction(normal=1, tangential=0)
velocity = StaggeredGrid(0, extrapolation.combine_sides(x=(INFLOW_BC, extrapolation.BOUNDARY), y=0), x=50, y=32)
pressure = None

for _ in view('velocity, pressure', namespace=globals()).range():
velocity = advect.semi_lagrangian(velocity, velocity, DT)
velocity = velocity * (1 - BOUNDARY_MASK) + BOUNDARY_MASK * (1, 0)
velocity, pressure = fluid.make_incompressible(velocity, solve=Solve('CG-adaptive', 1e-5, 0, x0=pressure))
velocity = diffuse.explicit(velocity, 0.1, DT)
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

#### Neural Networks

* [▶️ Introduction Video](https://youtu.be/YRi_c0v3HKs)
* [▶️ Introduction Video](https://youtu.be/aNigTqklCBc)
* [Learning to Throw](https://tum-pbs.github.io/PhiFlow/Learn_to_Throw_Tutorial.html)


Expand Down
2 changes: 1 addition & 1 deletion phi/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.3
2.2.4
22 changes: 16 additions & 6 deletions phi/field/_field_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def spatial_gradient(field: Field,
if extrapolation is None:
extrapolation = field.extrapolation.spatial_gradient()
if type == CenteredGrid:
values = math.spatial_gradient(field.values, field.dx.vector.as_channel(name=stack_dim.name), difference='central', padding=field.extrapolation, stack_dim=stack_dim)
return CenteredGrid(values, bounds=field.bounds, extrapolation=extrapolation)
pad = 1 if extrapolation == math.extrapolation.NONE else 0
values = math.spatial_gradient(field.values, field.dx.vector.as_channel(name=stack_dim.name), difference='central', padding=field.extrapolation, stack_dim=stack_dim, pad=pad)
bounds = Box(field.bounds.lower - field.dx, field.bounds.upper + field.dx) if extrapolation == math.extrapolation.NONE else field.bounds
return CenteredGrid(values, bounds=bounds, extrapolation=extrapolation)
elif type == StaggeredGrid:
assert stack_dim.name == 'vector'
return stagger(field, lambda lower, upper: (upper - lower) / field.dx, extrapolation)
Expand Down Expand Up @@ -121,10 +123,16 @@ def stagger(field: CenteredGrid,
all_lower = []
all_upper = []
if type == StaggeredGrid:
for dim in field.shape.spatial.names:
lo_valid, up_valid = extrapolation.valid_outer_faces(dim)
width_lower = {dim: (int(lo_valid), int(up_valid) - 1)}
width_upper = {dim: (int(lo_valid or up_valid) - 1, int(lo_valid and up_valid))}
for dim in field.resolution.names:
valid_lo, valid_up = extrapolation.valid_outer_faces(dim)
if valid_lo and valid_up:
width_lower, width_upper = {dim: (1, 0)}, {dim: (0, 1)}
elif valid_lo and not valid_up:
width_lower, width_upper = {dim: (1, -1)}, {dim: (0, 0)}
elif not valid_lo and valid_up:
width_lower, width_upper = {dim: (0, 0)}, {dim: (-1, 1)}
else:
width_lower, width_upper = {dim: (0, -1)}, {dim: (-1, 0)}
all_lower.append(math.pad(field.values, width_lower, field.extrapolation, bounds=field.bounds))
all_upper.append(math.pad(field.values, width_upper, field.extrapolation, bounds=field.bounds))
all_upper = math.stack(all_upper, channel('vector'))
Expand Down Expand Up @@ -169,6 +177,8 @@ def divergence(field: Grid) -> CenteredGrid:
grad = (right - left) / (field.dx * 2)
components = [grad.vector[i].div_[i] for i in range(grad.div_.size)]
result = sum(components)
if field.extrapolation == math.extrapolation.NONE:
result = result.with_bounds(Box(field.bounds.lower + field.dx, field.bounds.upper - field.dx))
return result
else:
raise NotImplementedError(f"{type(field)} not supported. Only StaggeredGrid allowed.")
Expand Down
6 changes: 5 additions & 1 deletion phi/field/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def __getitem__(self, item):
item = slicing_dict(self, item)
if not item:
return self
if 'vector' in item:
selection = item['vector']
if isinstance(selection, int):
item['vector'] = self.resolution.names[selection]
values = self._values[{dim: sel for dim, sel in item.items() if dim not in self.shape.spatial}]
for dim, sel in item.items():
if dim in self.shape.spatial:
Expand All @@ -410,7 +414,7 @@ def __getitem__(self, item):
assert selection in item_names, f"Accessing field.vector['{selection}'] failed. Item names are {item_names}."
selection = item_names.index(selection)
if isinstance(selection, int):
dim = self.shape.spatial.names[selection]
dim = self.resolution.names[selection]
comp_cells = GridCell(self.resolution, bounds).stagger(dim, *self.extrapolation.valid_outer_faces(dim))
return CenteredGrid(values, bounds=comp_cells.bounds, extrapolation=extrapolation)
else:
Expand Down
2 changes: 2 additions & 0 deletions phi/jax/_jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ def cast(self, x, dtype: DType):
return jnp.array(x, to_numpy_dtype(dtype))

def batched_gather_nd(self, values, indices):
values = self.as_tensor(values)
indices = self.as_tensor(indices)
assert indices.shape[-1] == self.ndims(values) - 2
batch_size = combined_dim(values.shape[0], indices.shape[0])
results = []
Expand Down
18 changes: 12 additions & 6 deletions phi/math/_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def shift(x: Tensor,
offsets: tuple,
dims: DimFilter = math.spatial,
padding: Extrapolation or None = extrapolation.BOUNDARY,
stack_dim: Optional[Shape] = channel('shift')) -> list:
stack_dim: Optional[Shape] = channel('shift'),
extend_bounds=0) -> list:
"""
shift Tensor by a fixed offset and abiding by extrapolation
Expand All @@ -343,7 +344,9 @@ def shift(x: Tensor,
pad_lower = max(0, -min(offsets))
pad_upper = max(0, max(offsets))
if padding:
x = math.pad(x, {axis: (pad_lower, pad_upper) for axis in dims}, mode=padding)
x = math.pad(x, {axis: (pad_lower + extend_bounds, pad_upper + extend_bounds) for axis in dims}, mode=padding)
if extend_bounds:
assert padding is not None
offset_tensors = []
for offset in offsets:
components = []
Expand Down Expand Up @@ -435,7 +438,8 @@ def spatial_gradient(grid: Tensor,
difference: str = 'central',
padding: Extrapolation or None = extrapolation.BOUNDARY,
dims: DimFilter = spatial,
stack_dim: Shape or None = channel('gradient')) -> Tensor:
stack_dim: Shape or None = channel('gradient'),
pad=0) -> Tensor:
"""
Calculates the spatial_gradient of a scalar channel from finite differences.
The spatial_gradient vectors are in reverse order, lowest dimension first.
Expand All @@ -448,6 +452,8 @@ def spatial_gradient(grid: Tensor,
difference: type of difference, one of ('forward', 'backward', 'central') (default 'forward')
padding: tensor padding mode
stack_dim: name of the new vector dimension listing the spatial_gradient w.r.t. the various axes
pad: How many cells to extend the result compared to `grid`.
This value is added to the internal padding. For non-trivial extrapolations, this gives the correct result while manual padding before or after this operation would not respect the boundary locations.
Returns:
`Tensor`
Expand All @@ -463,13 +469,13 @@ def spatial_gradient(grid: Tensor,
if dx.vector.size in (None, 1):
dx = dx.vector[0]
if difference.lower() == 'central':
left, right = shift(grid, (-1, 1), dims, padding, stack_dim=stack_dim)
left, right = shift(grid, (-1, 1), dims, padding, stack_dim=stack_dim, extend_bounds=pad)
return (right - left) / (dx * 2)
elif difference.lower() == 'forward':
left, right = shift(grid, (0, 1), dims, padding, stack_dim=stack_dim)
left, right = shift(grid, (0, 1), dims, padding, stack_dim=stack_dim, extend_bounds=pad)
return (right - left) / dx
elif difference.lower() == 'backward':
left, right = shift(grid, (-1, 0), dims, padding, stack_dim=stack_dim)
left, right = shift(grid, (-1, 0), dims, padding, stack_dim=stack_dim, extend_bounds=pad)
return (right - left) / dx
else:
raise ValueError('Invalid difference type: {}. Can be CENTRAL or FORWARD'.format(difference))
Expand Down
28 changes: 19 additions & 9 deletions phi/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,16 +701,21 @@ def inner_stack(*values):
def concat_tensor(values: tuple or list, dim: str) -> Tensor:
assert len(values) > 0, "concat() got empty sequence"
assert isinstance(dim, str), f"dim must be a single-dimension Shape but got '{dim}' of type {type(dim)}"
broadcast_shape = merge_shapes(*[t.shape._with_item_name(dim, None).with_sizes([None] * t.shape.rank) for t in values])
natives = [v.native(order=broadcast_shape.names) for v in values]
backend = choose_backend(*natives)
concatenated = backend.concat(natives, broadcast_shape.index(dim))
if all([v.shape.get_item_names(dim) is not None for v in values]):
broadcast_shape = broadcast_shape._with_item_name(dim, sum([v.shape.get_item_names(dim) for v in values], ()))
return NativeTensor(concatenated, broadcast_shape.with_sizes(backend.staticshape(concatenated)))

def inner_concat(*values):
broadcast_shape = merge_shapes(*[t.shape._with_item_name(dim, None).with_sizes([None] * t.shape.rank) for t in values])
natives = [v.native(order=broadcast_shape.names) for v in values]
backend = choose_backend(*natives)
concatenated = backend.concat(natives, broadcast_shape.index(dim))
if all([v.shape.get_item_names(dim) is not None for v in values]):
broadcast_shape = broadcast_shape._with_item_name(dim, sum([v.shape.get_item_names(dim) for v in values], ()))
return NativeTensor(concatenated, broadcast_shape.with_sizes(backend.staticshape(concatenated)))

def pad(value: Tensor, widths: dict, mode: 'e_.Extrapolation', **kwargs) -> Tensor:
result = broadcast_op(inner_concat, values)
return result


def pad(value: Tensor, widths: dict, mode: 'e_.Extrapolation' or Tensor or Number, **kwargs) -> Tensor:
"""
Pads a tensor along the specified dimensions, determining the added values using the given extrapolation.
Unlike `Extrapolation.pad()`, this function can handle negative widths which slice off outer values.
Expand All @@ -720,13 +725,15 @@ def pad(value: Tensor, widths: dict, mode: 'e_.Extrapolation', **kwargs) -> Tens
widths: `dict` mapping dimension name (`str`) to `(lower, upper)`
where `lower` and `upper` are `int` that can be positive (pad), negative (slice) or zero (pass).
mode: `Extrapolation` used to determine values added from positive `widths`.
Assumes constant extrapolation if given a number or `Tensor` instead.
kwargs: Additional padding arguments.
These are ignored by the standard extrapolations defined in `phi.math.extrapolation` but can be used to pass additional contextual information to custom extrapolations.
Grid classes from `phi.field` will pass the argument `bounds: Box`.
Returns:
Padded `Tensor`
"""
mode = mode if isinstance(mode, e_.Extrapolation) else e_.ConstantExtrapolation(mode)
has_negative_widths = any(w[0] < 0 or w[1] < 0 for w in widths.values())
slices = None
if has_negative_widths:
Expand Down Expand Up @@ -886,6 +893,7 @@ def broadcast_op(operation: Callable,
dim = next(iter(iter_dims))
dim_type = None
size = None
item_names = None
unstacked = []
for tensor in tensors:
if dim in tensor.shape.names:
Expand All @@ -897,14 +905,16 @@ def broadcast_op(operation: Callable,
else:
assert size == len(unstacked_tensor)
assert dim_type == tensor.shape.get_type(dim)
if item_names is None:
item_names = tensor.shape.get_item_names(dim)
else:
unstacked.append(tensor)
result_unstacked = []
for i in range(size):
gathered = [t[i] if isinstance(t, tuple) else t for t in unstacked]
result_unstacked.append(broadcast_op(operation, gathered, iter_dims=set(iter_dims) - {dim}))
if not no_return:
return TensorStack(result_unstacked, Shape((None,), (dim,), (dim_type,), (None,)))
return TensorStack(result_unstacked, Shape((None,), (dim,), (dim_type,), (item_names,)))


def where(condition: Tensor or float or int, value_true: Tensor or float or int, value_false: Tensor or float or int):
Expand Down
23 changes: 21 additions & 2 deletions phi/math/extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def shortest_distance(self, start: Tensor, end: Tensor, domain_size: Tensor):
def __getitem__(self, item):
return self

def _getitem_with_domain(self, item: dict, dim: str, upper_edge: bool, all_dims: tuple):
return self[item]

def __abs__(self):
raise NotImplementedError(self.__class__)

Expand Down Expand Up @@ -240,8 +243,10 @@ def pad(self, value: Tensor, widths: dict, **kwargs):
if isinstance(value, NativeTensor):
native = value._native
ordered_pad_widths = order_by_shape(value.shape, widths, default=(0, 0))
backend = choose_backend(native)
backend = choose_backend(native, pad_value.native())
result_tensor = backend.pad(native, ordered_pad_widths, 'constant', pad_value.native())
if result_tensor is NotImplemented:
return Extrapolation.pad(self, value, widths, **kwargs)
new_shape = value.shape.with_sizes(backend.staticshape(result_tensor))
return NativeTensor(result_tensor, new_shape)
elif isinstance(value, CollapsedTensor):
Expand Down Expand Up @@ -902,7 +907,8 @@ def transform_coordinates(self, coordinates: Tensor, shape: Shape, **kwargs) ->

def __getitem__(self, item):
if isinstance(item, dict):
return combine_sides(**{dim: (e1[item], e2[item]) for dim, (e1, e2) in self.ext.items()})
all_dims = tuple(self.ext.keys())
return combine_sides(**{dim: (e1._getitem_with_domain(item, dim, False, all_dims), e2._getitem_with_domain(item, dim, True, all_dims)) for dim, (e1, e2) in self.ext.items()})
else:
dim, face = item
return self.ext[dim][face]
Expand Down Expand Up @@ -992,9 +998,22 @@ def pad_values(self, value: Tensor, width: int, dim: str, upper_edge: bool, **kw
result = stack(result, value.shape.only('vector'))
return result

def _getitem_with_domain(self, item: dict, dim: str, upper_edge: bool, all_dims: tuple):
if 'vector' not in item:
return self
component = item['vector']
assert isinstance(component, str), f"Selecting a component of normal/tangential must be done by dimension name but got {component}"
if component == dim:
return self.normal
else:
return self.tangential

def __eq__(self, other):
return isinstance(other, _NormalTangentialExtrapolation) and self.normal == other.normal and self.tangential == other.tangential

def __hash__(self):
return hash(self.normal) + hash(self.tangential)

def __add__(self, other):
return self._op2(other, lambda e1, e2: e1 + e2)

Expand Down
4 changes: 3 additions & 1 deletion phi/physics/diffuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def explicit(field: FieldType,
amount = diffusivity * dt
if isinstance(amount, Field):
amount = amount.at(field)
ext = field.extrapolation
for i in range(substeps):
field += amount / substeps * laplace(field).with_extrapolation(field.extrapolation)
field += amount / substeps * laplace(field).with_extrapolation(ext)
field = field.with_extrapolation(ext)
return field


Expand Down
10 changes: 4 additions & 6 deletions phi/physics/fluid.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def make_incompressible(velocity: GridType,
pressure = math.solve_linear(masked_laplace, f_args=[hard_bcs, active], y=div, solve=solve)
# --- Subtract grad p ---
grad_pressure = field.spatial_gradient(pressure, input_velocity.extrapolation, type=type(velocity)) * hard_bcs
velocity = velocity - grad_pressure
velocity = (velocity - grad_pressure).with_extrapolation(input_velocity.extrapolation)
return velocity, pressure


Expand All @@ -123,8 +123,8 @@ def masked_laplace(pressure: CenteredGrid, hard_bcs: Grid, active: CenteredGrid)
Returns:
`CenteredGrid`
"""
grad = spatial_gradient(pressure, hard_bcs.extrapolation, type=type(hard_bcs))
valid_grad = grad * hard_bcs
grad = spatial_gradient(pressure, extrapolation.NONE, type=type(hard_bcs))
valid_grad = grad * field.bake_extrapolation(hard_bcs)
div = divergence(valid_grad)
laplace = where(active, div, pressure)
return laplace
Expand Down Expand Up @@ -189,10 +189,8 @@ def _pressure_extrapolation(vext: Extrapolation):
return extrapolation.ZERO
elif isinstance(vext, extrapolation.ConstantExtrapolation):
return extrapolation.BOUNDARY
elif isinstance(vext, extrapolation._MixedExtrapolation):
return combine_sides(**{dim: (_pressure_extrapolation(lo), _pressure_extrapolation(hi)) for dim, (lo, hi) in vext.ext.items()})
else:
raise ValueError(f"Unsupported extrapolation: {type(vext)}")
return extrapolation.map(_pressure_extrapolation, vext)


def _accessible_extrapolation(vext: Extrapolation):
Expand Down
Loading

0 comments on commit 4b7e5cb

Please sign in to comment.