Skip to content

Commit

Permalink
Merge pull request #86 from tum-pbs/develop
Browse files Browse the repository at this point in the history
2.2.3
  • Loading branch information
holl- authored Oct 13, 2022
2 parents 49cf9c5 + fec82c1 commit 52eac7f
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 24 deletions.
2 changes: 1 addition & 1 deletion phi/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.2
2.2.3
2 changes: 1 addition & 1 deletion phi/jax/_jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self):
for device_type in ['cpu', 'gpu', 'tpu']:
try:
for jax_dev in jax.devices(device_type):
devices.append(ComputeDevice(self, jax_dev.device_kind, jax_dev.platform.upper(), -1, -1, f"id={jax_dev.id}", jax_dev))
devices.append(ComputeDevice(self, device_type.upper(), jax_dev.platform.upper(), -1, -1, f"id={jax_dev.id}", jax_dev))
except RuntimeError as err:
pass # this is just Jax not finding anything. jaxlib.xla_client._get_local_backends() could help but isn't currently available on GitHub actions
Backend.__init__(self, "Jax", devices, devices[-1])
Expand Down
71 changes: 59 additions & 12 deletions phi/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ def unstack(value, dim: DimFilter):
If multiple dimensions are given, the order of elements will be according to the dimension order in `dim`, i.e. elements along the last dimension will be neighbors in the returned `tuple`.
Args:
value: `Tensor` to unstack.
value: `phi.math.magic.Shapable`, such as `phi.math.Tensor`
dim: Dimensions as `Shape` or comma-separated `str` or dimension type, i.e. `channel`, `spatial`, `instance`, `batch`.
Returns:
`tuple` of `Tensor` objects.
Examples:
```python
unstack(math.zeros(spatial(x=5)), 'x')
# Out: (0.0, 0.0, 0.0, 0.0, 0.0)
```
"""
assert isinstance(value, Sliceable) and isinstance(value, Shaped)
dims = shape(value).only(dim)
Expand Down Expand Up @@ -55,7 +61,7 @@ def stack(values: tuple or list or dict, dim: Shape, **kwargs):
This makes repeated stacking and slicing along the same dimension very efficient.
Args:
values: Sequence of `Shapable` objects to be stacked.
values: Collection of `phi.math.magic.Shapable`, such as `phi.math.Tensor`
If a `dict`, keys must be of type `str` and are used as item names along `dim`.
dim: `Shape` with a least one dimension. None of these dimensions can be present with any of the `values`.
If `dim` is a single-dimension shape, its size is determined from `len(values)` and can be left undefined (`None`).
Expand All @@ -66,6 +72,19 @@ def stack(values: tuple or list or dict, dim: Shape, **kwargs):
Returns:
`Tensor` containing `values` stacked along `dim`.
Examples:
```python
stack({'x': 0, 'y': 1}, channel('vector'))
# Out: (x=0, y=1)
stack([math.zeros(batch(b=2)), math.ones(batch(b=2))], channel(c='x,y'))
# Out: (x=0.000, y=1.000); (x=0.000, y=1.000) (bᵇ=2, cᶜ=x,y)
stack([vec(x=1, y=0), vec(x=2, y=3.)], batch('b'))
# Out: (x=1.000, y=0.000); (x=2.000, y=3.000) (bᵇ=2, vectorᶜ=x,y)
```
"""
assert len(values) > 0, f"stack() got empty sequence {values}"
assert isinstance(dim, Shape)
Expand Down Expand Up @@ -129,11 +148,11 @@ def stack(values: tuple or list or dict, dim: Shape, **kwargs):

