Skip to content

Commit

Permalink
Merge pull request #102 from tum-pbs/2.3-develop
Browse files Browse the repository at this point in the history
2.3
  • Loading branch information
holl- authored Feb 26, 2023
2 parents 9d11b27 + cd6aaf0 commit f8d0009
Show file tree
Hide file tree
Showing 115 changed files with 35,648 additions and 6,758 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install --quiet tensorflow tensorflow-probability torch jax jaxlib plotly nbformat ipython pylint coverage pytest
pip install --quiet tensorflow tensorflow-probability torch jax jaxlib scikit-learn plotly nbformat ipython pylint coverage pytest
pip install .
- name: Test with pytest
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/update-gh-pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ jobs:
run: pdoc --html --output-dir docs --force phi

- name: Build static HTML for Jupyter Notebooks
run: jupyter nbconvert --to html --execute --allow-errors docs/*.ipynb
run: |
jupyter nbconvert --to html --execute --allow-errors docs/*.ipynb
jupyter nbconvert --to html --output-dir docs/ docs/prerendered/*.ipynb
- name: Deploy 🚀
uses: JamesIves/[email protected] # See https://github.com/marketplace/actions/deploy-to-github-pages
Expand Down
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,23 @@ making it easy to build end-to-end differentiable functions involving both learn

Installation with [pip](https://pypi.org/project/pip/) on [Python 3.6](https://www.python.org/downloads/) and above:
``` bash
$ pip install phiflow dash
$ pip install phiflow
```
Install PyTorch, TensorFlow or Jax in addition to Φ<sub>Flow</sub> to enable machine learning capabilities and GPU execution.
See the [detailed installation instructions](https://tum-pbs.github.io/PhiFlow/Installation_Instructions.html) on how to compile the custom CUDA operators and verify your installation.
Install [PyTorch](https://pytorch.org/), [TensorFlow](https://www.tensorflow.org/install) or [Jax](https://github.com/google/jax#installation) in addition to Φ<sub>Flow</sub> to enable machine learning capabilities and GPU execution.
To enable the web UI, also install [Dash](https://pypi.org/project/dash/).
For optimal GPU performance, you may compile the custom CUDA operators, see the [detailed installation instructions](https://tum-pbs.github.io/PhiFlow/Installation_Instructions.html).

You can verify your installation by running
```bash
$ python3 -c "import phi; phi.verify()"
```
This will check for compatible PyTorch, Jax and TensorFlow installations as well.

## Documentation and Tutorials
[**Documentation Overview**](https://tum-pbs.github.io/PhiFlow/)
&nbsp;&nbsp; [**▶ YouTube Tutorials**](https://www.youtube.com/playlist?list=PLYLhRkuWBmZ5R6hYzusA2JBIUPFEE755O)
&nbsp;&nbsp; [**API**](https://tum-pbs.github.io/PhiFlow/phi/)
&nbsp;&nbsp; [**Demos**](https://github.com/tum-pbs/PhiFlow/tree/develop/demos)
&nbsp;&nbsp; [**Demos**](https://github.com/tum-pbs/PhiFlow/tree/master/demos)
&nbsp;&nbsp; [<img src="https://www.tensorflow.org/images/colab_logo_32px.png" height=16> **Playground**](https://colab.research.google.com/drive/1zBlQbmNguRt-Vt332YvdTqlV4DBcus2S#offline=true&sandboxMode=true)

To get started, check out our YouTube tutorial series and the following Jupyter notebooks:
Expand Down
7 changes: 3 additions & 4 deletions demos/train_identify_noise.py → demos/FNO_train_noise.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import math

from phi.jax.stax.flow import *


net = u_net(1, 2, in_spatial=2, use_res_blocks=True, activation='SiLU')
from phi.torch.flow import *

net = fno(1, 2, 3, modes=12, activation='GeLU')
optimizer = adam(net, learning_rate=1e-3)


Expand All @@ -25,6 +23,7 @@ def loss_function(scale: Tensor, smoothness: Tensor):

viewer = view(gui='dash', scene=True)
for i in viewer.range():
if i == 100: break
loss = update_weights(net, optimizer, loss_function, gt_scale, gt_smoothness)
print(f'Iter : {i}, Loss : {loss}')
viewer.log_scalars(loss=loss)
29 changes: 29 additions & 0 deletions demos/INN_Test_Script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import math
from phi.torch.flow import *
net = invertible_net(1, 3, True, 'u_net', 'SiLU')
optimizer = adam(net, learning_rate=1e-3)

print(parameter_count(net))

def loss_function(smoothness: Tensor):
grid = CenteredGrid(Noise(smoothness=smoothness), x=8, y=8)
pred_smoothness = field.native_call(net, grid)

return math.l2_loss(pred_smoothness - smoothness)

gt_smoothness = math.random_uniform(batch(examples=10), low=0.5, high=1)

viewer = view(gui='dash', scene=True)
for i in viewer.range():
if i > 100: break
loss = update_weights(net, optimizer, loss_function, gt_smoothness)
if i % 10 == 0: print(f'Iter : {i}, Loss : {loss}')
viewer.log_scalars(loss=loss)

grid = CenteredGrid(Noise(scale=1.0, smoothness=gt_smoothness), x=8, y=8)
pred = field.native_call(net, grid, False)
reconstructed_input = field.native_call(net, pred, True)

print('Loss between Predicted Tensor and original grid', math.l2_loss(pred - grid))
print('Loss between Predicted Tensor and GT tensor', math.l2_loss(pred - gt_smoothness))
print('Loss between Reconstructed Input and original grid:', math.l2_loss(reconstructed_input - grid))
6 changes: 3 additions & 3 deletions demos/differentiate_pressure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@


DOMAIN = dict(x=80, y=64)
LEFT = StaggeredGrid(HardGeometryMask(Box(x=(-INF, 40), y=None)), 0, **DOMAIN)
RIGHT = StaggeredGrid(HardGeometryMask(Box(x=(40, INF), y=None)), extrapolation.ZERO, **DOMAIN)
LEFT = StaggeredGrid(Box(x=(-INF, 40), y=None), 0, **DOMAIN)
RIGHT = StaggeredGrid(Box(x=(40, INF), y=None), extrapolation.ZERO, **DOMAIN)
TARGET = RIGHT * StaggeredGrid(lambda x: math.exp(-0.5 * math.vec_squared(x - (50, 10), 'vector') / 32**2), extrapolation.ZERO, **DOMAIN) * (0, 2)


def loss(v0, p0):
v1, p = fluid.make_incompressible(v0 * LEFT, solve=Solve('CG-adaptive', 1e-5, 0, x0=p0))
v1, p = fluid.make_incompressible(v0 * LEFT, solve=Solve('CG-adaptive', 1e-5, x0=p0))
return field.l2_loss((v1 - TARGET) * RIGHT), v1, p


Expand Down
16 changes: 8 additions & 8 deletions demos/flip_liquid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
A liquid block collides with a rotated obstacle and falls into a liquid pool.
"""
from phi.field._point_cloud import distribute_points
from phi.torch.flow import *
# from phi.tf.flow import *
# from phi.torch.flow import *
from phi.tf.flow import *
# from phi.jax.flow import *


GRAVITY = tensor([0, -9.81])
DT = .2
OBSTACLE = Box(x=(1, 25), y=(30, 33)).rotated(-20)
ACCESSIBLE_CELLS = CenteredGrid(~OBSTACLE, 0, x=64, y=64)
_OBSTACLE_POINTS = PointCloud(Cuboid(field.support(1 - ACCESSIBLE_CELLS, 'points'), x=2, y=2), color='#000000', bounds=ACCESSIBLE_CELLS.bounds)
_OBSTACLE_POINTS = PointCloud(Cuboid(field.support(1 - ACCESSIBLE_CELLS, 'points'), x=2, y=2), bounds=ACCESSIBLE_CELLS.bounds)

particles = distribute_points(union(Box(x=(15, 30), y=(50, 60)), Box(x=None, y=(-INF, 5))), x=64, y=64) * (0, 0)
scene = vis.overlay(particles, _OBSTACLE_POINTS) # only for plotting
Expand All @@ -21,13 +21,13 @@
# @jit_compile
def step(particles):
# --- Grid Operations ---
velocity = prev_velocity = field.finite_fill(StaggeredGrid(particles, 0, x=64, y=64, scheme=Scheme(outside_points='clamp')))
occupied = CenteredGrid(particles.mask(), velocity.extrapolation.spatial_gradient(), velocity.bounds, velocity.resolution)
velocity = prev_velocity = field.finite_fill(resample(particles, StaggeredGrid(0, 0, x=64, y=64), scatter=True, outside_handling='clamp'))
occupied = resample(field.mask(particles), CenteredGrid(0, velocity.extrapolation.spatial_gradient(), velocity.bounds, velocity.resolution), scatter=True)
velocity, pressure = fluid.make_incompressible(velocity + GRAVITY * DT, [OBSTACLE], active=occupied)
# --- Particle Operations ---
particles += (velocity - prev_velocity) @ particles # FLIP update
# particles = velocity @ particles # PIC update
particles = advect.points(particles, velocity * ~OBSTACLE, DT, advect.finite_rk4)
particles += resample(velocity - prev_velocity, to=particles) # FLIP update
# particles = resample(velocity, particles) # PIC update
particles = advect.points(particles, velocity * mask(~OBSTACLE), DT, advect.finite_rk4)
particles = fluid.boundary_push(particles, [OBSTACLE, ~particles.bounds])
return particles, velocity, pressure

Expand Down
6 changes: 3 additions & 3 deletions demos/fluid_logo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

OBSTACLE_GEOMETRIES = [Box(x=(15 + x * 7, 15 + (x + 1) * 7), y=(41, 83)) for x in range(1, 10, 2)] + [Box['x,y', 43:50, 41:48], Box['x,y', 15:43, 83:90], Box['x,y', 50:85, 83:90]]
OBSTACLE = Obstacle(union(OBSTACLE_GEOMETRIES))
OBSTACLE_MASK = HardGeometryMask(OBSTACLE.geometry) @ CenteredGrid(0, extrapolation.BOUNDARY, **DOMAIN)
OBSTACLE_MASK = resample(OBSTACLE.geometry, to=CenteredGrid(0, extrapolation.BOUNDARY, **DOMAIN))

INFLOW = CenteredGrid(Box['x,y', 14:21, 6:10], extrapolation.BOUNDARY, **DOMAIN) + \
CenteredGrid(Box['x,y', 81:88, 6:10], extrapolation.BOUNDARY, **DOMAIN) * 0.9 + \
Expand All @@ -20,7 +20,7 @@

for _ in view('smoke, velocity, pressure, OBSTACLE_MASK', play=False, namespace=globals()).range(warmup=1):
smoke = advect.semi_lagrangian(smoke, velocity, 1) + INFLOW
buoyancy_force = smoke * (0, 0.1) @ velocity # resamples density to velocity sample points
buoyancy_force = resample(smoke * (0, 0.1), to=velocity)
velocity = advect.semi_lagrangian(velocity, velocity, 1) + buoyancy_force
velocity, pressure = fluid.make_incompressible(velocity, (OBSTACLE,), Solve('CG-adaptive', 1e-5, 0, x0=pressure))
velocity, pressure = fluid.make_incompressible(velocity, (OBSTACLE,), Solve('CG-adaptive', 1e-5, x0=pressure))
remaining_divergence = field.divergence(velocity)
4 changes: 2 additions & 2 deletions demos/fog.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
# Physics
temperature = diffuse.explicit(advect.mac_cormack(temperature, velocity, dt=1), 0.1, dt=1, substeps=2)
humidity = advect.mac_cormack(humidity, velocity, dt=1)
buoyancy_force = temperature * (0, 0.1) @ velocity # resamples smoke to velocity sample points
buoyancy_force = (temperature * (0, 0.1)).at(velocity)
velocity = advect.semi_lagrangian(velocity, velocity, 1) + buoyancy_force
velocity, pressure = fluid.make_incompressible(velocity, (), Solve('auto', 1e-5, 0, x0=pressure))
velocity, pressure = fluid.make_incompressible(velocity, (), Solve('auto', 1e-5, x0=pressure))
# Compute fog
fog = field.maximum(humidity - temperature, 0)
2 changes: 1 addition & 1 deletion demos/karman_vortex_street.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def step(v, p, dt=1.):
v = advect.semi_lagrangian(v, v, dt)
v = v * (1 - BOUNDARY_MASK) + BOUNDARY_MASK * (SPEED, 0)
return fluid.make_incompressible(v, [CYLINDER], Solve('auto', 1e-5, 0, x0=p))
return fluid.make_incompressible(v, [CYLINDER], Solve('auto', 1e-5, x0=p))


for _ in view('vorticity,velocity,pressure', namespace=globals()).range():
Expand Down
5 changes: 2 additions & 3 deletions demos/moving_obstacle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ def move_obstacle(obs: Obstacle):

obstacle = Obstacle(Box(x=(5, 11), y=(10, 16)), velocity=[1., 0], angular_velocity=tensor(0,))
velocity = StaggeredGrid(0, extrapolation.ZERO, **DOMAIN)
obstacle_mask = CenteredGrid(HardGeometryMask(obstacle.geometry), extrapolation.BOUNDARY, **DOMAIN)
obstacle_mask = CenteredGrid(obstacle.geometry, extrapolation.BOUNDARY, **DOMAIN)
pressure = None

for _ in view(velocity, obstacle_mask, play=True, namespace=globals()).range():
obstacle = move_obstacle(obstacle)
velocity = advect.mac_cormack(velocity, velocity, DT)
velocity, pressure = fluid.make_incompressible(velocity, (obstacle,))
fluid.masked_laplace.tracers.clear() # we will need to retrace because the matrix changes each step. This is not needed when JIT-compiling the physics.
obstacle_mask = HardGeometryMask(obstacle.geometry) @ pressure
obstacle_mask = resample(obstacle.geometry, pressure)
2 changes: 1 addition & 1 deletion demos/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@

for _ in view('velocity, pressure', namespace=globals()).range():
velocity = advect.semi_lagrangian(velocity, velocity, DT)
velocity, pressure = fluid.make_incompressible(velocity, solve=Solve('CG-adaptive', 1e-5, 0, x0=pressure))
velocity = diffuse.explicit(velocity, 0.1, DT)
velocity, pressure = fluid.make_incompressible(velocity, solve=Solve('CG-adaptive', 1e-5, x0=pressure))
10 changes: 5 additions & 5 deletions demos/point_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from phi.flow import *


points1 = PointCloud(vec(x=1, y=1), color='#ba0a04')
points2 = PointCloud(vec(x=20, y=20), color='#ba0a04')
points1 = PointCloud(vec(x=1, y=1))
points2 = PointCloud(vec(x=20, y=20))
# points = points1 & points2
points = field.stack([points1, points2], instance('points'))

Expand All @@ -15,8 +15,8 @@
points = advect.advect(points, points * (-1, 1), -5) # Euler

# Grid sampling
scattered_data = field.sample(points, velocity.elements)
scattered_grid = points @ velocity
scattered_sgrid = points @ StaggeredGrid(0, 0, velocity.bounds, velocity.resolution)
scattered_data = field.sample(points, velocity.elements, scatter=True)
scattered_grid = points.at(velocity, scatter=True)
scattered_sgrid = resample(points, to=StaggeredGrid(0, 0, velocity.bounds, velocity.resolution), scatter=True)

view(namespace=globals())
1 change: 0 additions & 1 deletion demos/rotating_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@
obstacle = obstacle.copied_with(geometry=obstacle.geometry.rotated(-obstacle.angular_velocity * DT)) # rotate bar
velocity = advect.mac_cormack(velocity, velocity, DT)
velocity, pressure = fluid.make_incompressible(velocity, (obstacle,), Solve('CG-adaptive', 1e-5, 1e-5))
fluid.masked_laplace.tracers.clear() # we will need to retrace because the matrix changes each step. This is not needed when JIT-compiling the physics.
obstacle_mask = CenteredGrid(obstacle.geometry, extrapolation.ZERO, **DOMAIN)
10 changes: 5 additions & 5 deletions demos/smoke_embedded_mesh.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from phi.flow import *

velocity = StaggeredGrid((0, 0), 0, x=32, y=32, bounds=Box(x=100, y=100)) # or CenteredGrid(...)
velocity_emb = velocity @ StaggeredGrid(0, velocity, x=64, y=64, bounds=Box(x=(30, 70), y=(40, 80)))
velocity_emb = velocity.at(StaggeredGrid(0, velocity, x=64, y=64, bounds=Box(x=(30, 70), y=(40, 80))))
smoke = CenteredGrid(0, extrapolation.BOUNDARY, x=200, y=200, bounds=Box(x=100, y=100))

OBSTACLE = Obstacle(Sphere(x=50, y=60, radius=5))
INFLOW = 0.2 * CenteredGrid(SoftGeometryMask(Sphere(x=50, y=9.5, radius=5)), 0, smoke.bounds, smoke.resolution)
INFLOW = 0.2 * resample(Sphere(x=50, y=9.5, radius=5), CenteredGrid(0, 0, smoke.bounds, smoke.resolution), soft=True)
pressure = None


# @jit_compile # Only for PyTorch, TensorFlow and Jax
def step(v, v_emb, s, p, dt=1.):
s = advect.mac_cormack(s, v_emb, dt) + INFLOW
buoyancy = s * (0, 0.1)
v_emb = advect.semi_lagrangian(v_emb, v_emb, dt) + (buoyancy @ v_emb) * dt
v = advect.semi_lagrangian(v, v, dt) + (buoyancy @ v) * dt
v, p = fluid.make_incompressible(v, [OBSTACLE], Solve('auto', 1e-5, 0, x0=p))
v_emb = advect.semi_lagrangian(v_emb, v_emb, dt) + buoyancy.at(v_emb) * dt
v = advect.semi_lagrangian(v, v, dt) + buoyancy.at(v) * dt
v, p = fluid.make_incompressible(v, [OBSTACLE], Solve('auto', 1e-5, x0=p))
# Perform the embedded pressure solve
p_emb_x0 = CenteredGrid(0, p, v_emb.bounds, v_emb.resolution)
v_emb = StaggeredGrid(v_emb, extrapolation.BOUNDARY, v_emb.bounds, v_emb.resolution)
Expand Down
6 changes: 3 additions & 3 deletions demos/smoke_plume.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@

velocity = StaggeredGrid((0, 0), 0, x=64, y=64, bounds=Box(x=100, y=100)) # or CenteredGrid(...)
smoke = CenteredGrid(0, extrapolation.BOUNDARY, x=200, y=200, bounds=Box(x=100, y=100))
INFLOW = 0.2 * CenteredGrid(SoftGeometryMask(Sphere(x=50, y=9.5, radius=5)), 0, smoke.bounds, smoke.resolution)
INFLOW = 0.2 * resample(Sphere(x=50, y=9.5, radius=5), to=smoke, soft=True)
pressure = None


# @jit_compile # Only for PyTorch, TensorFlow and Jax
def step(v, s, p, dt=1.):
s = advect.mac_cormack(s, v, dt) + INFLOW
buoyancy = s * (0, 0.1) @ v # resamples smoke to velocity sample points
buoyancy = resample(s * (0, 0.1), to=v)
v = advect.semi_lagrangian(v, v, dt) + buoyancy * dt
v, p = fluid.make_incompressible(v, (), Solve('auto', 1e-5, 0, x0=p))
v, p = fluid.make_incompressible(v, (), Solve(x0=p))
return v, s, p


Expand Down
6 changes: 3 additions & 3 deletions demos/smoke_plume_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@

velocity = StaggeredGrid((0, 0, 0), extrapolation.ZERO, x=32, y=32, z=32, bounds=Box(x=100, y=100, z=100)) # or CenteredGrid(...)
smoke = CenteredGrid(0, extrapolation.BOUNDARY, x=32, y=32, z=32, bounds=Box(x=100, y=100, z=100))
INFLOW = 0.2 * CenteredGrid(SoftGeometryMask(Sphere(x=50, y=50, z=10, radius=5)), 0, smoke.bounds, smoke.resolution)
INFLOW = 0.2 * resample(Sphere(x=50, y=50, z=10, radius=5), to=smoke, soft=True)
pressure = None


# @jit_compile # Only for PyTorch, TensorFlow and Jax
def step(v, s, p, dt=1.):
s = advect.mac_cormack(s, v, dt) + INFLOW
buoyancy = s * (0, 0, 0.1) @ v # resamples smoke to velocity sample points
buoyancy = resample(s * (0, 0, 0.1), to=v)
v = advect.semi_lagrangian(v, v, dt) + buoyancy * dt
v, p = fluid.make_incompressible(v, (), Solve('auto', 1e-5, 0, x0=p))
v, p = fluid.make_incompressible(v, (), Solve('auto', 1e-5, x0=p))
return v, s, p


Expand Down
10 changes: 5 additions & 5 deletions demos/smoke_plume_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
viewer = view(smoke, velocity, namespace=globals(), play=False)
for _ in viewer.range(warmup=1):
# Resize grids if needed
inflow = SoftGeometryMask(INFLOW) @ CenteredGrid(0, smoke.extrapolation, x=smoke_res ** 2, y=smoke_res ** 2, bounds=BOUNDS)
smoke = smoke @ inflow
velocity = velocity @ StaggeredGrid(0, velocity.extrapolation, x=v_res ** 2, y=v_res ** 2, bounds=BOUNDS)
inflow = resample(INFLOW, CenteredGrid(0, smoke.extrapolation, x=smoke_res ** 2, y=smoke_res ** 2, bounds=BOUNDS), soft=True)
smoke = resample(smoke, inflow)
velocity = velocity.at(StaggeredGrid(0, velocity.extrapolation, x=v_res ** 2, y=v_res ** 2, bounds=BOUNDS))
# Physics step
smoke = advect.mac_cormack(smoke, velocity, 1) + inflow
buoyancy_force = smoke * (0, 0.1) @ velocity # resamples smoke to velocity sample points
buoyancy_force = (smoke * (0, 0.1)).at(velocity)
velocity = advect.semi_lagrangian(velocity, velocity, 1) + buoyancy_force
try:
with math.SolveTape() as solves:
velocity, pressure = fluid.make_incompressible(velocity, (), Solve(pressure_solver, 1e-5, 0))
velocity, pressure = fluid.make_incompressible(velocity, (), Solve(pressure_solver, 1e-5))
viewer.log_scalars(solve_time=solves[0].solve_time)
viewer.info(f"Presure solve {v_res**2}x{v_res**2} with {solves[0].method}: {solves[0].solve_time * 1000:.0f} ms ({solves[0].iterations} iterations)")
except ConvergenceException as err:
Expand Down
Loading

0 comments on commit f8d0009

Please sign in to comment.