Skip to content

Commit

Permalink
Merge pull request #5 from tum-pbs/develop
Browse files Browse the repository at this point in the history
PyTorch Integration
  • Loading branch information
the-rccg authored Jan 27, 2020
2 parents 94c7f39 + 4312c66 commit 6d02565
Show file tree
Hide file tree
Showing 81 changed files with 1,587 additions and 624 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![Build Status](https://travis-ci.com/tum-pbs/PhiFlow.svg?token=8vG2QPsZzeswTApmkekH&branch=master)](https://travis-ci.com/tum-pbs/PhiFlow)
[![PyPI pyversions](https://img.shields.io/pypi/pyversions/phiflow.svg)](https://pypi.org/project/phiflow/)
[![PyPI license](https://img.shields.io/pypi/l/phiflow.svg)](https://pypi.org/project/phiflow/)
[![image](https://www.tensorflow.org/images/colab_logo_32px.png) Run in Google Colab](https://colab.research.google.com/drive/1S21OY8hzh1oZK2wQyL3BNXvSlrMTtRbV#offline=true&sandboxMode=true)
[<img src="https://colab.research.google.com/assets/colab-badge.svg" align="center">](https://colab.research.google.com/drive/1S21OY8hzh1oZK2wQyL3BNXvSlrMTtRbV#offline=true&sandboxMode=true)

![Gui](documentation/figures/WebInterface.png)

Expand All @@ -14,22 +14,23 @@ Having all functionality of a fluid simulation running in TensorFlow opens up th

## Features

- Support for a variety of differentiable simulation types, from Burgers over Navier-Stokes to the Schrödinger equation.
- Tight integration with [TensorFlow](https://www.tensorflow.org/) allowing for straightforward network training with fully differentiable simulations that run on the GPU.
- Variety of built-in fully-differentiable simulations, ranging from Burgers and Navier-Stokes to the Schrödinger equation.
- Tight integration with [TensorFlow](https://www.tensorflow.org/) and [PyTorch](https://pytorch.org/) (experimental) allowing for straightforward neural network training with fully differentiable simulations that run on the GPU.
- Object-oriented architecture enabling concise and expressive code, designed for ease of use and extensibility.
- Reusable simulation code, independent of backend and dimensionality, i.e. the exact same code can run a 2D fluid sim using NumPy and a 3D fluid sim on the GPU using TensorFlow.
- Reusable simulation code, independent of backend and dimensionality, i.e. the exact same code can run a 2D fluid sim using NumPy and a 3D fluid sim on the GPU using TensorFlow or PyTorch.
- Flexible, easy-to-use web interface featuring live visualizations and interactive controls that can affect simulations or network training on the fly.

## Installation

The following commands will get you Φ<sub>*Flow*</sub> + browser-GUI + NumPy execution:
To install Φ<sub>*Flow*</sub> with web interface, run:

```bash
$ pip install phiflow[gui]
```

See the [detailed installation instructions](documentation/Installation_Instructions.md) on how to install Φ<sub>*Flow*</sub>
with TensorFlow support.
Install TensorFlow or PyTorch in addition to Φ<sub>*Flow*</sub> to enable machine learning capabilities and GPU execution.

See the [detailed installation instructions](documentation/Installation_Instructions.md) on how to compile the custom CUDA operators and verify your installation.

## Documentation and Guides

Expand Down Expand Up @@ -72,7 +73,7 @@ The [software architecture documentation](documentation/Software_Architecture.md

## Version History

The [Version history](documentation/Version_History.md) lists all major changes since release.
The [Version history](https://github.com/tum-pbs/PhiFlow/releases) lists all major changes since release.

## Known Issues

Expand All @@ -82,7 +83,7 @@ Resampling / Advection: NumPy interpolation handles the boundaries slightly diff

## Contributions

Contributions are welcome! Check out [this document](documentation/Contributing.md) for some guidelines.
Contributions are welcome! Check out [this document](documentation/Contributing.md) for guidelines.

## Acknowledgements

Expand Down
4 changes: 1 addition & 3 deletions demos/burgers_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ class BurgersEquation(App):

def __init__(self, domain=Domain([64, 64], boundaries=PERIODIC)):
App.__init__(self, framerate=5)
initial_velocity = domain.centered_grid(data=lambda s: math.randfreq(s) * 2, components=domain.rank, name='velocity')
velocity = world.add(initial_velocity, physics=Burgers(viscosity=0.1))
self.add_field('Velocity', velocity)
world.add(BurgersVelocity(domain, velocity=lambda s: math.randfreq(s) * 2), physics=Burgers())


show()
15 changes: 8 additions & 7 deletions demos/fluid_logo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,20 @@ def create_tum_logo():
class SmokeLogo(App):

def __init__(self, resolution):
App.__init__(self, 'Fluid Logo', DESCRIPTION, summary='smoke' + 'x'.join([str(d) for d in resolution]), framerate=20)
smoke = self.smoke = world.add(Fluid(Domain(resolution, box=box[0:100, 0:100], boundaries=CLOSED), buoyancy_factor=0.1), physics=IncompressibleFlow())
App.__init__(self, 'Fluid Logo', DESCRIPTION, summary='fluid' + 'x'.join([str(d) for d in resolution]), framerate=20)
fluid = self.fluid = world.add(Fluid(Domain(resolution, box=box[0:100, 0:100], boundaries=CLOSED), buoyancy_factor=0.1), physics=IncompressibleFlow())
world.add_all(Inflow(box[6:10, 14:21], rate=1.0), Inflow(box[6:10, 79:86], 0.8), Inflow(box[49:50, 43:46], 0.1))
create_tum_logo()
# Add Fields
self.add_field('Density', lambda: smoke.density)
self.add_field('Velocity', lambda: smoke.velocity)
self.add_field('Domain', lambda: obstacle_mask(smoke).at(smoke.density))
self.add_field('Remaining Divergence', lambda: smoke.velocity.divergence())
self.add_field('Density', lambda: fluid.density)
self.add_field('Velocity', lambda: fluid.velocity)
self.add_field('Domain', lambda: obstacle_mask(fluid).at(fluid.density))
self.add_field('Remaining Divergence', lambda: fluid.velocity.divergence())
self.add_field('Pressure', lambda: fluid.solve_info.get('pressure', None))

def action_reset(self):
self.steps = 0
self.smoke.density = self.smoke.velocity = 0
self.fluid.density = self.fluid.velocity = 0


show(SmokeLogo([int(sys.argv[1])] * 2 if len(sys.argv) > 1 and __name__ == '__main__' else [128] * 2),
Expand Down
4 changes: 2 additions & 2 deletions demos/loader_mantaflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
DIMS = 2


class DataLoader(TFApp):
class DataLoader(App):

def __init__(self, scene_path, dims, mantaflowRes):
TFApp.__init__(self, 'Data Demo')
App.__init__(self, 'Data Demo')

smoke = world.add(Fluid(Domain([mantaflowRes - 1] * dims)), physics=IncompressibleFlow()) # 2D: YXc , 3D: ZYXc
smoke.velocity = smoke.density = placeholder # switch to TF tensors
Expand Down
33 changes: 17 additions & 16 deletions demos/manual_fluid_numpy_or_tf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# example that runs a "manual" simple INCOMPRESSIBLE_FLOW sim either in numpy or TF
# example that runs a "manual" simple incompressible fluid sim either in numpy or TF
# note, this example does not use the dash GUI, instead it creates PNG images via PIL

import sys
Expand All @@ -18,12 +18,12 @@
RES = 32
DT = 0.6

# by default, creates a numpy state, i.e. "INCOMPRESSIBLE_FLOW.density.data" is a numpy array
INCOMPRESSIBLE_FLOW = Fluid(Domain([RES] * DIM, boundaries=OPEN), batch_size=BATCH_SIZE)
# by default, creates a numpy state, i.e. "FLOW.density.data" is a numpy array
FLOW = Fluid(Domain([RES] * DIM, boundaries=OPEN), batch_size=BATCH_SIZE, buoyancy_factor=0.2)

if MODE == 'NumPy':
DENSITY = INCOMPRESSIBLE_FLOW.density
VELOCITY = INCOMPRESSIBLE_FLOW.velocity
DENSITY = FLOW.density
VELOCITY = FLOW.velocity
# no phiflow session for pure numpy, write to specific directory instead
IMG_PATH = os.path.expanduser("~/phi/data/manual/numpy")
if not os.path.exists(IMG_PATH):
Expand All @@ -33,9 +33,9 @@
SESSION = Session(SCENE)
IMG_PATH = SCENE.path
# create TF placeholders with the correct shapes
INCOMPRESSIBLE_FLOW_IN = INCOMPRESSIBLE_FLOW.copied_with(density=placeholder, velocity=placeholder)
DENSITY = INCOMPRESSIBLE_FLOW_IN.density
VELOCITY = INCOMPRESSIBLE_FLOW_IN.velocity
FLOW_IN = FLOW.copied_with(density=placeholder, velocity=placeholder)
DENSITY = FLOW_IN.density
VELOCITY = FLOW_IN.velocity

# optional , write images
SAVE_IMAGES = False
Expand All @@ -60,13 +60,13 @@ def save_img(array, scale, name, idx=0):
# def save_img(array, scale, name, idx=0):
print("(Skipping image output)")

# main , step 1: run INCOMPRESSIBLE_FLOW sim (numpy), or only set up graph for TF
# main , step 1: run FLOW sim (numpy), or only set up graph for TF

for i in range(STEPS if (MODE == 'NumPy') else GRAPH_STEPS):
# simulation step; note that the core is only 3 lines for the actual simulation
# the RESt is setting up the inflow, and debug info afterwards

INFLOW_DENSITY = math.zeros_like(INCOMPRESSIBLE_FLOW.density)
INFLOW_DENSITY = math.zeros_like(FLOW.density)
if DIM == 2:
# (batch, y, x, components)
INFLOW_DENSITY.data[..., (RES // 4 * 2):(RES // 4 * 3), (RES // 4):(RES // 4 * 3), 0] = 1.
Expand All @@ -75,8 +75,8 @@ def save_img(array, scale, name, idx=0):
INFLOW_DENSITY.data[..., (RES // 4 * 2):(RES // 4 * 3), (RES // 4 * 1):(RES // 4 * 3), (RES // 4):(RES // 4 * 3), 0] = 1.

DENSITY = advect.semi_lagrangian(DENSITY, VELOCITY, DT) + DT * INFLOW_DENSITY
VELOCITY = advect.semi_lagrangian(VELOCITY, VELOCITY, DT) + buoyancy(DENSITY, 9.81, INCOMPRESSIBLE_FLOW.buoyancy_factor) * DT
VELOCITY = divergence_free(VELOCITY, INCOMPRESSIBLE_FLOW.domain, obstacles=())
VELOCITY = advect.semi_lagrangian(VELOCITY, VELOCITY, DT) + buoyancy(DENSITY, 9.81, FLOW.buoyancy_factor) * DT
VELOCITY = divergence_free(VELOCITY, FLOW.domain, obstacles=())

if i == 0:
print("Density type: %s" % type(DENSITY.data)) # here we either have np array of tf tensor
Expand All @@ -92,15 +92,16 @@ def save_img(array, scale, name, idx=0):

if MODE == 'TensorFlow':
# for TF, all the work still needs to be done, feed empty state and start simulation
INCOMPRESSIBLE_FLOW_OUT = INCOMPRESSIBLE_FLOW.copied_with(density=DENSITY, velocity=VELOCITY, age=INCOMPRESSIBLE_FLOW.age + DT)
FLOW_OUT = FLOW.copied_with(density=DENSITY, velocity=VELOCITY, age=FLOW.age + DT)

# run session
for i in range(STEPS // GRAPH_STEPS):
INCOMPRESSIBLE_FLOW = SESSION.run(INCOMPRESSIBLE_FLOW_OUT, feed_dict={INCOMPRESSIBLE_FLOW_IN: INCOMPRESSIBLE_FLOW}) # Passes DENSITY and VELOCITY tensors
FLOW = SESSION.run(FLOW_OUT, feed_dict={FLOW_IN: FLOW}) # Passes DENSITY and VELOCITY tensors

# for TF, we only have RESults now after each GRAPH_STEPS iterations
if SAVE_IMAGES:
save_img(INCOMPRESSIBLE_FLOW.density.data, 10000., IMG_PATH + "/tf_%04d.png" % (GRAPH_STEPS * (i + 1) - 1))
save_img(FLOW.density.data, 10000., IMG_PATH + "/tf_%04d.png" % (GRAPH_STEPS * (i + 1) - 1))

print("Step SESSION.run %04d done, DENSITY shape %s, means %s %s" %
(i, INCOMPRESSIBLE_FLOW.density.data.shape, np.mean(INCOMPRESSIBLE_FLOW.density.data), np.mean(INCOMPRESSIBLE_FLOW.velocity.staggered_tensor())))
(i, FLOW.density.data.shape, np.mean(FLOW.density.data), np.mean(FLOW.velocity.staggered_tensor())))

6 changes: 3 additions & 3 deletions demos/optimize_pressure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
"""


class PressureOptimization(TFApp):
class PressureOptimization(LearningApp):

def __init__(self):
TFApp.__init__(self, 'Pressure Optimization', DESCRIPTION, learning_rate=0.1, epoch_size=5)
LearningApp.__init__(self, 'Pressure Optimization', DESCRIPTION, learning_rate=0.1, epoch_size=5)
# --- Physics ---
domain = Domain([62, 62], boundaries=CLOSED)
with self.model_scope():
Expand All @@ -31,7 +31,7 @@ def __init__(self):
target_velocity = math.expand_dims(np.stack([target_velocity_y, np.zeros_like(target_velocity_y)], axis=-1), 0)
target_velocity *= self.editable_float('Target_Direction', 1, [-1, 1], log_scale=False)
# --- Optimization ---
loss = math.l2_loss(velocity.staggered_tensor()[:, :, 31:, :] - target_velocity[:, :, 31:, :])
loss = math.l2_loss(math.sub(velocity.staggered_tensor()[:, :, 31:, :], target_velocity[:, :, 31:, :]))
self.add_objective(loss, 'Loss')
# --- Display ---
gradient = StaggeredGrid(tf.gradients(loss, [component.data for component in optimizable_velocity.unstack()]))
Expand Down
4 changes: 2 additions & 2 deletions demos/simple_tfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def network(density):
return f_o


class TrainingTest(TFApp):
class TrainingTest(LearningApp):

def __init__(self):
TFApp.__init__(self, 'Training', DESCRIPTION, learning_rate=2e-4, validation_batch_size=4, training_batch_size=8)
LearningApp.__init__(self, 'Training', DESCRIPTION, learning_rate=2e-4, validation_batch_size=4, training_batch_size=8)
# --- Setup simulation and placeholders ---
smoke_in, load_dict = load_state(Fluid(Domain(RESOLUTION)))
# --- Build neural network ---
Expand Down
2 changes: 1 addition & 1 deletion demos/wavepacket.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ def action_reset(self):
wave_vector=[1 * self.value_frequency, 0.6 * self.value_frequency])


show()
show(WavePacketDemo)
45 changes: 39 additions & 6 deletions documentation/Contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,47 @@
All contributions are welcome!
You can mail the developers to get in touch.

## Code Style
We have some code style guidelines that help make the code easier to read.
We mainly use PyLint for static code analysis with specific configuration files for
## Types of contributions we're looking for

We're open to all kind of contributions that improve or extend the Φ<sub>*Flow*</sub> library.
Have a look at the [roadmap](https://github.com/tum-pbs/PhiFlow/projects/1) to see what is planned and what's currently being done.

We especially welcome
- New equations / solvers
- Code optimizations or native (CUDA) implementations.
- Integrations with other computing libraries such as [PyTorch](https://pytorch.org/) or [Jax](https://github.com/google/jax).
- Bug fixes

Φ<sub>*Flow*</sub> is a framework, not an application collection.
While we list applications in the [demos](../demos) directory, these should be short and easy to understand.

## How to Contribute

We recommend you to contact the developers before starting your contribution.
There may already be similar internal work or planned changes that would affect how to code the contribution.
Also check the [roadmap](https://github.com/tum-pbs/PhiFlow/projects/1).

To contribute code, fork Φ<sub>*Flow*</sub> on GitHub, make your changes, and submit a pull request.
Make sure that your contribution passes all tests.

The code you contribute should be able to run in at least 1D, 2D and 3D without additional modifications required by the user.

## Style Guide
Style guidelines make the code more uniform and easier to read.
Generally we stick to the Python style guidelines as outlined in [PEP 8](https://www.python.org/dev/peps/pep-0008/), with some minor modifications outlined below.

Have a look at the [Zen](https://en.wikipedia.org/wiki/Zen_of_Python) [of Python](https://www.python.org/dev/peps/pep-0020/) for the philosophy behind the rules.
We would like to add the rule *Concise is better than repetitive.*

We use PyLint for static code analysis with specific configuration files for
[demos](../demos/.pylintrc),
[tests](../tests/.pylintrc) and the
[code base](../phi/.pylintrc).
PyLint is part of the automatic testing pipeline on [Travis CI](https://travis-ci.com/tum-pbs/PhiFlow). The warning log can be viewed online by selecting a Python 3.6 job on Travis CI and expanding the pylint output at the bottom.

Additional style choices
- Long lines are allowed.
- Code comments should go in the same line as the code they refer to (if possible).
- Code comments that title multiple lines preceed the block and have the format `# --- Comment ---`.
- **No line length limit**; long lines are allowed.
- **Code comments** should only describe information that is not obvious from the code. They should be used sparingly as the code should be understandable by itself. For documentation, use docstrings instead. Code comments that explain a single line of code should go in the same line as the code they refer to, if possible.
- Code comments that describe multiple lines precede the block and have the format `# --- Comment ---`.
- No empty lines inside of methods. To separate code blocks use multi-line comments as described above.
- Use the apostrophe character ' to enclose strings that affect the program / are not displayed to the user.
6 changes: 3 additions & 3 deletions documentation/Interactive_Training_Apps.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This document assumes you have some basic knowledge of `Apps` and how they interact with the GUI.
If not, checkout the [documentation](Web_Interface.md).

If the purpose of your application is to train a TensorFlow model, your main application class should extend [TFApp](../phi/tf/app.py) which in turn extends `App`.
If the purpose of your application is to train a TensorFlow model, your main application class should extend [LearningApp](../phi/tf/app.py) which in turn extends `App`.
This has a couple of benefits:

- Model parameters can be saved and loaded from the GUI
Expand All @@ -25,10 +25,10 @@ data sets such as the ones used in this example).
```python
from phi.tf.flow import *

class TrainingTest(TFApp):
class TrainingTest(LearningApp):

def __init__(self):
TFApp.__init__(self, "Training")
LearningApp.__init__(self, "Training")
fluid = world.add(Fluid(Domain([64] * 2), density=placeholder, velocity=placeholder))

with self.model_scope():
Expand Down
15 changes: 15 additions & 0 deletions documentation/Package_Info.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# phiflow

Research-oriented differentiable fluid simulation framework

[Project Homepage on GitHub](https://github.com/tum-pbs/PhiFlow)

## Installation

To install phiflow with web interface, run

```
$ pip install phiflow[gui]
```

To run phiflow with custom CUDA operators, you have to install it from source, see the [detailed installation instructions](documentation/Installation_Instructions.md).
2 changes: 1 addition & 1 deletion documentation/Reading_and_Writing_Data.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ whole_dataset = Dataset.load('~/phi/data/simpleplume')
training_data = Dataset.load('~/phi/data/simpleplume', range(1000), name='train')
```

Classes that extend [`TFApp`](../phi/tf/app.py) only need to call `self.set_data`, passing a training and validation dataset as well as a struct containing TensorFlow placeholders (see the [documentation](Interactive_Training_Apps.md)).
Classes that extend [`LearningApp`](../phi/tf/app.py) only need to call `self.set_data`, passing a training and validation dataset as well as a struct containing TensorFlow placeholders (see the [documentation](Interactive_Training_Apps.md)).

### Channels

Expand Down
2 changes: 1 addition & 1 deletion documentation/Scene_Format_Specification.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ The following content was created by running the [simpleplume.py](../demos/simpl
"module": "phi.physics.objects"
}
],
"type": "CollectiveState",
"type": "StateCollection",
"module": "phi.physics.collective"
}
}
Expand Down
1 change: 1 addition & 0 deletions documentation/Staggered_Grids.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ New grids can also be created from the simulation object.
from phi.tf.flow import *

centered_zeros = fluid.centered_grid('f0', 0)
centered_zeros = CenteredGrid.sample(0, fluid.domain)
staggered_zeros = fluid.staggered_grid('v', 0)
```

Expand Down
30 changes: 0 additions & 30 deletions documentation/Version_History.md

This file was deleted.

7 changes: 0 additions & 7 deletions documentation/v0.3.md

This file was deleted.

Loading

0 comments on commit 6d02565

Please sign in to comment.