def concat(values: tuple or list, dim: str or Shape, **kwargs):
"""
Concatenates a sequence of tensors along one dimension.
Concatenates a sequence of `phi.math.magic.Shapable` objects, e.g. `Tensor`, along one dimension.
The shapes of all values must be equal, except for the size of the concat dimension.
Args:
values: Tensors to concatenate
values: Tuple or list of `phi.math.magic.Shapable`, such as `phi.math.Tensor`
dim: Concatenation dimension, must be present in all `values`.
The size along `dim` is determined from `values` and can be set to undefined (`None`).
**kwargs: Additional keyword arguments required by specific implementations.
Expand All @@ -142,6 +161,16 @@ def concat(values: tuple or list, dim: str or Shape, **kwargs):
Returns:
Concatenated `Tensor`
Examples:
```python
concat([math.zeros(batch(b=10)), math.ones(batch(b=10))], 'b')
# Out: (bᵇ=20) 0.500 ± 0.500 (0e+00...1e+00)
concat([vec(x=1, y=0), vec(z=2.)], 'vector')
# Out: (x=1.000, y=0.000, z=2.000) float64
```
"""
assert len(values) > 0, f"concat() got empty sequence {values}"
if isinstance(dim, Shape):
Expand Down Expand Up @@ -186,14 +215,14 @@ def expand(value, dims: Shape, **kwargs):
Additionally, it replaces the traditional `unsqueeze` / `expand_dims` functions.
Args:
value: `Tensor`
value: `phi.math.magic.Shapable`, such as `phi.math.Tensor`
dims: Dimensions to be added as `Shape`
**kwargs: Additional keyword arguments required by specific implementations.
Adding spatial dimensions to fields requires the `bounds: Box` argument specifying the physical extent of the new dimensions.
Adding batch dimensions must always work without keyword arguments.
Returns:
Expanded `Shapable`.
Same type as `value`.
"""
if hasattr(value, '__expand__'):
result = value.__expand__(dims, **kwargs)
Expand Down Expand Up @@ -271,7 +300,7 @@ def pack_dims(value, dims: DimFilter, packed_dim: Shape, pos: int or None = None
`unpack_dim()`
Args:
value: Tensor containing the dimensions `dims`.
value: `phi.math.magic.Shapable`, such as `phi.math.Tensor`.
dims: Dimensions to be compressed in the specified order.
packed_dim: Single-dimension `Shape`.
pos: Index of new dimension. `None` for automatic, `-1` for last, `0` for first.
Expand All @@ -280,7 +309,13 @@ def pack_dims(value, dims: DimFilter, packed_dim: Shape, pos: int or None = None
Adding batch dimensions must always work without keyword arguments.
Returns:
`Tensor` with compressed shape.
Same type as `value`.
Examples:
```python
pack_dims(math.zeros(spatial(x=4, y=3)), spatial, instance('points'))
# Out: (pointsⁱ=12) const 0.0
```
"""
assert isinstance(value, Shapable) and isinstance(value, Sliceable) and isinstance(value, Shaped), f"value must be Shapable but got {type(value)}"
dims = shape(value).only(dims)
Expand Down Expand Up @@ -316,15 +351,21 @@ def unpack_dim(value, dim: str or Shape, unpacked_dims: Shape, **kwargs):
`pack_dims()`
Args:
value: `Tensor` for which one dimension should be split.
value: `phi.math.magic.Shapable`, such as `Tensor`, for which one dimension should be split.
dim: Dimension to be decompressed.
unpacked_dims: `Shape`: Ordered dimensions to replace `dim`, fulfilling `unpacked_dims.volume == shape(self)[dim].rank`.
**kwargs: Additional keyword arguments required by specific implementations.
Adding spatial dimensions to fields requires the `bounds: Box` argument specifying the physical extent of the new dimensions.
Adding batch dimensions must always work without keyword arguments.
Returns:
`Tensor` with decompressed shape
Same type as `value`.
Examples:
```python
unpack_dim(math.zeros(instance(points=12)), 'points', spatial(x=4, y=3))
# Out: (xˢ=4, yˢ=3) const 0.0
```
"""
assert isinstance(value, Shapable) and isinstance(value, Sliceable) and isinstance(value, Shaped), f"value must be Shapable but got {type(value)}"
if isinstance(dim, Shape):
Expand Down Expand Up @@ -354,14 +395,20 @@ def flatten(value, flat_dim: Shape = instance('flat'), **kwargs):
The order of the values in memory is not changed.
Args:
value: `Tensor`
value: `phi.math.magic.Shapable`, such as `Tensor`.
flat_dim: Dimension name and type as `Shape` object. The size is ignored.
**kwargs: Additional keyword arguments required by specific implementations.
Adding spatial dimensions to fields requires the `bounds: Box` argument specifying the physical extent of the new dimensions.
Adding batch dimensions must always work without keyword arguments.
Returns:
`Tensor`
Same type as `value`.
Examples:
```python
flatten(math.zeros(spatial(x=4, y=3)))
# Out: (flatⁱ=12) const 0.0
```
"""
assert isinstance(flat_dim, Shape) and flat_dim.rank == 1, flat_dim
assert isinstance(value, Shapable) and isinstance(value, Shaped), f"value must be Shapable but got {type(value)}"
Expand Down
13 changes: 9 additions & 4 deletions phi/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,7 @@ def gather(values: Tensor, indices: Tensor, dims: DimFilter or None = None):

def scatter(base_grid: Tensor or Shape,
indices: Tensor,
values: Tensor,
values: Tensor or float,
mode: str = 'update',
outside_handling: str = 'discard',
indices_gradient=False):
Expand Down Expand Up @@ -1843,6 +1843,7 @@ def scatter(base_grid: Tensor or Shape,
assert isinstance(indices_gradient, bool)
grid_shape = base_grid if isinstance(base_grid, Shape) else base_grid.shape
assert indices.shape.channel.names == ('vector',) or (grid_shape.spatial_rank + grid_shape.instance_rank == 1 and indices.shape.channel_rank == 0)
values = wrap(values)
batches = values.shape.non_channel.non_instance & indices.shape.non_channel.non_instance
channels = grid_shape.channel & values.shape.channel
# --- Set up grid ---
Expand All @@ -1855,10 +1856,14 @@ def scatter(base_grid: Tensor or Shape,
if outside_handling == 'clamp':
indices = clip(indices, 0, tensor(grid_shape.spatial, channel('vector')) - 1)
elif outside_handling == 'discard':
indices_inside = min_((round_(indices) >= 0) & (round_(indices) < tensor(grid_shape.spatial, channel('vector'))), 'vector')
indices = boolean_mask(indices, indices.shape.instance.name, indices_inside)
indices_linear = pack_dims(indices, instance, instance(_scatter_instance=1))
indices_inside = min_((round_(indices_linear) >= 0) & (round_(indices_linear) < tensor(grid_shape.spatial, channel('vector'))), 'vector')
indices_linear = boolean_mask(indices_linear, '_scatter_instance', indices_inside)
if instance(values).rank > 0:
values = boolean_mask(values, values.shape.instance.name, indices_inside)
values_linear = pack_dims(values, instance, instance(_scatter_instance=1))
values_linear = boolean_mask(values_linear, '_scatter_instance', indices_inside)
values = unpack_dim(values_linear, '_scatter_instance', instance(values))
indices = unpack_dim(indices_linear, '_scatter_instance', instance(indices))
if indices.shape.is_non_uniform:
raise NotImplementedError()
lists = indices.shape.instance & values.shape.instance
Expand Down
2 changes: 1 addition & 1 deletion phi/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def name(self) -> str:
See Also:
`Shape.names`.
"""
assert self.rank == 1, "Shape.name is only defined for shapes of rank 1."
assert self.rank == 1, f"Shape.name is only defined for shapes of rank 1. shape={self}"
return self.names[0]

@property
Expand Down
6 changes: 5 additions & 1 deletion phi/tf/_tf_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class TFBackend(Backend):

def __init__(self):
devices = [ComputeDevice(self, device.name, device.device_type, device.memory_limit, -1, str(device), device.name) for device in device_lib.list_local_devices()]
devices = [ComputeDevice(self, device.name, simple_device_type(device.device_type), device.memory_limit, -1, str(device), device.name) for device in device_lib.list_local_devices()]
# Example refs: '/device:CPU:0'
default_device_ref = '/' + os.path.basename(tf.zeros(()).device)
default_device = None
Expand Down Expand Up @@ -659,3 +659,7 @@ def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Tup


_TAPES = []


def simple_device_type(t: str):
return t[len('XLA_'):] if t.startswith('XLA_') else t
6 changes: 4 additions & 2 deletions phi/vis/_matplotlib/_matplotlib_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def animate(self, fig: plt.Figure, frames: int, plot_frame_function: Callable, i
rc('animation', html='html5')

base_axes = tuple(fig.axes)
positions = {a: (a.figbox.p0, a.figbox.p1) for a in base_axes}
titles = {a: a.get_title() for a in base_axes}
positions = {a: (a.get_subplotspec().get_position(a.figure).p0, a.get_subplotspec().get_position(a.figure).p1) for a in base_axes}
# titles = {a: a.get_title() for a in base_axes}
specs = {a: a.get_subplotspec() for a in base_axes}

def clear_and_plot(frame: int):
Expand Down Expand Up @@ -133,6 +133,8 @@ def show(self, figure):
return HTML(figure.to_html5_video())
else:
figure._fig.show()
else:
raise ValueError(f"{type(figure)} is not a valid {self.name} figure")

def save(self, figure, path: str, dpi: float):
if isinstance(figure, plt.Figure):
Expand Down
8 changes: 6 additions & 2 deletions phi/vis/_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ def show(*model: VisModel or SampledField or tuple or list or Tensor or Geometry
return plots.show(plots.current_figure)
else:
plots = default_plots() if lib is None else get_plots(lib)
fig = plot(*model, lib=plots, **config)
return plots.show(fig)
fig_tensor = plot(*model, lib=plots, **config)
if isinstance(fig_tensor, Tensor):
for fig in fig_tensor:
plots.show(fig)
else:
return plots.show(fig_tensor)


def close(figure=None):
Expand Down
7 changes: 7 additions & 0 deletions tests/commit/math/test__ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,13 @@ def test_scatter_2d_discard(self):
updated = math.scatter(base, indices, values, mode='update', outside_handling='discard')
math.assert_close(updated, math.tensor([[1, 1, 1], [12, 1, 1]], spatial('y,x')))

def test_scatter_single(self):
base = math.zeros(spatial(x=3, y=2))
indices = vec(x=1, y=0)
values = 1
updated = math.scatter(base, indices, values, outside_handling='discard')
math.assert_close(updated, math.tensor([[0, 1, 0], [0, 0, 0]], spatial('y,x')))

def test_sin(self):
for backend in BACKENDS:
with backend:
Expand Down

0 comments on commit 52eac7f

Please sign in to comment.