diff --git a/.travis.yml b/.travis.yml index 44b7b93e8..dde0f3e8b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,12 +3,20 @@ language: python python: - "2.7" - "3.6" + - "3.7" + - "3.8" + +matrix: + allow_failures: + - python: "3.8" before_install: - pip install -U pip install: - - pip install --quiet tensorflow==1.14.0 nbformat ipython pylint + - if [[ $TRAVIS_PYTHON_VERSION != 3.8 ]]; then pip install --quiet tensorflow==1.14.0; fi + - if [[ $TRAVIS_PYTHON_VERSION == "3.8" ]]; then pip install --quiet https://github.com/ppwwyyxx/tensorflow-wheels/releases/download/v0.2/tensorflow-1.15.0-cp38-cp38-linux_x86_64.whl numba llvmlite pytest; fi + - pip install nbformat ipython pylint; - pip install . script: diff --git a/documentation/Contributing.md b/CONTRIBUTING.md similarity index 97% rename from documentation/Contributing.md rename to CONTRIBUTING.md index 73c38cdcf..5b9950107 100644 --- a/documentation/Contributing.md +++ b/CONTRIBUTING.md @@ -47,3 +47,4 @@ Additional style choices - Code comments that describe multiple lines precede the block and have the format `# --- Comment ---`. - No empty lines inside of methods. To separate code blocks use multi-line comments as described above. - Use the apostrophe character ' to enclose strings that affect the program / are not displayed to the user. +- **Sphinx Docstring** format is used throughout the code base diff --git a/README.md b/README.md index 923809cd6..8d631b03d 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Build Status](https://travis-ci.com/tum-pbs/PhiFlow.svg?token=8vG2QPsZzeswTApmkekH&branch=master)](https://travis-ci.com/tum-pbs/PhiFlow) [![PyPI pyversions](https://img.shields.io/pypi/pyversions/phiflow.svg)](https://pypi.org/project/phiflow/) [![PyPI license](https://img.shields.io/pypi/l/phiflow.svg)](https://pypi.org/project/phiflow/) -[](https://colab.research.google.com/drive/1S21OY8hzh1oZK2wQyL3BNXvSlrMTtRbV#offline=true&sandboxMode=true) +[![Google Collab Book](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1S21OY8hzh1oZK2wQyL3BNXvSlrMTtRbV#offline=true&sandboxMode=true) ![Gui](documentation/figures/WebInterface.png) @@ -34,7 +34,7 @@ See the [detailed installation instructions](documentation/Installation_Instruct ## Documentation and Guides -| [Index](documentation) | [Demos](demos) / [Tests](tests) | [Source](phi) | ![image](https://www.tensorflow.org/images/colab_logo_32px.png) [Fluids Tutorial](https://colab.research.google.com/drive/1S21OY8hzh1oZK2wQyL3BNXvSlrMTtRbV#offline=true&sandboxMode=true) / [Playground](https://colab.research.google.com/drive/1zBlQbmNguRt-Vt332YvdTqlV4DBcus2S#offline=true&sandboxMode=true) | +| [Index](documentation) | [Demos](demos) / [Tests](tests) | [Source](phi) | [ Fluids Tutorial](https://colab.research.google.com/drive/1S21OY8hzh1oZK2wQyL3BNXvSlrMTtRbV#offline=true&sandboxMode=true) / [Playground](https://colab.research.google.com/drive/1zBlQbmNguRt-Vt332YvdTqlV4DBcus2S#offline=true&sandboxMode=true) | |------------------------|---------------------------------|---------------| -----------------------------| If you would like to get right into it and have a look at some code, check out the @@ -83,7 +83,7 @@ Resampling / Advection: NumPy interpolation handles the boundaries slightly diff ## Contributions -Contributions are welcome! Check out [this document](documentation/Contributing.md) for guidelines. +Contributions are welcome! Check out [this document](CONTRIBUTING.md) for guidelines. ## Acknowledgements diff --git a/demos/simple_tfmodel.py b/demos/simple_tfmodel.py index 1c7382db6..c4b9898e7 100644 --- a/demos/simple_tfmodel.py +++ b/demos/simple_tfmodel.py @@ -1,7 +1,8 @@ +# coding=utf-8 from phi.tf.flow import * -RESOLUTION = y, x = 64, 64 +DOMAIN = Domain([64, 64], boundaries=OPEN) # [y, x] DATAPATH = os.path.expanduser('~/phi/data/smoke/') # at least 10 sims, has to match RESOLUTION DESCRIPTION = u""" Train a neural network to reproduce the flow field given the marker density. @@ -46,7 +47,7 @@ class TrainingTest(LearningApp): def __init__(self): LearningApp.__init__(self, 'Training', DESCRIPTION, learning_rate=2e-4, validation_batch_size=4, training_batch_size=8) # --- Setup simulation and placeholders --- - smoke_in, load_dict = load_state(Fluid(Domain(RESOLUTION))) + smoke_in, load_dict = load_state(Fluid(DOMAIN)) # --- Build neural network --- with self.model_scope(): pred_vel = network(smoke_in.density.data) diff --git a/documentation/Package_Info.md b/documentation/Package_Info.md index 5a633453c..7ca83a570 100644 --- a/documentation/Package_Info.md +++ b/documentation/Package_Info.md @@ -12,4 +12,4 @@ To install phiflow with web interface, run $ pip install phiflow[gui] ``` -To run phiflow with custom CUDA operators, you have to install it from source, see the [detailed installation instructions](documentation/Installation_Instructions.md). \ No newline at end of file +To run phiflow with custom CUDA operators, you have to install it from source, see the [detailed installation instructions](https://github.com/tum-pbs/PhiFlow/blob/master/documentation/Installation_Instructions.md). \ No newline at end of file diff --git a/documentation/Software_Architecture.md b/documentation/Software_Architecture.md index 88b1a937a..7a70781df 100644 --- a/documentation/Software_Architecture.md +++ b/documentation/Software_Architecture.md @@ -3,27 +3,27 @@ ## Context -![Context](documentation/figures/Context.png) +![Context](./figures/Context.png) | Actor / System | Description | |----------------------|-------------------------------------------------------------------------------------------------| | ML Researcher | Scientist interested in training ML models, publishing results | | User | Person who wants to run built-in simulations and store / analyse the results | | NumPy | Non-differentiable Python computing library | -| TensorFlow | Machine-learning framework supporting GPU computations and reverse-mode differentiation | +| TensorFlow, PyTorch | Machine-learning frameworks supporting GPU computation and reverse-mode differentiation | ## Building Blocks -![Building Blocks](documentation/figures/Building_Blocks.png) +![Building Blocks](./figures/Building_Blocks.png) | Actor / System | Description | |-------------------|-------------------------------------------------------------------------------------------------------| -| Model | Allows setting up simulations and GUI | -| TF Model | Trains neural networks, creates logs, visualizes results with UI | +| App | Allows setting up simulations and GUI | +| LearningApp | Trains neural networks, creates logs, visualizes results with UI | | Data | Writes and loads data from disc | -| UI | Hosts web server to display data of Model | +| UI | Hosts web server to display data of Model | | Physics | Defines simulation classes, implements built-in simulations like Navier-Stokes, Schrödinger | ## Module dependencies -![Module Diagram](documentation/figures/Module_Diagram.png) +![Module Diagram](./figures/Module_Diagram.png) diff --git a/documentation/Structs.ipynb b/documentation/Structs.ipynb index 9921d2a24..fa9566df0 100644 --- a/documentation/Structs.ipynb +++ b/documentation/Structs.ipynb @@ -6,10 +6,9 @@ "metadata": {}, "outputs": [], "source": [ - "import sys\n", + "import sys; sys.path.append('../')\n", "import os\n", - "import numpy\n", - "sys.path.append('../')" + "import numpy" ] }, { @@ -31,12 +30,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Examples" + "## Building structs from lists and dicts\n", + "\n", + "The following cell declares three structs, `a`, `b` and `c`. A full list of what classes can be used to construct structs is given later." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -50,12 +51,59 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "All functions in `phi.math` can be called on structs. This broadcasts the corresponding calls to all contained arrays." + "Like arrays and tensors, structs have data types and shapes. Since a struct can contain many tensors with different types and shapes, these properties return the results in the same structure as the original struct." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[dtype('int32'), dtype('float64')]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "struct.dtype(a)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(), (2,)]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "struct.shape(a)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Functions in `phi.math` can be called on structs. This broadcasts the corresponding calls to all contained arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": { "scrolled": false }, @@ -67,7 +115,7 @@ " {'x0': 0.0, 'x1': 0.479425538604203, 'x2': 0.8414709848078965})" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -82,7 +130,7 @@ "metadata": {}, "source": [ "Note that math.sin calls numpy.sin in this case.\n", - "If the struct contained TensorFlow tensors, tensorflow.sin would be called instead." + "If the struct contained TensorFlow or PyTorch tensors, the corresponding sin functions would be called instead." ] }, { @@ -96,14 +144,21 @@ "- Tuples\n", "- Dicts containing strings as keys\n", "- NumPy arrays with `dtype=numpy.object`\n", - "- Subclasses of [`phi.math.struct.Struct`](../phi/struct/struct.py)\n", + "- Subclasses of [`phi.struct.Struct`](../phi/struct/struct.py)\n", "\n", "All `phi.math` functions and functions of the struct API work with any of the above types.\n", "\n", "While all entries of lists, tuples, dicts and NumPy arrays are expected to hold data,\n", "subclasses of the `Struct` class can define further properties which are not subject to the above mentioned functions.\n", "\n", - "Struct items are categorized into *variables* and *constants*. Variables are expected to change when a struct is run through an algorithm while constants should generally only be changed by the user. Typically, variables hold data in the form of tensors or grids while constants hold scalar values, booleans or strings." + "Struct items are separated into three categories:\n", + "- *Variables*: properties that change over time. They span the state space of a physical system in which the system moves over time.\n", + "- *Constants*: system characteristics that specify a certain system but do not change over time.\n", + "- *Derived quantities*: properties that may change over time but can fully be derived from variables and constants.\n", + "\n", + "Variables are expected to change when a struct is run through an algorithm while constants should generally only be changed by the user. Each item can hold data (tensors), other structs or other values.\n", + "\n", + "Typically, variables hold data while constants hold scalar values, booleans or strings." ] }, { @@ -112,12 +167,14 @@ "source": [ "## Iterating over structs\n", "\n", - "The struct interface provides the function `map` which iterates over all data-holding items of a struct and its sub-structs by default." + "The struct interface provides the function `map` which iterates over all data-holding items of a struct and its sub-structs by default.\n", + "\n", + "All items of lists and dicts are considered to be data-holding variables." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": { "scrolled": true }, @@ -128,7 +185,7 @@ "(['1', '[0. 0.]'], {'x0': '0', 'x1': '0.5', 'x2': '1'})" ] }, - "execution_count": 5, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -146,7 +203,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -155,7 +212,7 @@ "('[1, array([0., 0.])]', \"{'x0': 0, 'x1': 0.5, 'x2': 1}\")" ] }, - "execution_count": 6, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -174,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -183,7 +240,7 @@ "(['1', '[0. 0.]'], \"{'x0': 0, 'x1': 0.5, 'x2': 1}\")" ] }, - "execution_count": 7, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -197,7 +254,9 @@ "metadata": {}, "source": [ "The parameter `item_condition` can further specify which types of items should be affected by a struct operation.\n", - "The constants `VARIABLES`, `CONSTANTS`, `DATA` and `ALL_ITEMS` are part of the `struct` package and can be used for the `item_condition`." + "The constants `VARIABLES`, `CONSTANTS`, `DATA` and `ALL_ITEMS` are part of the `struct` package and can be used for the `item_condition`.\n", + "\n", + "The item condition can also be set through the context. Operations within a `with item_condition:` block that do not override the `item_condition` parameter, use the context item condition instead. The default context item condition is `DATA`." ] }, { @@ -210,7 +269,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -219,7 +278,7 @@ "([0, 1], {'x0': 'x0', 'x1': 'x1', 'x2': 'x2'})" ] }, - "execution_count": 8, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -234,23 +293,36 @@ "source": [ "## Usages in Φ*Flow*\n", "\n", - "In Φ*Flow*, structs are mostly used to store simulation states, i.e.\n", - "each attribute holds a tensor such as density or velocity of a fluid simulation.\n", - "In particular, the state base class [`phi.physics.physics.State`](../phi/physics/physics.py) extends `Struct`.\n", - "All Field classes such as StaggeredGrid are also structs.\n", + "In Φ*Flow*, structs are used to represent simulation states.\n", + "Not only the state base class -- [`phi.physics.physics.State`](../phi/physics/physics.py) -- extends `Struct`, but also objects depended upon by states such as `Domain`, `Sphere`, `CenteredGrid` etc. inherit from `Struct`.\n", "\n", - "Properties are used to hold additional parameters for the simulation that should be included in the `description.json` file. Typical examples of these include `viscosity` or `buoyancy_factor`." + "Variables hold the current state of the system (e.g. current velocity field) while constants describe the system itself (e.g. fluid viscosity).\n", + "\n", + "Let's have a look at the structure of a `Fluid` state." ] }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, + "execution_count": 22, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "--- Variables ---\n", + "age (float)\n", + "density.age (float)\n", + "density.data (ndarray)\n", + "velocity.age (float)\n", + "velocity.data.0.age (float)\n", + "velocity.data.0.data (ndarray)\n", + "velocity.data.1.age (float)\n", + "velocity.data.1.data (ndarray)\n", + "\n", + "--- Constants ---\n", "buoyancy_factor (float)\n", "domain.boundaries.friction (float)\n", "domain.boundaries.name (str)\n", @@ -259,61 +331,25 @@ "domain.resolution (ndarray)\n", "domain.box.lower (ndarray)\n", "domain.box.upper (ndarray)\n", - "density.data (ndarray)\n", - "density.box.lower (ndarray)\n", - "density.box.upper (ndarray)\n", - "density.extrapolation (str)\n", - "density.interpolation (str)\n", - "density.age (float)\n", - "density.name (str)\n", - "density.tags.0 (str)\n", - "density.tags.1 (str)\n", - "velocity.resolution (ndarray)\n", - "velocity.box.lower (ndarray)\n", - "velocity.box.upper (ndarray)\n", - "velocity.data.0.data (ndarray)\n", - "velocity.data.0.box.lower (ndarray)\n", - "velocity.data.0.box.upper (ndarray)\n", - "velocity.data.0.extrapolation (str)\n", - "velocity.data.0.interpolation (str)\n", - "velocity.data.0.age (float)\n", - "velocity.data.0.name (str)\n", - "velocity.data.0.tags.0 (str)\n", - "velocity.data.0.tags.1 (str)\n", - "velocity.data.1.data (ndarray)\n", - "velocity.data.1.box.lower (ndarray)\n", - "velocity.data.1.box.upper (ndarray)\n", - "velocity.data.1.extrapolation (str)\n", - "velocity.data.1.interpolation (str)\n", - "velocity.data.1.age (float)\n", - "velocity.data.1.name (str)\n", - "velocity.data.1.tags.0 (str)\n", - "velocity.data.1.tags.1 (str)\n", - "velocity.flags.0.field_types.0 (str)\n", - "velocity.flags.0.is_data_bound (bool)\n", - "velocity.flags.0.is_structure_bound (bool)\n", - "velocity.flags.0.name (str)\n", - "velocity.flags.0.propagators.0 (str)\n", - "velocity.flags.0.propagators.1 (str)\n", - "velocity.age (float)\n", - "velocity.name (str)\n", - "velocity.tags.0 (str)\n", - "velocity.tags.1 (str)\n", - "age (float)\n", - "name (str)\n", - "tags.0 (str)\n", - "tags.1 (str)\n" + "name (str)\n" ] } ], "source": [ "from phi.flow import *\n", + "\n", "fluid = Fluid(Domain([80, 64]))\n", "\n", "def print_name(trace):\n", " print('%s (%s)' % (trace.path(), type(trace.value).__name__))\n", - " return trace.value\n", - "struct.map(print_name, fluid, trace=True, item_condition=None);" + "\n", + "print('--- Variables ---')\n", + "with struct.VARIABLES:\n", + " struct.map(print_name, fluid, trace=True, content_type=struct.INVALID);\n", + "\n", + "print('\\n--- Constants ---')\n", + "with struct.CONSTANTS:\n", + " struct.map(print_name, fluid, trace=True, content_type=struct.INVALID);" ] }, { @@ -331,11 +367,35 @@ "\n", "Initializer functions such as `zeros` or `placeholder` internally call their counterparts in NumPy or TensorFlow.\n", "They can take 1D-tensors describing the shape as input but also support structs holding shapes.\n", - "The call `zeros(StaggeredGrid([1,65,65,2]))` will return a `StaggeredGrid` holding a NumPy array.\n", - "\n", - "Some states simplify this even further by allowing a syntax like `Fluid(density=zeros)` or `Fluid(velocity=placeholder)`.\n", "\n", - "The `placeholder` and `variable` initializers also infer the name of the resulting tensors from the attribute names." + "Some states simplify this even further by allowing a syntax like `Fluid(density=zeros)` or `Fluid(velocity=placeholder)`." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'small array': array([[0.]], dtype=float32),\n", + " 'large array': array([[[0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)}" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "math.zeros({'small array': (1,1), 'large array': [1, 64, 64]})" ] }, { @@ -354,23 +414,49 @@ "source": [ "### Session\n", "\n", - "The [`Session`](../phi/tf/session.py) class is a customized version of `tf.Session` which accepts structs for the `fetches` argument as well as inside the `feed_dict`.\n", + "The TensorFlow [`Session`](../phi/tf/session.py) class is a customized version of `tf.Session` which accepts structs for the `fetches` argument as well as inside the `feed_dict`.\n", "\n", "This can be used to quickly run states through a graph like so:" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 27, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning:\n", + "\n", + "Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", + "For more information, please see:\n", + " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", + " * https://github.com/tensorflow/addons\n", + "If you depend on functionality not listed there, please file an issue.\n", + "\n", + "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\control_flow_ops.py:423: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Colocations handled automatically by placer.\n" + ] + }, { "data": { "text/plain": [ "Fluid[density: Grid[16x16(1), size=[16. 16.]], velocity: StaggeredGrid[16x16, size=[16. 16.]]]" ] }, - "execution_count": 13, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -392,14 +478,15 @@ "## Validity\n", "\n", "As structs are supposed to hold data in a specific structure, there is a preferred data type for each entry.\n", - "For a CenteredGrid, the `data` attribute should be a tensor or array with a certain rank and the `velocity` of a `Fluid` object should be a `StaggeredGrid`.\n", + "For a `CenteredGrid`, the `data` attribute should be a tensor or array with a certain rank and the `velocity` of a `Fluid` object should be a `StaggeredGrid`.\n", "\n", - "An entry is _valid_ if its value if of the preferred data type.\n", - "Subclasses of `Struct` can implement validity checks and modify their entries to make them valid.\n", + "An item is _valid_ if its value fulfills all those restrictions and can be passed to a solver.\n", + "Subclasses of `Struct` implement validity checks and may modify their entries to make them valid.\n", + "This allows the shorthand notation `Fluid(density=1)` to create a `CenteredGrid` full of ones.\n", "\n", - "This hierarchy is not always needed, however. Many math functions return invalid structs such as `math.staticshape(obj)` which returns a struct containing shapes instead of data.\n", - "Code dealing with invalid structs should always be enclosed in a `with struct.unsafe():` block.\n", - "This context skips all data validation steps." + "When representing some other property such as the `shape` or `dtype` of a struct, these restrictions do not apply.\n", + "To model this behavior properly, `Struct` objects remember their content type.\n", + "Item validation is only performed if the content type is equal to `VALID`." ] }, { @@ -407,9 +494,9 @@ "metadata": {}, "source": [ "# Immutability\n", - "While structs can be mutable in principle, the struct interface does not allow for changing a struct.\n", - "Attributes and properties can be \"changed\" using the `copy_with` function.\n", - "In this way, the struct isn't altered but rather a duplicate with the new values is created." + "While structs can be mutable in principle, the public struct API does not allow for changing a struct.\n", + "Variables and constants can be \"changed\" using the `copy_with` function and `copied_with` method.\n", + "This does not alter the struct but creates a duplicate with the new values." ] }, { @@ -432,7 +519,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -446,17 +533,17 @@ "@struct.definition()\n", "class MyStruct(struct.Struct):\n", " \n", - " def __init__(self, a, p, other, **kwargs):\n", + " def __init__(self, v, c, other, **kwargs):\n", " struct.Struct.__init__(self, **struct.kwargs(locals(), ignore=['other']))\n", " self._other = other\n", " \n", " @struct.variable()\n", - " def a(self, a):\n", - " return a\n", + " def v(self, v):\n", + " return v\n", "\n", " @struct.constant()\n", - " def p(self, p):\n", - " return p\n", + " def c(self, c):\n", + " return c\n", " \n", " @property\n", " def other(self):\n", @@ -468,68 +555,269 @@ ], "source": [ "from phi.struct.python_generator import generate\n", - "print(generate('MyStruct', attributes=['a'], properties=['p'], other=['other']))" + "print(generate('MyStruct', variables=['v'], constants=['c'], others=['other']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The methods themselves are used for validation. In addition to `self`, each attribute and property gets the intended value as an input. The function can either directly return this value without any validity checks, raise an error for invalid values or transform the value into a valid value.\n", + "The methods themselves are used for validation. In addition to `self`, each attribute and property gets the intended value as an input. The function can either directly return this value without any validity checks, raise an error for invalid values or transform the value into a valid value." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "from phi import struct\n", + "\n", + "\n", + "@struct.definition()\n", + "class MyStruct(struct.Struct):\n", + " \n", + " def __init__(self, v, c, other, **kwargs):\n", + " struct.Struct.__init__(self, **struct.kwargs(locals(), ignore=['other']))\n", + " self._other = other\n", + " \n", + " @struct.variable()\n", + " def v(self, v):\n", + " return v\n", "\n", - "Unless created inside a `with struct.unsafe()` block, structs are always valid when viewed from outside." + " @struct.constant()\n", + " def c(self, c):\n", + " return c\n", + " \n", + " @property\n", + " def other(self):\n", + " return self._other" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can iterate over specific items the same way as before." ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0 \n", - "1 \n", - "2 \n" + "--- Variables ---\n", + "v (int)\n", + "\n", + "--- Constants ---\n", + "c (int)\n" ] } ], "source": [ - "from phi import struct\n", + "mystruct = MyStruct(v=0, c=0, other=None)\n", + "\n", + "print('--- Variables ---')\n", + "with struct.VARIABLES:\n", + " struct.map(print_name, mystruct, trace=True, content_type=struct.INVALID);\n", "\n", + "print('\\n--- Constants ---')\n", + "with struct.CONSTANTS:\n", + " struct.map(print_name, mystruct, trace=True, content_type=struct.INVALID);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By inheriting from `Struct`, `MyStruct` obtains implementations for `dtype` and `shape`, making it look like a tensor. " + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MyStruct[]\n" + ] + } + ], + "source": [ + "print(mystruct.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "int32 0\n" + ] + } + ], + "source": [ + "print(mystruct.dtype.v, mystruct.dtype.c)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "() 0\n" + ] + } + ], + "source": [ + "print(mystruct.shape.v, mystruct.shape.c)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that `shape` and `dtype` use the context item condition (`DATA` by default). Therefore only the variable `v` is affected at the moment. To obtain the shapes of other items, we can use the `with item_condition:` syntax." + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "int32 int32\n" + ] + } + ], + "source": [ + "with struct.ALL_ITEMS:\n", + " print(mystruct.dtype.v, mystruct.dtype.c)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The inherited `copied_with` method can be used to \"change\" variables and constants." + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 1\n" + ] + } + ], + "source": [ + "changed = mystruct.copied_with(v=1, c=1)\n", + "print(changed.v, changed.c)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom validation functions\n", + "Each item is declared as a function. This function is called upon validation to return a valid value for the item or raise an error.\n", "\n", + "Let's have the variable always be of type float and convert the constant to a string." + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [], + "source": [ "@struct.definition()\n", "class MyStruct(struct.Struct):\n", " \n", - " def __init__(self, a, p, other, **kwargs):\n", + " def __init__(self, v, c, other, **kwargs):\n", " struct.Struct.__init__(self, **struct.kwargs(locals(), ignore=['other']))\n", " self._other = other\n", " \n", " @struct.variable()\n", - " def a(self, a):\n", - " return a\n", + " def v(self, v):\n", + " return float(v)\n", "\n", " @struct.constant()\n", - " def p(self, p):\n", - " return str(p)\n", + " def c(self, c):\n", + " return str(c)\n", " \n", " @property\n", " def other(self):\n", - " return self._other\n", - "\n", - "mystruct = MyStruct(a=0, p=0, other=None)\n", - "print(mystruct.p, type(mystruct.p))\n", - "mystruct = mystruct.copied_with(p=1)\n", - "print(mystruct.p, type(mystruct.p))" + " return self._other" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 67, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float64 \n" + ] + } + ], + "source": [ + "mystruct = MyStruct(v=0, c=0, other=None)\n", + "with struct.ALL_ITEMS:\n", + " print(mystruct.dtype.v, mystruct.dtype.c)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To skip validation, we could declare a different content type." + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "int32 int32\n" + ] + } + ], + "source": [ + "mystruct = MyStruct(v=0, c=0, other=None, content_type=struct.INVALID)\n", + "with struct.ALL_ITEMS:\n", + " print(mystruct.dtype.v, mystruct.dtype.c)" + ] } ], "metadata": { @@ -548,7 +836,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.6.5" } }, "nbformat": 4, diff --git a/documentation/figures/Building_Blocks.png b/documentation/figures/Building_Blocks.png index 615a99be5..ac35a04ce 100644 Binary files a/documentation/figures/Building_Blocks.png and b/documentation/figures/Building_Blocks.png differ diff --git a/documentation/figures/Context.png b/documentation/figures/Context.png index 918c4683f..62dee4e4a 100644 Binary files a/documentation/figures/Context.png and b/documentation/figures/Context.png differ diff --git a/documentation/figures/Module_Diagram.png b/documentation/figures/Module_Diagram.png index dad3a4532..a9f26d399 100644 Binary files a/documentation/figures/Module_Diagram.png and b/documentation/figures/Module_Diagram.png differ diff --git a/phi/app/app.py b/phi/app/app.py index 486a14c5f..109e8fa51 100644 --- a/phi/app/app.py +++ b/phi/app/app.py @@ -263,8 +263,7 @@ def field_generator(): return trace.find_in(world_state) self.add_field(field.name[0].upper() + field.name[1:], field_generator) return None - with struct.unsafe(): - struct.map(add_default_field, self.world.state, leaf_condition=lambda x: isinstance(x, (CenteredGrid, StaggeredGrid)), trace=True) + struct.map(add_default_field, self.world.state, leaf_condition=lambda x: isinstance(x, (CenteredGrid, StaggeredGrid)), trace=True, content_type=struct.INVALID) def add_custom_property(self, key, value): self._custom_properties[key] = value diff --git a/phi/backend/__init__.py b/phi/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/phi/backend/backend.py b/phi/backend/backend.py new file mode 100644 index 000000000..513d1de18 --- /dev/null +++ b/phi/backend/backend.py @@ -0,0 +1,278 @@ +class Backend: + + def __init__(self, name): + self.name = name + + def __str__(self): + return self.name + + def __repr__(self): + return self.name + + def matches_name(self, name): + return self.name.lower() == name.lower() + + def is_applicable(self, values): + for value in values: + if self.is_tensor(value): + return True + return False + + # --- Abstract math functions --- + + def is_tensor(self, x): + raise NotImplementedError() + + def as_tensor(self, x): + raise NotImplementedError() + + def equal(self, x, y): + raise NotImplementedError() + + def random_uniform(self, shape): + raise NotImplementedError(self) + + def stack(self, values, axis=0): + raise NotImplementedError(self) + + def concat(self, values, axis): + raise NotImplementedError(self) + + def pad(self, value, pad_width, mode='constant', constant_values=0): + """ + Pad a tensor. + :param value: tensor + :param pad_width: 2D tensor specifying the number of values padded to the edges of each axis in the form [[before axis 0, after axis 0], ...] including batch and component axes. + :param mode: + 'constant', + 'reflect', + 'replicate', + 'circular' + ('wrap' is deprecated, use 'circular' instead, 'symmetric' may not be supported by all backends and defaults to 'replicate'). + :param constant_values: used for out-of-bounds points if mode='constant' + """ + raise NotImplementedError(self) + + def reshape(self, value, shape): + raise NotImplementedError(self) + + def sum(self, value, axis=None, keepdims=False): + raise NotImplementedError(self) + + def prod(self, value, axis=None): + raise NotImplementedError(self) + + def divide_no_nan(self, x, y): + """ Computes x/y but returns 0 if y=0. """ + raise NotImplementedError(self) + + def where(self, condition, x=None, y=None): + raise NotImplementedError(self) + + def mean(self, value, axis=None, keepdims=False): + raise NotImplementedError(self) + + def py_func(self, func, inputs, Tout, shape_out, stateful=True, name=None, grad=None): + raise NotImplementedError(self) + + def resample(self, inputs, sample_coords, interpolation='linear', boundary='constant'): + """ + Interpolates a regular grid at the sample coordinates. + :param inputs: grid data + :param sample_coords: tensor of floating grid locations. The last dimension must match the dimensions of inputs. The first grid point of dimension i lies at position 0, the last at data.shape[i]-1. + :param interpolation: only 'linear' is currently supported + :param boundary: + 'constant'/'zero', + 'replicate', + 'circular' + ('symmetric' may not be supported by all backends and defaults to 'replicate') + """ + raise NotImplementedError(self) + + def range(self, start, limit=None, delta=1, dtype=None): + raise NotImplementedError(self) + + def zeros_like(self, tensor): + raise NotImplementedError(self) + + def ones_like(self, tensor): + raise NotImplementedError(self) + + def dot(self, a, b, axes): + raise NotImplementedError(self) + + def matmul(self, A, b): + raise NotImplementedError(self) + + def while_loop(self, cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, + swap_memory=False, name=None, maximum_iterations=None): + raise NotImplementedError(self) + + def abs(self, x): + raise NotImplementedError(self) + + def sign(self, x): + raise NotImplementedError(self) + + def round(self, x): + raise NotImplementedError(self) + + def ceil(self, x): + raise NotImplementedError(self) + + def floor(self, x): + raise NotImplementedError(self) + + def max(self, x, axis=None): + raise NotImplementedError(self) + + def min(self, x, axis=None): + raise NotImplementedError(self) + + def maximum(self, a, b): + raise NotImplementedError(self) + + def minimum(self, a, b): + raise NotImplementedError(self) + + def with_custom_gradient(self, function, inputs, gradient, input_index=0, output_index=None, name_base='custom_gradient_func'): + raise NotImplementedError(self) + + def sqrt(self, x): + raise NotImplementedError(self) + + def exp(self, x): + raise NotImplementedError(self) + + def conv(self, tensor, kernel, padding='same'): + raise NotImplementedError(self) + + def expand_dims(self, a, axis=0, number=1): + raise NotImplementedError(self) + + def shape(self, tensor): + raise NotImplementedError(self) + + def staticshape(self, tensor): + raise NotImplementedError(self) + + def to_float(self, x): + raise NotImplementedError(self) + + def to_int(self, x, int64=False): + raise NotImplementedError(self) + + def to_complex(self, x): + raise NotImplementedError(self) + + def dimrange(self, tensor): + return range(1, len(tensor.shape) - 1) + + def gather(self, values, indices): + raise NotImplementedError(self) + + def gather_nd(self, values, indices): + raise NotImplementedError(self) + + def flatten(self, x): + return self.reshape(x, (-1,)) + + def unstack(self, tensor, axis=0): + raise NotImplementedError(self) + + def std(self, x, axis=None): + raise NotImplementedError(self) + + def boolean_mask(self, x, mask): + raise NotImplementedError(self) + + def isfinite(self, x): + raise NotImplementedError(self) + + def scatter(self, points, indices, values, shape, duplicates_handling='undefined'): + """ + This method expects the first dimension of indices and values to be the batch dimension. + The batch dimension need not be specified in the indices array. + + All indices must be non-negative and are expected to be within bounds. Otherwise the behaviour is undefined. + + :param indices: + :param values: + :param shape: + :param duplicates_handling: one of ('undefined', 'add', 'mean', 'any', 'last', 'no duplicates') + """ + raise NotImplementedError(self) + + def any(self, boolean_tensor, axis=None, keepdims=False): + raise NotImplementedError(self) + + def all(self, boolean_tensor, axis=None, keepdims=False): + raise NotImplementedError(self) + + def fft(self, x): + """ + Computes the n-dimensional FFT along all but the first and last dimensions. + + :param x: tensor of dimension 3 or higher + """ + raise NotImplementedError(self) + + def ifft(self, k): + """ + Computes the n-dimensional inverse FFT along all but the first and last dimensions. + + :param k: tensor of dimension 3 or higher + """ + raise NotImplementedError(self) + + def imag(self, complex): + raise NotImplementedError(self) + + def real(self, complex): + raise NotImplementedError(self) + + def cast(self, x, dtype): + raise NotImplementedError(self) + + def sin(self, x): + raise NotImplementedError(self) + + def cos(self, x): + raise NotImplementedError(self) + + def dtype(self, array): + raise NotImplementedError(self) + + def tile(self, value, multiples): + raise NotImplementedError(self) + + def sparse_tensor(self, indices, values, shape): + raise NotImplementedError(self) + + # --- Math function with default implementation --- + + def ndims(self, tensor): + return len(self.staticshape(tensor)) + + def size(self, array): + return self.prod(self.shape(array)) + + def batch_gather(self, tensor, batches): + if isinstance(batches, int): + batches = [batches] + return tensor[batches, ...] + + def add(self, a, b): + return self.as_tensor(a) * self.as_tensor(b) + + def sub(self, a, b): + return self.as_tensor(a) - self.as_tensor(b) + + def mul(self, a, b): + return self.as_tensor(a) * self.as_tensor(b) + + def div(self, numerator, denominator): + return self.as_tensor(numerator) / self.as_tensor(denominator) + + def pow(self, base, exp): + return self.as_tensor(base) ** self.as_tensor(exp) diff --git a/phi/math/base_backend.py b/phi/backend/dynamic_backend.py similarity index 50% rename from phi/math/base_backend.py rename to phi/backend/dynamic_backend.py index 8aaa5440c..be9189c41 100644 --- a/phi/math/base_backend.py +++ b/phi/backend/dynamic_backend.py @@ -1,281 +1,10 @@ -class Backend: +from .backend import Backend - def __init__(self, name): - self.name = name - def __str__(self): - return self.name - - def __repr__(self): - return self.name - - def matches_name(self, name): - return self.name.lower() == name.lower() - - def is_applicable(self, values): - for value in values: - if self.is_tensor(value): - return True - return False - - # --- Abstract math functions --- - - def is_tensor(self, x): - raise NotImplementedError() - - def as_tensor(self, x): - raise NotImplementedError() - - def equal(self, x, y): - raise NotImplementedError() - - def random_uniform(self, shape): - raise NotImplementedError(self) - - def stack(self, values, axis=0): - raise NotImplementedError(self) - - def concat(self, values, axis): - raise NotImplementedError(self) - - def pad(self, value, pad_width, mode='constant', constant_values=0): - """ - Pad a tensor. - :param value: tensor - :param pad_width: 2D tensor specifying the number of values padded to the edges of each axis in the form [[before axis 0, after axis 0], ...] including batch and component axes. - :param mode: - 'constant', - 'reflect', - 'replicate', - 'circular' - ('wrap' is deprecated, use 'circular' instead, 'symmetric' may not be supported by all backends and defaults to 'replicate'). - :param constant_values: used for out-of-bounds points if mode='constant' - """ - raise NotImplementedError(self) - - def reshape(self, value, shape): - raise NotImplementedError(self) - - def sum(self, value, axis=None, keepdims=False): - raise NotImplementedError(self) - - def prod(self, value, axis=None): - raise NotImplementedError(self) - - def divide_no_nan(self, x, y): - """ Computes x/y but returns 0 if y=0. """ - raise NotImplementedError(self) - - def where(self, condition, x=None, y=None): - raise NotImplementedError(self) - - def mean(self, value, axis=None, keepdims=False): - raise NotImplementedError(self) - - def py_func(self, func, inputs, Tout, shape_out, stateful=True, name=None, grad=None): - raise NotImplementedError(self) - - def resample(self, inputs, sample_coords, interpolation='linear', boundary='constant'): - """ - Interpolates a regular grid at the sample coordinates. - :param inputs: grid data - :param sample_coords: tensor of floating grid locations. The last dimension must match the dimensions of inputs. The first grid point of dimension i lies at position 0, the last at data.shape[i]-1. - :param interpolation: only 'linear' is currently supported - :param boundary: - 'constant'/'zero', - 'replicate', - 'circular' - ('symmetric' may not be supported by all backends and defaults to 'replicate') - """ - raise NotImplementedError(self) - - def range(self, start, limit=None, delta=1, dtype=None): - raise NotImplementedError(self) - - def zeros_like(self, tensor): - raise NotImplementedError(self) - - def ones_like(self, tensor): - raise NotImplementedError(self) - - def dot(self, a, b, axes): - raise NotImplementedError(self) - - def matmul(self, A, b): - raise NotImplementedError(self) - - def while_loop(self, cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, - swap_memory=False, name=None, maximum_iterations=None): - raise NotImplementedError(self) - - def abs(self, x): - raise NotImplementedError(self) - - def sign(self, x): - raise NotImplementedError(self) - - def round(self, x): - raise NotImplementedError(self) - - def ceil(self, x): - raise NotImplementedError(self) - - def floor(self, x): - raise NotImplementedError(self) - - def max(self, x, axis=None): - raise NotImplementedError(self) - - def min(self, x, axis=None): - raise NotImplementedError(self) - - def maximum(self, a, b): - raise NotImplementedError(self) - - def minimum(self, a, b): - raise NotImplementedError(self) - - def with_custom_gradient(self, function, inputs, gradient, input_index=0, output_index=None, name_base='custom_gradient_func'): - raise NotImplementedError(self) - - def sqrt(self, x): - raise NotImplementedError(self) - - def exp(self, x): - raise NotImplementedError(self) - - def conv(self, tensor, kernel, padding='same'): - raise NotImplementedError(self) - - def expand_dims(self, a, axis=0, number=1): - raise NotImplementedError(self) - - def shape(self, tensor): - raise NotImplementedError(self) - - def staticshape(self, tensor): - raise NotImplementedError(self) - - def to_float(self, x): - raise NotImplementedError(self) - - def to_int(self, x, int64=False): - raise NotImplementedError(self) - - def to_complex(self, x): - raise NotImplementedError(self) - - def dimrange(self, tensor): - return range(1, len(tensor.shape) - 1) - - def gather(self, values, indices): - raise NotImplementedError(self) - - def gather_nd(self, values, indices): - raise NotImplementedError(self) - - def flatten(self, x): - return self.reshape(x, (-1,)) - - def unstack(self, tensor, axis=0): - raise NotImplementedError(self) - - def std(self, x, axis=None): - raise NotImplementedError(self) - - def boolean_mask(self, x, mask): - raise NotImplementedError(self) - - def isfinite(self, x): - raise NotImplementedError(self) - - def scatter(self, points, indices, values, shape, duplicates_handling='undefined'): - """ - This method expects the first dimension of indices and values to be the batch dimension. - The batch dimension need not be specified in the indices array. - - All indices must be non-negative and are expected to be within bounds. Otherwise the behaviour is undefined. - - :param indices: - :param values: - :param shape: - :param duplicates_handling: one of ('undefined', 'add', 'mean', 'any', 'last', 'no duplicates') - """ - raise NotImplementedError(self) - - def any(self, boolean_tensor, axis=None, keepdims=False): - raise NotImplementedError(self) - - def all(self, boolean_tensor, axis=None, keepdims=False): - raise NotImplementedError(self) - - def fft(self, x): - """ - Computes the n-dimensional FFT along all but the first and last dimensions. - - :param x: tensor of dimension 3 or higher - """ - raise NotImplementedError(self) - - def ifft(self, k): - """ - Computes the n-dimensional inverse FFT along all but the first and last dimensions. - - :param k: tensor of dimension 3 or higher - """ - raise NotImplementedError(self) - - def imag(self, complex): - raise NotImplementedError(self) - - def real(self, complex): - raise NotImplementedError(self) - - def cast(self, x, dtype): - raise NotImplementedError(self) - - def sin(self, x): - raise NotImplementedError(self) - - def cos(self, x): - raise NotImplementedError(self) - - def dtype(self, array): - raise NotImplementedError(self) - - def tile(self, value, multiples): - raise NotImplementedError(self) - - def sparse_tensor(self, indices, values, shape): - raise NotImplementedError(self) - - # --- Math function with default implementation --- - - def ndims(self, tensor): - return len(self.staticshape(tensor)) - - def size(self, array): - return self.prod(self.shape(array)) - - def batch_gather(self, tensor, batches): - if isinstance(batches, int): - batches = [batches] - return tensor[batches, ...] - - def add(self, a, b): - return self.as_tensor(a) * self.as_tensor(b) - - def sub(self, a, b): - return self.as_tensor(a) - self.as_tensor(b) - - def mul(self, a, b): - return self.as_tensor(a) * self.as_tensor(b) - - def div(self, numerator, denominator): - return self.as_tensor(numerator) / self.as_tensor(denominator) +class NoBackendFound(Exception): - def pow(self, base, exp): - return self.as_tensor(base) ** self.as_tensor(exp) + def __init__(self, msg): + Exception.__init__(self, msg) class DynamicBackend(Backend): @@ -504,10 +233,4 @@ def pow(self, base, exp): return self.choose_backend([base, exp]).pow(base, exp) -class NoBackendFound(Exception): - - def __init__(self, msg): - Exception.__init__(self, msg) - - DYNAMIC_BACKEND = DynamicBackend() diff --git a/phi/math/scipy_backend.py b/phi/backend/scipy_backend.py similarity index 87% rename from phi/math/scipy_backend.py rename to phi/backend/scipy_backend.py index 50085286f..0a1317d62 100644 --- a/phi/math/scipy_backend.py +++ b/phi/backend/scipy_backend.py @@ -6,13 +6,17 @@ import scipy.signal import scipy.sparse import six -from phi.struct.tensorop import collapsed_gather_nd, expand -from .base_backend import Backend +from .backend import Backend +from .tensorop import collapsed_gather_nd, expand class SciPyBackend(Backend): + """ + Core Python Backend using NumPy & SciPy + """ + def __init__(self): Backend.__init__(self, "SciPy") @@ -40,12 +44,15 @@ def is_applicable(self, values): # --- Abstract math functions --- def as_tensor(self, x): + """ as array """ return np.array(x) def is_tensor(self, x): + """ is array """ return isinstance(x, np.ndarray) def equal(self, x, y): + """ array equality comparison """ return np.equal(x, y) def divide_no_nan(self, x, y): @@ -54,12 +61,15 @@ def divide_no_nan(self, x, y): return np.where(y == 0, 0, result) def random_uniform(self, shape): + """ random array [0.0, 1.0) """ return np.random.random(shape).astype('f') def rank(self, value): + """ len(shape), number of dimensions """ return len(value.shape) def range(self, start, limit=None, delta=1, dtype=None): + """ range syntax to arange syntax """ if limit is None: start, limit = 0, start return np.arange(start, limit, delta, dtype) @@ -88,13 +98,13 @@ def pad(self, value, pad_width, mode='constant', constant_values=0): def _single_mode_pad(self, value, pad_width, single_mode, constant_values=0): if np.sum(np.array(pad_width)) == 0: return value - if single_mode == 'wrap': + if single_mode.lower() == 'wrap': warnings.warn("padding mode 'wrap' is deprecated. Use 'circular' instead.", DeprecationWarning, stacklevel=2) - if single_mode.lower() == 'constant': + elif single_mode.lower() == 'constant': return np.pad(value, pad_width, 'constant', constant_values=constant_values) - if single_mode.lower() == 'circular': + elif single_mode.lower() == 'circular': single_mode = 'wrap' - if single_mode.lower() == 'replicate': + elif single_mode.lower() == 'replicate': single_mode = 'edge' return np.pad(value, pad_width, single_mode.lower()) @@ -123,26 +133,28 @@ def py_func(self, func, inputs, Tout, shape_out, stateful=True, name=None, grad= return result def resample(self, inputs, sample_coords, interpolation='linear', boundary='constant'): - if boundary.lower() == 'zero' or boundary.lower() == 'constant': + """ resample input array at certain coordinates """ + if boundary.lower() in ('zero', 'constant'): pass # default elif boundary.lower() == 'replicate': sample_coords = clamp(sample_coords, inputs.shape[1:-1]) elif boundary.lower() == 'circular': - inputs = self.pad(inputs, [[0,0]] + [[0,1]] * tensor_spatial_rank(inputs) + [[0,0]], mode='circular') - sample_coords = sample_coords % self.to_float(self.staticshape(inputs)[1:-1]) + resolution = self.staticshape(inputs)[1:-1] + inputs = self.pad(inputs, [[0, 0]] + [[0, 1]] * tensor_spatial_rank(inputs) + [[0, 0]], mode='circular') + sample_coords = sample_coords % self.to_float(resolution) else: raise ValueError("Unsupported boundary: %s" % boundary) - + # Interpolate import scipy.interpolate points = [np.arange(dim) for dim in inputs.shape[1:-1]] result = [] for batch in range(sample_coords.shape[0]): components = [] for dim in range(inputs.shape[-1]): - resampled = scipy.interpolate.interpn(points, inputs[batch, ..., dim], sample_coords[batch, ...], method=interpolation.lower(), bounds_error=False, fill_value=0) + resampled = scipy.interpolate.interpn(points,inputs[batch, ..., dim], sample_coords[batch, ...], + method=interpolation.lower(), bounds_error=False, fill_value=0) components.append(resampled) result.append(np.stack(components, -1)) - result = np.stack(result).astype(inputs.dtype) return result @@ -208,6 +220,7 @@ def exp(self, x): return np.exp(x) def conv(self, tensor, kernel, padding="SAME"): + """ apply convolution of kernel on tensor """ assert tensor.shape[-1] == kernel.shape[-2] # kernel = kernel[[slice(None)] + [slice(None, None, -1)] + [slice(None)]*(len(kernel.shape)-3) + [slice(None)]] if padding.lower() == "same": @@ -300,9 +313,9 @@ def fft(self, x): if rank == 1: return np.fft.fft(x, axis=1) elif rank == 2: - return np.fft.fft2(x, axes=[1,2]) + return np.fft.fft2(x, axes=[1, 2]) else: - return np.fft.fftn(x, axes=list(range(1,rank + 1))) + return np.fft.fftn(x, axes=list(range(1, rank + 1))) def ifft(self, k): rank = len(k.shape) - 2 @@ -310,15 +323,15 @@ def ifft(self, k): if rank == 1: return np.fft.ifft(k, axis=1) elif rank == 2: - return np.fft.ifft2(k, axes=[1,2]) + return np.fft.ifft2(k, axes=[1, 2]) else: - return np.fft.ifftn(k, axes=list(range(1,rank + 1))) + return np.fft.ifftn(k, axes=list(range(1, rank + 1))) - def imag(self, complex): - return np.imag(complex) + def imag(self, complex_arr): + return np.imag(complex_arr) - def real(self, complex): - return np.real(complex) + def real(self, complex_arr): + return np.real(complex_arr) def sin(self, x): return np.sin(x) @@ -338,7 +351,7 @@ def sparse_tensor(self, indices, values, shape): def clamp(coordinates, shape): assert coordinates.shape[-1] == len(shape) for i in range(len(shape)): - coordinates[...,i] = np.maximum(0, np.minimum(shape[i] - 1, coordinates[...,i])) + coordinates[...,i] = np.maximum(0, np.minimum(shape[i] - 1, coordinates[..., i])) return coordinates diff --git a/phi/backend/tensorop.py b/phi/backend/tensorop.py new file mode 100644 index 000000000..8611e2e81 --- /dev/null +++ b/phi/backend/tensorop.py @@ -0,0 +1,64 @@ +import numpy as np + + +def _is_leaf(tensor_like, leaf_condition): + if not isinstance(tensor_like, (tuple, list, np.ndarray)): + return True + if leaf_condition is not None and leaf_condition(tensor_like): + return True + return False + + +def collapse(tensor_like, leaf_condition=None): + if _is_leaf(tensor_like, leaf_condition): + return tensor_like + collapsed_elements = tuple([collapse(element, leaf_condition) for element in tensor_like]) + first = collapsed_elements[0] + for element in collapsed_elements[1:]: + if element != first: + return collapsed_elements + return first + + +def collapsed_gather_nd(collapsed, nd_index, leaf_condition=None): + if isinstance(collapsed, (tuple, list, np.ndarray)): + if leaf_condition is not None and leaf_condition(collapsed): + return collapsed + # collapsed = np.array(collapsed) + if len(nd_index) == 1: + return collapsed[nd_index[0]] + else: + return collapsed_gather_nd(collapsed[nd_index[0]], nd_index[1:]) + else: + return collapsed + + +def expand(collapsed, shape): + if len(shape) == 0: + return collapsed + if isinstance(collapsed, (tuple, list, np.ndarray)): + if len(collapsed) == shape[0]: + return [expand(item, shape[1:]) for item in collapsed] + elif len(collapsed) == 1: + item = expand(collapsed[0], shape[1:]) + return [item] * shape[0] + else: + raise ValueError('Cannot match shape: requested %d but actual %d' % (shape[0], len(collapsed))) + else: + return [expand(collapsed, shape[1:])] * shape[0] + + +class CollapsedTensor(object): + + def __init__(self, collapsed, leaf_condition=None, shape=None): + self.collapsed = collapsed + self.leaf_condition = leaf_condition + self.shape = shape + + def __getitem__(self, item): + return collapsed_gather_nd(self.collapsed, item, self.leaf_condition) + + def expand(self, shape=None): + shape = self.shape if shape is None else shape + assert shape is not None + return expand(self.collapsed, shape) diff --git a/phi/data/fluidformat.py b/phi/data/fluidformat.py index d4166319e..44b6d17c4 100644 --- a/phi/data/fluidformat.py +++ b/phi/data/fluidformat.py @@ -351,8 +351,7 @@ def f(value): return value.data else: return value - with struct.unsafe(): - data = struct.map(f, obj, lambda x: isinstance(x, (field.StaggeredGrid, field.CenteredGrid))) + data = struct.map(f, obj, lambda x: isinstance(x, (field.StaggeredGrid, field.CenteredGrid)), content_type='format') return data diff --git a/phi/data/reader.py b/phi/data/reader.py index 99e9bfc66..ae6b185d4 100644 --- a/phi/data/reader.py +++ b/phi/data/reader.py @@ -1,6 +1,11 @@ import math from bisect import bisect_left -from collections import Iterable +try: + # Python 3 + from collections.abc import Iterable +except ImportError: + # Python 2.7 + from collections import Iterable from sys import getsizeof import numpy as np @@ -40,8 +45,7 @@ def _get_batch(self, indices): data_list = self._cache.get(indices, self._load, add_to_cache=True) data = list_swap_axes(data_list) data_map = {self.streams[i]: data[i] for i in range(len(self._streams))} - with struct.unsafe(): - return struct.map(lambda stream: data_map[stream], self._fields) + return struct.map(lambda stream: data_map[stream], self._fields, content_type=struct.INVALID) def _load(self, indices): result = [] diff --git a/phi/flow.py b/phi/flow.py index b987ef543..61c41c209 100644 --- a/phi/flow.py +++ b/phi/flow.py @@ -12,7 +12,9 @@ from .physics.material import * from .physics.domain import * from .physics.field.effect import * +from .physics.pressuresolver.solver_api import PoissonDomain, PoissonSolver from .physics.pressuresolver.sparse import SparseCG, SparseSciPy +from .physics.pressuresolver.geom import GeometricCG from .data.fluidformat import * from .data.dataset import * diff --git a/phi/geom/geometry.py b/phi/geom/geometry.py index 8ac59b286..06a0eb572 100644 --- a/phi/geom/geometry.py +++ b/phi/geom/geometry.py @@ -20,19 +20,23 @@ def rank(self): raise NotImplementedError() -@struct.definition() +@struct.definition(traits=[math.BATCHED]) class AABox(Geometry): + """ + Axis-aligned box, defined by lower and upper corner. + AABoxes can be created using the shorthand notation box[slices], (e.g. box[:,0:1] to create an inifinite-height box from x=0 to x=1). + """ def __init__(self, lower, upper, **kwargs): Geometry.__init__(self, **struct.kwargs(locals())) - @struct.constant() + @struct.constant(min_rank=1) def lower(self, lower): - return math.to_float(math.as_tensor(lower)) + return math.to_float(lower) - @struct.constant() + @struct.constant(min_rank=1) def upper(self, upper): - return math.to_float(math.as_tensor(upper)) + return math.to_float(upper) def get_lower(self, axis): return self._get(self.lower, axis) @@ -42,8 +46,8 @@ def get_upper(self, axis): @staticmethod def _get(vector, axis): - if math.ndims(vector) == 0: - return vector + if vector.shape[-1] == 1: + return vector[...,0] else: return vector[...,axis] @@ -88,10 +92,10 @@ def without_axis(self, axis): return self.copied_with(lower=lower, upper=upper) def __repr__(self): - try: + if self.is_valid: return '%s at (%s)' % ('x'.join([str(x) for x in self.size]), ','.join([str(x) for x in self.lower])) - except TypeError: - return '%s at %s' % (self.size, self.lower) + else: + return struct.Struct.__repr__(self) @staticmethod def to_box(value, resolution_hint=None): @@ -127,22 +131,22 @@ def __getitem__(self, item): return AABox(lower, upper) -box = AABoxGenerator() +box = AABoxGenerator() # Instantiate an AABox using the syntax box[slices] -@struct.definition() +@struct.definition(traits=[math.BATCHED]) class Sphere(Geometry): def __init__(self, center, radius, **kwargs): Geometry.__init__(self, **struct.kwargs(locals())) - @struct.constant() + @struct.constant(min_rank=0) def radius(self, radius): - return math.as_tensor(radius) + return radius - @struct.constant() + @struct.constant(min_rank=1) def center(self, center): - return math.as_tensor(center) + return center def value_at(self, location): center = math.batch_align(self.center, 1, location) diff --git a/phi/math/__init__.py b/phi/math/__init__.py index e424e5d91..b18b25551 100644 --- a/phi/math/__init__.py +++ b/phi/math/__init__.py @@ -1,9 +1,9 @@ -from .base_backend import DYNAMIC_BACKEND -from .scipy_backend import SciPyBackend -from .struct_backend import StructBroadcastBackend +from phi.backend.dynamic_backend import DYNAMIC_BACKEND +from phi.backend.scipy_backend import SciPyBackend +from phi.struct.struct_backend import StructBroadcastBackend from .math_util import types, is_static_shape, zeros, ones, randn, randfreq -from .nd import (spatial_rank, spatial_dimensions, axes, all_dimensions, - is_scalar, +from .helper import is_scalar, axes +from .nd import (spatial_rank, spatial_dimensions, all_dimensions, indices_tensor, normalize_to, batch_align, batch_align_scalar, @@ -13,6 +13,7 @@ fftfreq, downsample2x, upsample2x, interpolate_linear, spatial_sum,) +from .batched import BATCHED, ShapeMismatch # Setup Backend diff --git a/phi/math/batched.py b/phi/math/batched.py new file mode 100644 index 000000000..f2c735b01 --- /dev/null +++ b/phi/math/batched.py @@ -0,0 +1,90 @@ +from phi.struct import Trait +from phi.backend.dynamic_backend import DYNAMIC_BACKEND as math + + +class Batched(Trait): + """ +Structs with this trait can tag items with the keyword 'min_rank', representing the number of innate inner dimensions of tensor values of that item. +All further, outer dimensions are assumed to be batch dimensions. + +Example: + @struct.definition(traits=[math.BATCHED])\n + class MyStruct(struct.Struct): + @struct.constant(min_rank=1)\n + def batched_constant(self, c): + assert math.ndims(c) >= 1 # Will always be fulfilled\n + return c + +For each item with the 'min_rank' keyword, (1) the specified minimum tensor rank is ensured by expanding the dimensions if necessary and (2) batch shape checks are performed during validation. +Also, all additional dimensions, called `batch dimensions` are collected and cross-checked between all items. +The batch dimensions of a valid batched struct can be accessed as `struct.batch_shape`, and `struct.batch_rank` holds the corresponding length. + +When a struct with inconsistent batch dimensions is validated, a `ShapeMismatch` error is raised, typically upon struct creation. + +The Batched trait also ensures that all values are converted to tensors before the validation function is called. + """ + + def check_argument(self, struct_class, item, keyword, value): + assert keyword == 'min_rank' + assert isinstance(value, int) or callable(value), value + + def endow(self, struct): + struct.batch_shape = None + struct.batch_rank = None + + def pre_validate_struct(self, struct): + struct.batch_shape = None + struct.batch_rank = None + + def pre_validated(self, struct, item, value): + tensor = math.as_tensor(value) + min_rank = item.trait_kwargs['min_rank'] + if callable(min_rank): + min_rank = min_rank(struct) + shape = math.staticshape(value) + if len(shape) < min_rank: + tensor = math.expand_dims(tensor, axis=0, number=min_rank - len(shape)) + shape = math.staticshape(value) + batch_shape = shape[:-min_rank if min_rank != 0 else None] + if struct.batch_shape is None: + struct.batch_shape = batch_shape + else: + struct.batch_shape = _combined_shape(batch_shape, struct.batch_shape, item, struct) + struct.batch_rank = len(struct.batch_shape) + return tensor + + +BATCHED = Batched(keywords=['min_rank']) + + +def _combined_shape(shape1, shape2, prop, obj): + rank = max(len(shape1), len(shape2)) + resulting_shape = [] + for i in range(1, rank+1): + dim1 = shape1[-i] if len(shape1) >= i else 1 + dim2 = shape2[-i] if len(shape2) >= i else 1 + try: + resulting_shape.append(_combined_dim(dim1, dim2)) + except AssertionError: + raise ShapeMismatch("Batch dimension %d with value %d of '%s' of %s does not match other properties with value %d. Occured during comparison of batch shapes %s and %s" % (-i, dim1, prop, obj, dim2, shape1, shape2)) + return tuple(resulting_shape[::-1]) + + +def _combined_dim(dim1, dim2): + if dim1 is None or dim2 is None: + return None + if dim1 == 1: + return dim2 + if dim2 == 1: + return dim1 + assert dim1 == dim2 + return dim1 + + +class ShapeMismatch(ValueError): + """ +Raised when a shape check fails, i.e. when tensors that require compatible shapes do not match. +It is a subclass of `ValueError` because ValueErrors are often raised in this case. + """ + def __init__(self, *args): + ValueError.__init__(self, *args) diff --git a/phi/math/blas.py b/phi/math/blas.py index 2bc45c5c9..85cefbc19 100644 --- a/phi/math/blas.py +++ b/phi/math/blas.py @@ -1,5 +1,5 @@ # coding=utf-8 -from .base_backend import DYNAMIC_BACKEND as math +from phi.backend.dynamic_backend import DYNAMIC_BACKEND as math def conjugate_gradient(k, apply_A, initial_x=None, accuracy=1e-5, max_iterations=1024, back_prop=False): @@ -36,6 +36,8 @@ def loop_condition(_1, _2, _3, residual, _i): def loop_condition(*_args): return True + non_batch_dims = tuple(range(1, len(k.shape))) + def loop_body(pressure, momentum, A_times_momentum, residual, loop_index): """ iteratively solve for: @@ -44,11 +46,11 @@ def loop_body(pressure, momentum, A_times_momentum, residual, loop_index): laplace_momentum : A_times_momentum residual : residual """ - tmp = math.sum(momentum * A_times_momentum, axis=1, keepdims=True) # t = sum(mAm) - a = math.divide_no_nan(math.sum(momentum * residual, axis=1, keepdims=True), tmp) # a = sum(mr)/sum(mAm) + tmp = math.sum(momentum * A_times_momentum, axis=non_batch_dims, keepdims=True) # t = sum(mAm) + a = math.divide_no_nan(math.sum(momentum * residual, axis=non_batch_dims, keepdims=True), tmp) # a = sum(mr)/sum(mAm) pressure += a * momentum # p += am residual -= a * A_times_momentum # r -= aAm - momentum = residual - math.divide_no_nan(math.sum(residual * A_times_momentum, axis=1, keepdims=True) * momentum, tmp) # m = r-sum(rAm)*m/t = r-sum(rAm)*m/sum(mAm) + momentum = residual - math.divide_no_nan(math.sum(residual * A_times_momentum, axis=non_batch_dims, keepdims=True) * momentum, tmp) # m = r-sum(rAm)*m/t = r-sum(rAm)*m/sum(mAm) A_times_momentum = apply_A(momentum) # Am = A*m return [pressure, momentum, A_times_momentum, residual, loop_index + 1] diff --git a/phi/math/helper.py b/phi/math/helper.py new file mode 100644 index 000000000..b0fce124c --- /dev/null +++ b/phi/math/helper.py @@ -0,0 +1,92 @@ +from phi.struct.tensorop import collapsed_gather_nd + +from phi.backend.dynamic_backend import DYNAMIC_BACKEND as math + + + +def spatial_rank(tensor): + """ The spatial rank of a tensor is ndims - 2. """ + return math.ndims(tensor) - 2 + + +def spatial_dimensions(obj): + return tuple(range(1, len(math.staticshape(obj)) - 1)) + + +def axes(obj): + return tuple(range(len(math.staticshape(obj)) - 2)) + + +def all_dimensions(tensor): + return range(len(math.staticshape(tensor))) + + +def is_scalar(obj): + return len(math.staticshape(obj)) == 0 + + +def _get_pad_width_axes(rank, axes, val_true=(1, 1), val_false=(0, 0)): + mid_shape = [] + for i in range(rank): + if _contains_axis(axes, i, rank): + mid_shape.append(val_true) + else: + mid_shape.append(val_false) + return [[0, 0]] + mid_shape + [[0, 0]] + + +def _get_pad_width(rank, axis_widths=(1, 1)): + return [[0, 0]] + [axis_widths] * rank + [[0, 0]] + + +def _dim_shifted(tensor, axis, relative_shifts, components=None, diminish_others=(0, 0), diminish_other_condition=None): + assert len(relative_shifts) >= 2 + total_shift = max(relative_shifts) - min(relative_shifts) + # --- Handle diminish_others --- + if isinstance(diminish_others, tuple): + slice_others = slice(diminish_others[0], -diminish_others[1] if diminish_others[1] != 0 else None) + else: + raise ValueError("Illegal diminish_others arguemnt: '%s'" % diminish_others) + # --- Handle components --- + if components is None: + component_slice = slice(None) + elif isinstance(components, int): + component_slice = slice(components, components+1) + elif isinstance(components, slice): + component_slice = components + else: + raise ValueError("Illegal components argument: '%s'" % components) + # --- Slice tensor to create shifts --- + rank = spatial_rank(tensor) + shifted_tensors = [] + for shift in relative_shifts: + shift_start = shift - min(relative_shifts) + shift_end = shift_start - total_shift + if shift_end == 0: + shift_end = None + slices = [] + for ax in range(rank): + if ax == axis: + slices.append(slice(shift_start, shift_end)) + else: + if diminish_other_condition is None or diminish_other_condition(ax): + slices.append(slice_others) + else: + slices.append(slice(None)) + sliced_tensor = tensor[(slice(None),) + tuple(slices) + (component_slice,)] + shifted_tensors.append(sliced_tensor) + return shifted_tensors + + +def _contains_axis(axes, axis, sp_rank): + assert -sp_rank <= axis < sp_rank + return (axes is None) or (axis in axes) or (axis + sp_rank in axes) + + +def map_for_axes(function, obj, axes, rank): + if axes is None: + return function(obj) + else: + return [(function(collapsed_gather_nd(obj, i)) if _contains_axis(axes, i, rank) + else collapsed_gather_nd(obj, i)) + for i in range(rank)] diff --git a/phi/math/math_util.py b/phi/math/math_util.py index 7865ef176..655cba8e3 100644 --- a/phi/math/math_util.py +++ b/phi/math/math_util.py @@ -1,16 +1,19 @@ +import warnings + import numpy as np from numbers import Number from phi import struct from phi.struct.functions import mappable -from .base_backend import DYNAMIC_BACKEND as math -from .base_backend import NoBackendFound +from phi.backend.dynamic_backend import DYNAMIC_BACKEND as math +from phi.backend.dynamic_backend import NoBackendFound from .nd import fftfreq -@mappable(item_condition=struct.ALL_ITEMS, unsafe_context=True) +@mappable(item_condition=struct.ALL_ITEMS, content_type=type) def types(x): + warnings.warn("math.types is deprecated. Use struct.dtype isntead.", DeprecationWarning) try: return math.dtype(x) except NoBackendFound: diff --git a/phi/math/nd.py b/phi/math/nd.py index 28a5e7c41..55e8705f3 100644 --- a/phi/math/nd.py +++ b/phi/math/nd.py @@ -4,29 +4,8 @@ import numpy as np from phi import struct -from phi.struct.tensorop import collapsed_gather_nd -from .base_backend import DYNAMIC_BACKEND as math - - -def spatial_rank(tensor): - """ The spatial rank of a tensor is ndims - 2. """ - return math.ndims(tensor) - 2 - - -def spatial_dimensions(obj): - return tuple(range(1, len(math.staticshape(obj)) - 1)) - - -def axes(obj): - return tuple(range(len(math.staticshape(obj)) - 2)) - - -def all_dimensions(tensor): - return range(len(math.staticshape(tensor))) - - -def is_scalar(obj): - return len(math.staticshape(obj)) == 0 +from phi.backend.dynamic_backend import DYNAMIC_BACKEND as math +from .helper import _get_pad_width_axes, _get_pad_width, spatial_rank, _dim_shifted, _contains_axis, spatial_dimensions, all_dimensions def indices_tensor(tensor, dtype=np.float32): @@ -76,13 +55,13 @@ def batch_align(tensor, innate_dims, target, convert_to_same_backend=True): assert target_ndims >= ndims if target_ndims == ndims: return tensor - return math.expand_dims(tensor, axis=-innate_dims-1, number=target_ndims - ndims) + return math.expand_dims(tensor, axis=(-innate_dims - 1), number=(target_ndims - ndims)) def batch_align_scalar(tensor, innate_spatial_dims, target): if math.staticshape(tensor)[-1] != 1: tensor = math.expand_dims(tensor, -1) - result = batch_align(tensor, innate_spatial_dims+1, target) + result = batch_align(tensor, innate_spatial_dims + 1, target) return result @@ -100,10 +79,10 @@ def blur(field, radius, cutoff=None, kernel="1/1+x"): if cutoff is None: cutoff = min(int(round(radius * 3)), *field.shape[1:-1]) - xyz = np.meshgrid(*[range(-int(cutoff), (cutoff)+1) for _ in field.shape[1:-1]]) - d = np.float32(np.sqrt(np.sum([x ** 2 for x in xyz], axis=0))) + xyz = np.meshgrid(*[range(-int(cutoff), (cutoff) + 1) for _ in field.shape[1:-1]]) + d = np.float32(np.sqrt(np.sum([x**2 for x in xyz], axis=0))) if kernel == "1/1+x": - weights = np.float32(1) / ( d / radius + 1) + weights = np.float32(1) / (d / radius + 1) elif kernel.lower() == "gauss": weights = math.exp(- d / radius / 2) else: @@ -146,49 +125,33 @@ def l_n_loss(tensor, n, batch_norm=True): # Divergence -def divergence(vel, dx=1, difference='central'): +def divergence(tensor, dx=1, difference='central'): """ Computes the spatial divergence of a vector channel from finite differences. - :param vel: tensor of shape (batch size, spatial dimensions..., spatial rank) + :param tensor: vector field; tensor of shape (batch size, spatial dimensions..., spatial rank) :param dx: distance between adjacent grid points (default 1) :param difference: type of difference, one of ('forward', 'central') (default 'forward') :return: tensor of shape (batch size, spatial dimensions..., 1) """ - assert difference in ('central', 'forward') - rank = spatial_rank(vel) + assert difference in ('central', 'forward', 'backward'), difference + rank = spatial_rank(tensor) if difference == 'forward': - return _forward_divergence_nd(vel) / dx ** rank + return _divergence_nd(tensor, (0, 1)) / dx ** rank + elif difference == 'backward': + return _divergence_nd(tensor, (-1, 0)) / dx ** rank else: - return _central_divergence_nd(vel) / (2 * dx) ** rank + return _divergence_nd(tensor, (-1, 1)) / (2 * dx) ** rank -def _forward_divergence_nd(field): - rank = spatial_rank(field) - dims = range(rank) - components = [] - for dimension in dims: - vq = field[...,rank-dimension-1] - upper_slices = [(slice(1, None) if i == dimension else slice(None)) for i in dims] - lower_slices = [(slice(-1) if i == dimension else slice(None)) for i in dims] - diff = vq[(slice(None),)+upper_slices] - vq[(slice(None),)+lower_slices] - padded = math.pad(diff, [[0,0]] + [([0,1] if i == dimension else [0,0]) for i in dims]) - components.append(padded) - return math.expand_dims(math.sum(components, 0), -1) - - -def _central_divergence_nd(tensor): +def _divergence_nd(tensor, relative_shifts): rank = spatial_rank(tensor) - dims = range(rank) + tensor = math.pad(tensor, _get_pad_width(rank, (-relative_shifts[0], relative_shifts[1]))) components = [] - tensor = math.pad(tensor, [[0, 0]] + [[1, 1]]*rank + [[0, 0]]) - for dimension in dims: - upper_slices = [(slice(2, None) if i == dimension else slice(1, -1)) for i in dims] - lower_slices = [(slice(-2) if i == dimension else slice(1, -1)) for i in dims] - diff = tensor[(slice(None),) + upper_slices + [rank - dimension - 1]] \ - - tensor[(slice(None),) + lower_slices + [rank - dimension - 1]] - components.append(diff) - return math.expand_dims(math.sum(components, 0), -1) + for dimension in range(rank): + lower, upper = _dim_shifted(tensor, dimension, relative_shifts, diminish_others=(-relative_shifts[0], relative_shifts[1]), components=rank - dimension - 1) + components.append(upper - lower) + return math.sum(components, 0) # Gradient @@ -201,57 +164,29 @@ def gradient(tensor, dx=1, difference='forward', padding='replicate'): :param tensor: channel with shape (batch_size, spatial_dimensions..., 1) :param dx: physical distance between grid points (default 1) :param difference: type of difference, one of ('forward', 'backward', 'central') (default 'forward') + :param padding: tensor padding mode :return: tensor of shape (batch_size, spatial_dimensions..., spatial rank) """ - if tensor.shape[-1] != 1: - raise ValueError('Gradient requires a scalar channel as input') - dims = range(spatial_rank(tensor)) - field = tensor[..., 0] - - if 1 in field.shape[1:]: - raise ValueError('All spatial dimensions must have size larger than 1, got {}'.format(tensor.shape)) - + assert tensor.shape[-1] == 1, "Gradient requires a scalar channel as input" + assert 1 not in tensor.shape[1:-1], "All spatial dimensions must have size larger than 1, got %s" % tensor.shape if difference.lower() == 'central': - return _central_diff_nd(tensor, dims, padding) / (dx * 2) + return _gradient_nd(tensor, padding, (-1, 1)) / (dx * 2) elif difference.lower() == 'forward': - return _forward_diff_nd(field, dims, padding) / dx + return _gradient_nd(tensor, padding, (0, 1)) / dx elif difference.lower() == 'backward': - return _backward_diff_nd(field, dims, padding) / dx + return _gradient_nd(tensor, padding, (-1, 0)) / dx else: raise ValueError('Invalid difference type: {}. Can be CENTRAL or FORWARD'.format(difference)) -def _backward_diff_nd(field, dims, padding): - df_dq = [] - for dimension in dims: - upper_slices = tuple([(slice(1, None) if i==dimension else slice(None)) for i in dims]) - lower_slices = tuple([(slice(-1) if i==dimension else slice(None)) for i in dims]) - diff = field[(slice(None),)+upper_slices] - field[(slice(None),)+lower_slices] - padded = math.pad(diff, [[0,0]]+[([1,0] if i == dimension else [0,0]) for i in dims], mode=padding) - df_dq.append(padded) - return math.stack(df_dq, axis=-1) - - -def _forward_diff_nd(field, dims, padding): - df_dq = [] - for dimension in dims: - upper_slices = tuple([(slice(1, None) if i==dimension else slice(None)) for i in dims]) - lower_slices = tuple([(slice(-1) if i==dimension else slice(None)) for i in dims]) - diff = field[(slice(None),) + upper_slices] - field[(slice(None),) + lower_slices] - padded = math.pad(diff, [[0,0]]+[([0,1] if i == dimension else [0,0]) for i in dims], mode=padding) - df_dq.append(padded) - return math.stack(df_dq, axis=-1) - - -def _central_diff_nd(field, dims, padding): - field = math.pad(field, [[0,0]] + [[1,1]]*spatial_rank(field) + [[0, 0]], mode=padding) - df_dq = [] - for dimension in dims: - upper_slices = tuple([(slice(2, None) if i==dimension else slice(1,-1)) for i in dims]) - lower_slices = tuple([(slice(-2) if i==dimension else slice(1,-1)) for i in dims]) - diff = field[(slice(None),) + upper_slices + (0,)] - field[(slice(None),) + lower_slices + (0,)] - df_dq.append(diff) - return math.stack(df_dq, axis=-1) +def _gradient_nd(tensor, padding, relative_shifts): + rank = spatial_rank(tensor) + tensor = math.pad(tensor, _get_pad_width(rank, (-relative_shifts[0], relative_shifts[1])), mode=padding) + components = [] + for dimension in range(rank): + lower, upper = _dim_shifted(tensor, dimension, relative_shifts, diminish_others=(-relative_shifts[0], relative_shifts[1])) + components.append(upper - lower) + return math.concat(components, axis=-1) def axis_gradient(tensor, spatial_axis): @@ -259,7 +194,7 @@ def axis_gradient(tensor, spatial_axis): upper_slices = tuple([(slice(1, None) if i == spatial_axis else slice(None)) for i in dims]) lower_slices = tuple([(slice(-1) if i == spatial_axis else slice(None)) for i in dims]) diff = tensor[(slice(None),) + upper_slices + (slice(None),)] \ - - tensor[(slice(None),) + lower_slices + (slice(None),)] + - tensor[(slice(None),) + lower_slices + (slice(None),)] return diff @@ -271,16 +206,18 @@ def laplace(tensor, padding='replicate', axes=None): If a vector field is passed, the laplace is computed component-wise. :param tensor: n-dimensional field of shape (batch, spacial dimensions..., components) - :param padding: 'valid', 'constant', 'reflect', 'replicate', 'cyclic' + :param padding: 'valid', 'constant', 'reflect', 'replicate', 'circular' :param axes: The second derivative along these axes is summed over :type axes: list :return: tensor of same shape """ - if padding.lower() == 'cyclic': - return fourier_laplace(tensor) rank = spatial_rank(tensor) - if padding.lower() in ('constant', 'reflect', 'replicate'): - tensor = math.pad(tensor, [[0,0]] + [([1,1] if _contains_axis(axes, i, rank) else [0,0]) for i in range(rank)] + [[0,0]], padding) + if padding is None or padding == 'valid': + pass # do not pad tensor + elif padding in ['circular', 'wrap']: + return fourier_laplace(tensor) + else: + tensor = math.pad(tensor, _get_pad_width_axes(rank, axes, val_true=[1, 1], val_false=[0, 0]), padding) # --- convolutional laplace --- if axes is not None: return _sliced_laplace_nd(tensor, axes) @@ -293,59 +230,62 @@ def laplace(tensor, padding='replicate', axes=None): def _conv_laplace_2d(tensor): - kernel = np.zeros((3, 3, 1, 1), np.float32) - kernel[1, 1, 0, 0] = -4 - kernel[(0,1,1,2), (1,0,2,1), 0, 0] = 1 + kernel = np.array([[0., 1., 0.], [1., -4., 1.], [0., 1., 0.]], dtype=np.float32) + kernel = kernel.reshape((3, 3, 1, 1)) if tensor.shape[-1] == 1: return math.conv(tensor, kernel, padding='VALID') else: - return math.concat([math.conv(tensor[..., i:i+1], kernel, padding='VALID') for i in range(tensor.shape[-1])], -1) + return math.concat([math.conv(tensor[..., i:i + 1], kernel, padding='VALID') for i in range(tensor.shape[-1])], -1) def _conv_laplace_3d(tensor): - kernel = np.zeros((3, 3, 3, 1, 1), np.float32) - kernel[1, 1, 1, 0, 0] = -6 - kernel[(0,1,1,1,1,2), (1,0,2,1,1,1), (1,1,1,0,2,1), 0, 0] = 1 + """ + 3D/Cube laplace stencil in 3D+2D [3,3,3,1,1] + array([[[[[ 0.]], [[ 0.]], [[ 0.]]], + [[[ 0.]], [[ 1.]], [[ 0.]]], + [[[ 0.]], [[ 0.]], [[ 0.]]]], + [[[[ 0.]], [[ 1.]], [[ 0.]]], + [[[ 1.]], [[-6.]], [[ 1.]]], + [[[ 0.]], [[ 1.]], [[ 0.]]]], + [[[[ 0.]], [[ 0.]], [[ 0.]]], + [[[ 0.]], [[ 1.]], [[ 0.]]], + [[[ 0.]], [[ 0.]], [[ 0.]]]]] + returns ... + + padding explicitly done in laplace(), hence here not needed + """ + kernel = np.array([[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 1., 0.], [1., -6., 1.], [0., 1., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]]], + dtype=np.float32) + kernel = kernel.reshape((3, 3, 3, 1, 1)) if tensor.shape[-1] == 1: return math.conv(tensor, kernel, padding='VALID') else: - return math.concat([math.conv(tensor[..., i:i+1], kernel, padding='VALID') for i in range(tensor.shape[-1])], -1) + return math.concat([math.conv(tensor[..., i:i + 1], kernel, padding='VALID') + for i in range(tensor.shape[-1])], -1) def _sliced_laplace_nd(tensor, axes=None): - # Laplace code for n dimensions + """ + Laplace Stencil for N-Dimensions + aggregated from (c)enter, (u)pper, and (l)ower parts + """ rank = spatial_rank(tensor) dims = range(rank) components = [] for ax in dims: if _contains_axis(axes, ax, rank): - center_slices = tuple([(slice(1, -1) if i == ax else (slice(1,-1)) if _contains_axis(axes, i, rank) else slice(None)) for i in dims]) - upper_slices = tuple([(slice(2, None) if i == ax else (slice(1,-1)) if _contains_axis(axes, i, rank) else slice(None)) for i in dims]) - lower_slices = tuple([(slice(-2) if i == ax else (slice(1,-1)) if _contains_axis(axes, i, rank) else slice(None)) for i in dims]) - diff = tensor[(slice(None),) + upper_slices + (slice(None),)] \ - + tensor[(slice(None),) + lower_slices + (slice(None),)] \ - - 2 * tensor[(slice(None),) + center_slices + (slice(None),)] - components.append(diff) + lower, center, upper = _dim_shifted(tensor, ax, (-1, 0, 1), diminish_others=(1, 1), diminish_other_condition=lambda other_ax: _contains_axis(axes, other_ax, rank)) + components.append(upper + lower - 2 * center) return math.sum(components, 0) -def _contains_axis(axes, axis, sp_rank): - assert -sp_rank <= axis < sp_rank - return axes is None or axis in axes or axis+sp_rank in axes - - -def map_for_axes(function, obj, axes, rank): - if axes is None: - return function(obj) - else: - return [(function(collapsed_gather_nd(obj, i)) if _contains_axis(axes, i, rank) else collapsed_gather_nd(obj, i)) for i in range(rank)] - - def fourier_laplace(tensor): frequencies = math.fft(math.to_complex(tensor)) k = fftfreq(math.staticshape(tensor)[1:-1], mode='square') - fft_laplace = -(2*np.pi)**2 * k - return math.ifft(frequencies * fft_laplace) + fft_laplace = -(2 * np.pi)**2 * k + return math.real(math.ifft(frequencies * fft_laplace)) def fftfreq(resolution, mode='vector', dtype=np.float32): @@ -355,7 +295,7 @@ def fftfreq(resolution, mode='vector', dtype=np.float32): k = k.astype(dtype) if mode == 'vector': return k - k = math.sum(k ** 2, axis=-1, keepdims=True) + k = math.sum(k**2, axis=-1, keepdims=True) if mode == 'square': return k else: @@ -366,19 +306,21 @@ def fftfreq(resolution, mode='vector', dtype=np.float32): def downsample2x(tensor, interpolation='linear'): if struct.isstruct(tensor): - return struct.map(lambda s: downsample2x(s, interpolation), tensor, recursive=False) + return struct.map(lambda s: downsample2x(s, interpolation), + tensor, recursive=False) if interpolation.lower() != 'linear': raise ValueError('Only linear interpolation supported') dims = range(spatial_rank(tensor)) - tensor = math.pad(tensor, [[0,0]]+ - [([0, 1] if (dim % 2) != 0 else [0,0]) for dim in tensor.shape[1:-1]] - + [[0,0]], 'replicate') + tensor = math.pad(tensor, + [[0, 0]] + + [([0, 1] if (dim % 2) != 0 else [0, 0]) for dim in tensor.shape[1:-1]] + + [[0, 0]], 'replicate') for dimension in dims: - upper_slices = tuple([(slice(1, None, 2) if i==dimension else slice(None)) for i in dims]) - lower_slices = tuple([(slice(0, None, 2) if i==dimension else slice(None)) for i in dims]) - sum = tensor[(slice(None),)+upper_slices+(slice(None),)] + tensor[(slice(None),)+lower_slices+(slice(None),)] - tensor = sum / 2 + upper_slices = tuple([(slice(1, None, 2) if i == dimension else slice(None)) for i in dims]) + lower_slices = tuple([(slice(0, None, 2) if i == dimension else slice(None)) for i in dims]) + tensor_sum = tensor[(slice(None),) + upper_slices + (slice(None),)] + tensor[(slice(None),) + lower_slices + (slice(None),)] + tensor = tensor_sum / 2 return tensor @@ -391,15 +333,11 @@ def upsample2x(tensor, interpolation='linear'): dims = range(spatial_rank(tensor)) vlen = tensor.shape[-1] spatial_dims = tensor.shape[1:-1] - tensor = math.pad(tensor, [[0, 0]] + [[1, 1]]*spatial_rank(tensor) + [[0, 0]], 'replicate') + rank = spatial_rank(tensor) + tensor = math.pad(tensor, _get_pad_width(rank), 'replicate') for dim in dims: - left_slices_1 = tuple([(slice(2, None) if i==dim else slice(None)) for i in dims]) - left_slices_2 = tuple([(slice(1,-1) if i==dim else slice(None)) for i in dims]) - right_slices_1 = tuple([(slice(1, -1) if i==dim else slice(None)) for i in dims]) - right_slices_2 = tuple([(slice(-2) if i==dim else slice(None)) for i in dims]) - left = 0.75 * tensor[(slice(None),)+left_slices_2+(slice(None),)] + 0.25 * tensor[(slice(None),)+left_slices_1+(slice(None),)] - right = 0.25 * tensor[(slice(None),)+right_slices_2+(slice(None),)] + 0.75 * tensor[(slice(None),)+right_slices_1+(slice(None),)] - combined = math.stack([right, left], axis=2+dim) + lower, center, upper = _dim_shifted(tensor, dim, (-1, 0, 1)) + combined = math.stack([0.25 * lower + 0.75 * center, 0.75 * center + 0.25 * upper], axis=2 + dim) tensor = math.reshape(combined, [-1] + [spatial_dims[dim] * 2 if i == dim else tensor.shape[i + 1] for i in dims] + [vlen]) return tensor @@ -424,5 +362,5 @@ def interpolate_linear(tensor, upper_weight, dimensions): if dimension in dimensions: upper_slices = tuple([(slice(1, None) if i == dimension else slice(None)) for i in all_dimensions(tensor)]) lower_slices = tuple([(slice(-1) if i == dimension else slice(None)) for i in all_dimensions(tensor)]) - tensor = math.mul(tensor[upper_slices], upper_weight[...,dimension-1]) + math.mul(tensor[lower_slices], lower_weight[...,dimension-1]) + tensor = math.mul(tensor[upper_slices], upper_weight[..., dimension - 1]) + math.mul(tensor[lower_slices], lower_weight[..., dimension - 1]) return tensor diff --git a/phi/physics/collective.py b/phi/physics/collective.py index f59addbdc..994fbeabc 100644 --- a/phi/physics/collective.py +++ b/phi/physics/collective.py @@ -1,59 +1,64 @@ +import warnings + import six from phi.struct.context import skip_validate +from phi.struct.structdef import Item from .physics import Physics, State, struct -@struct.definition() -class StateCollection(struct.Struct): - - def __init__(self, states=None, **kwargs): - struct.Struct.__init__(self, **struct.kwargs(locals())) +class StateCollection(dict): - @struct.variable() - def states(self, states): + def __init__(self, states=None): + """ +Create a state collection from a dictionary of states. + :param states: dict mapping from state names to states + :type states: dict or list or tuple None + """ if states is None: - return {} - if isinstance(states, (tuple, list)): - return {state.name: state for state in states} - assert isinstance(states, dict) - return states.copy() + states = {} + elif not isinstance(states, dict): + states = {state.name: state for state in states} + dict.__init__(self, states) + + def __setitem__(self, key, val): + raise AttributeError('StateCollections are immutable') def all_with_tag(self, tag): - return [s for s in self.states.values() if tag in s.tags] + return [s for s in self.values() if tag in s.tags] def all_instances(self, cls): - return [s for s in self.states.values() if isinstance(s, cls)] + return [s for s in self.values() if isinstance(s, cls)] def state_added(self, state): - assert state.name not in self.states,\ - 'A state with name "%s" is already present. Use state_replaced() to replace it.' % state.name - new_states = self.states.copy() + assert state.name not in self, 'A state with name "%s" is already present. Use state_replaced() to replace it.' % state.name + new_states = self.copy() new_states[state.name] = state - return self.copied_with(states=new_states) + return StateCollection(new_states) def state_replaced(self, new_state): - assert new_state.name in self.states, 'No state found with name "%s"' % new_state.name - new_states = {state.name: (state if state.name != new_state.name else new_state) for state in self.states.values()} - return self.copied_with(states=new_states) + assert new_state.name in self, 'No state found with name "%s"' % new_state.name + new_states = dict(self) + new_states[new_state.name] = new_state + return StateCollection(new_states) def state_removed(self, state): name = state if isinstance(state, six.string_types) else state.name - new_states = self.states.copy() + new_states = dict(self) del new_states[name] - return self.copied_with(states=new_states) + return StateCollection(new_states) def find(self, name): - return self.states[name] + warnings.warn("StateCollection.find is deprecated. Use statecollection[name] instead.", DeprecationWarning) + return dict.__getitem__(self, name) def __getitem__(self, item): if isinstance(item, State): return self[item.name] if isinstance(item, six.string_types): - return self.find(item) + return dict.__getitem__(self, item) if struct.isstruct(item): - with struct.unsafe(): - return struct.map(lambda x: self[x], item) + return struct.map(lambda x: self[x], item, content_type=struct.INVALID) try: return self[item.name] except AttributeError as e: @@ -61,58 +66,48 @@ def __getitem__(self, item): raise ValueError('Illegal argument: %s' % item) def __getattr__(self, item): - if item.startswith('_'): - return struct.Struct.__getattribute__(self, item) - if item in self.states: - return self.states[item] - return struct.Struct.__getattribute__(self, item) + return self[item] def default_physics(self): + warnings.warn("StateCollection will be removed in the future.", DeprecationWarning) return CollectivePhysics() def __repr__(self): - return '[' + ', '.join((str(s) for s in self.states)) + ']' - - def __len__(self): - return len(self.states) + return '[' + ', '.join((str(s) for s in self)) + ']' def __contains__(self, item): if isinstance(item, State): - return item.name in self.states + return item.name in self if isinstance(item, six.string_types): - return item in self.states + return dict.__contains__(self, item) raise ValueError('Illegal type: %s' % type(item)) - def _set_items(self, **kwargs): - for name, value in kwargs.items(): - if name in ('states', 'age'): - getattr(self.__class__, name).set(self, value) - else: - self._states = self.states.copy() - if not skip_validate(): - assert isinstance(value, State) - assert value.name == name, 'Inconsisten names: trying to assign state "%s" to name "%s"' % (value.name, name) - assert 'states' not in kwargs - self.states[name] = value + def __hash__(self): + return 0 + + @property + def states(self): return self - def __to_dict__(self, item_condition=None): - return self.states.copy() + def copied_with(self, **kwargs): + if len(kwargs) == 0: + return self + assert len(kwargs) == 1 + name, value = next(iter(kwargs.items())) + assert name == 'states' + return StateCollection(value) - def __properties__(self): - return {} + @property + def shape(self): + return StateCollection({name: state.shape for name, state in self.items()}) - def __properties_dict__(self): - result = {} - for state in self.states.values(): - result[state.name] = struct.properties_dict(state) - result['type'] = str(self.__class__.__name__) - result['module'] = str(self.__class__.__module__) - return result + @property + def staticshape(self): + return StateCollection({name: state.staticshape for name, state in self.items()}) @property - def shape(self): - return struct.map(lambda state: state.shape, self, recursive=False) + def dtype(self): + return StateCollection({name: state.dtype for name, state in self.items()}) CollectiveState = StateCollection @@ -128,7 +123,7 @@ def step(self, state_collection, dt=1.0, **dependent_states): assert len(dependent_states) == 0 if len(state_collection) == 0: return state_collection - unhandled_states = list(state_collection.states.values()) + unhandled_states = list(state_collection.values()) next_states = {} partial_next_state_collection = StateCollection(next_states) @@ -137,13 +132,15 @@ def step(self, state_collection, dt=1.0, **dependent_states): physics = self.for_(state) if self._all_dependencies_fulfilled(physics.blocking_dependencies, state_collection, partial_next_state_collection): next_state = self.substep(state, state_collection, dt, partial_next_state_collection=partial_next_state_collection) + assert next_state is not None, "step() called on %s returned None for state '%s'" % (type(physics).__name__, state) + assert isinstance(next_state, State), "step() called on %s dit not return a State but '%s' for state '%s'" % (type(physics).__name__, next_state, state) assert next_state.name == state.name, "The state name must remain constant during step(). Caused by '%s' on state '%s'." % (type(physics).__name__, state) next_states[next_state.name] = next_state unhandled_states.remove(state) partial_next_state_collection = StateCollection(next_states) if len(unhandled_states) == 0: - ordered_states = [partial_next_state_collection[state] for state in state_collection.states] - return partial_next_state_collection.copied_with(states=ordered_states) + ordered_states = [partial_next_state_collection[state] for state in state_collection] + return StateCollection(ordered_states) # Error errstr = 'Cyclic blocking_dependencies in simulation: %s' % unhandled_states diff --git a/phi/physics/domain.py b/phi/physics/domain.py index 887489ebd..db128d548 100644 --- a/phi/physics/domain.py +++ b/phi/physics/domain.py @@ -66,6 +66,12 @@ def boundaries(self, boundaries): def rank(self): return len(self.resolution) + def __repr__(self): + if self.is_valid: + return '(%s, size=%s)' % (self.resolution, self.box.size) + else: + return struct.Struct.__repr__(self) + def cell_index(self, global_position): local_position = self.box.global_to_local(global_position) * self.resolution position = math.to_int(local_position - 0.5) @@ -104,22 +110,20 @@ def equal(grid1, grid2): def centered_shape(self, components=1, batch_size=1, name=None, extrapolation=None, age=0.0): warnings.warn("Domain.centered_shape and Domain.centered_grid are deprecated. Use CenteredGrid.sample() instead.", DeprecationWarning) - with struct.unsafe(): - from phi.physics.field import CenteredGrid - return CenteredGrid(tensor_shape(batch_size, self.resolution, components), age=age, box=self.box, extrapolation=extrapolation, name=name, batch_size=batch_size, flags=()) + from phi.physics.field import CenteredGrid + return CenteredGrid(tensor_shape(batch_size, self.resolution, components), age=age, box=self.box, extrapolation=extrapolation, name=name, batch_size=batch_size, flags=(), content_type=struct.Struct.shape) def staggered_shape(self, batch_size=1, name=None, extrapolation=None, age=0.0): - with struct.unsafe(): - grids = [] - for axis in range(self.rank): - shape = _extend1(tensor_shape(batch_size, self.resolution, 1), axis) - from phi.physics.field.staggered_grid import staggered_component_box - box = staggered_component_box(self.resolution, axis, self.box) - from phi.physics.field import CenteredGrid - grid = CenteredGrid(shape, box, age=age, extrapolation=extrapolation, name=None, batch_size=batch_size, flags=()) - grids.append(grid) - from phi.physics.field import StaggeredGrid - return StaggeredGrid(grids, age=age, box=self.box, name=name, batch_size=batch_size, extrapolation=extrapolation, flags=()) + grids = [] + for axis in range(self.rank): + shape = _extend1(tensor_shape(batch_size, self.resolution, 1), axis) + from phi.physics.field.staggered_grid import staggered_component_box + box = staggered_component_box(self.resolution, axis, self.box) + from phi.physics.field import CenteredGrid + grid = CenteredGrid(shape, box, age=age, extrapolation=extrapolation, name=None, batch_size=batch_size, flags=(), content_type=struct.Struct.shape) + grids.append(grid) + from phi.physics.field import StaggeredGrid + return StaggeredGrid(grids, age=age, box=self.box, name=name, batch_size=batch_size, extrapolation=extrapolation, flags=(), content_type=struct.Struct.shape) def centered_grid(self, data, components=1, dtype=np.float32, name=None, batch_size=None, extrapolation=None): warnings.warn("Domain.centered_shape and Domain.centered_grid are deprecated. Use CenteredGrid.sample() instead.", DeprecationWarning) diff --git a/phi/physics/field/field.py b/phi/physics/field/field.py index 7f66e4c1d..7de2e8b55 100644 --- a/phi/physics/field/field.py +++ b/phi/physics/field/field.py @@ -25,10 +25,6 @@ def __init__(self, data, name=None, **kwargs): def with_data(self, data): return self.copied_with(data=data, flags=()) - @property - def dtype(self): - return math.dtype(self.data) - @struct.variable() def data(self, data): """ diff --git a/phi/physics/field/grid.py b/phi/physics/field/grid.py index ee08870a0..130606e43 100644 --- a/phi/physics/field/grid.py +++ b/phi/physics/field/grid.py @@ -4,7 +4,7 @@ from phi import math, struct from phi.geom import AABox from phi.geom.geometry import assert_same_rank -from phi.math.nd import map_for_axes +from phi.math.helper import map_for_axes from phi.physics.domain import Domain from phi.physics.material import Material from phi.struct.functions import mappable @@ -25,6 +25,17 @@ def _crop_for_interpolation(data, offset_float, window_resolution): class CenteredGrid(Field): def __init__(self, data, box=None, extrapolation='boundary', name=None, **kwargs): + """Create new CenteredGrid from array like data + + :param data: numerical values to be set as values of CenteredGrid (immutable) + :type data: array-like + :param box: numerical values describing the surrounding area of the CenteredGrid, defaults to None + :type box: domain.box, optional + :param extrapolation: set conditions for boundaries, defaults to 'boundary' + :type extrapolation: str, optional + :param name: give CenteredGrid a custom name (immutable), defaults to None + :type name: string, optional + """ Field.__init__(self, **struct.kwargs(locals())) self._sample_points = None @@ -51,6 +62,7 @@ def data(self, data): while math.ndims(data) < 2: data = math.expand_dims(data) return data + data.override(struct.staticshape, lambda self, data: (self._batch_size,) + math.staticshape(data)[1:]) @property def resolution(self): @@ -85,8 +97,7 @@ def sample_at(self, points, collapse_dimensions=True): return self._padded_resample(points) local_points = self.box.global_to_local(points) local_points = math.mul(local_points, math.to_float(self.resolution)) - 0.5 - boundary = {'periodic': 'circular', 'boundary': 'replicate', 'constant': 'constant'}[self.extrapolation] - resampled = math.resample(self.data, local_points, boundary=boundary, interpolation=self.interpolation) + resampled = math.resample(self.data, local_points, boundary=_pad_mode(self.extrapolation), interpolation=self.interpolation) return resampled def at(self, other_field, collapse_dimensions=True, force_optimization=False, return_self_if_compatible=False): @@ -138,14 +149,13 @@ def compatible(self, other_field): return False def __repr__(self): - try: + if self.is_valid: return 'Grid[%s(%d), size=%s]' % ('x'.join([str(r) for r in self.resolution]), self.component_count, self.box.size) - except: - return 'Grid[invalid]' + else: + return struct.Struct.__repr__(self) def padded(self, widths): - extrapolation = self.extrapolation if isinstance(self.extrapolation, six.string_types) else ['constant'] + list(self.extrapolation) + ['constant'] - data = math.pad(self.data, [[0, 0]]+widths+[[0, 0]], _pad_mode(extrapolation)) + data = math.pad(self.data, [[0, 0]]+widths+[[0, 0]], _pad_mode(self.extrapolation)) w_lower, w_upper = np.transpose(widths) box = AABox(self.box.lower - w_lower * self.dx, self.box.upper + w_upper * self.dx) return self.copied_with(data=data, box=box) @@ -202,18 +212,32 @@ def _required_paddings_transposed(box, dx, target): return [lower, upper] -@mappable() def _pad_mode(extrapolation): - if extrapolation == 'periodic': - return 'wrap' - elif extrapolation == 'boundary': - return 'replicate' + """ Inserts 'constant' padding for batch dimension and channel dimension. """ + if isinstance(extrapolation, six.string_types): + return _pad_mode_str(extrapolation) else: - return extrapolation + return _pad_mode_str(['constant'] + list(extrapolation) + ['constant']) + +@mappable() +def _pad_mode_str(extrapolation): + """ +Converts an extrapolation string (or struct of strings) to a string that can be passed to math functions like math.pad or math.resample. + :param extrapolation: field extrapolation + :return: padding mode, same type as extrapolation + """ + return {'periodic': 'circular', + 'boundary': 'replicate', + 'constant': 'constant'}[extrapolation] + @mappable() def _gradient_extrapolation(extrapolation): - if extrapolation == 'boundary': - return 'constant' - else: - return extrapolation + """ +Given the extrapolation of a field, returns the extrapolation mode of the corresponding gradient field. + :param extrapolation: string or struct of strings + :return: same type as extrapolation + """ + return {'periodic': 'periodic', + 'boundary': 'constant', + 'constant': 'constant'}[extrapolation] diff --git a/phi/physics/field/sampled.py b/phi/physics/field/sampled.py index 2e03ce428..c76b22832 100644 --- a/phi/physics/field/sampled.py +++ b/phi/physics/field/sampled.py @@ -101,6 +101,7 @@ def data(self, data): if isinstance(data, (tuple, list, np.ndarray)): data = math.zeros_like(self.sample_points) + data return data + data.override(struct.staticshape, lambda self, data: (self._batch_size, self._point_count, self.component_count) if math.ndims(self.data) > 0 else ()) @struct.constant(default='add') def mode(self, mode): @@ -111,15 +112,7 @@ def mode(self, mode): def sample_points(self, sample_points): assert math.ndims(sample_points) == 3, sample_points.shape return sample_points - - @property - def shape(self): - with struct.unsafe(): - if math.ndims(self.data) > 0: - data_shape = (self._batch_size, self._point_count, self.component_count) - else: - data_shape = () - return self.copied_with(data=data_shape, sample_points=(self._batch_size, self._point_count, self.rank)) + sample_points.override(struct.staticshape, lambda self, data: (self._batch_size, self._point_count, self.rank)) @property def rank(self): @@ -151,6 +144,8 @@ def __repr__(self): return '%s[%sx(%d), %dD]' % (self.__class__.__name__, self._point_count if self._point_count is not None else '?', self.component_count, self.rank) + + def batch_indices(indices): """ Reshapes the indices such that, aside from indices, they also contain batch number. diff --git a/phi/physics/field/staggered_grid.py b/phi/physics/field/staggered_grid.py index 41c16523c..73b2574d9 100644 --- a/phi/physics/field/staggered_grid.py +++ b/phi/physics/field/staggered_grid.py @@ -78,7 +78,8 @@ def _component_grid(self, grid, axis): assert grid.component_count == 1 assert grid.rank == self.rank assert grid.box == box - assert grid.extrapolation == self.extrapolation + if grid.extrapolation != self.extrapolation: + grid = grid.copied_with(extrapolation=self.extrapolation) else: grid = CenteredGrid(data=grid, box=box, extrapolation=self.extrapolation, name=_subname(self.name, axis), batch_size=self._batch_size, flags=propagate_flags_children(self.flags, box.rank, 1)) @@ -135,6 +136,18 @@ def component_count(self): def unstack(self): return self.data + @struct.derived() + def x(self): + return self.data[-1] + + @struct.derived() + def y(self): + return self.data[-2] + + @struct.derived() + def z(self): + return self.data[-3] + @property def points(self): raise StaggeredSamplePoints(self) @@ -144,7 +157,10 @@ def center_points(self): return CenteredGrid.getpoints(self.box, self.resolution) def __repr__(self): - return 'StaggeredGrid[%s, size=%s]' % ('x'.join([str(r) for r in self.resolution]), self.box.size) + if self.is_valid: + return 'StaggeredGrid[%s, size=%s]' % ('x'.join([str(r) for r in self.resolution]), self.box.size) + else: + return struct.Struct.__repr__(self) def compatible(self, other_field): if not other_field.has_points: @@ -178,10 +194,6 @@ def divergence(self, physical_units=True): data = math.sum(components, 0) return CenteredGrid(data, self.box, name='div(%s)' % self.name, batch_size=self._batch_size) - @property - def dtype(self): - return self.data[0].dtype - @staticmethod def gradient(scalar_field, padding_mode='replicate'): assert isinstance(scalar_field, CenteredGrid) diff --git a/phi/physics/field/util.py b/phi/physics/field/util.py index 5dbe5f1cc..5f72e67dc 100644 --- a/phi/physics/field/util.py +++ b/phi/physics/field/util.py @@ -1,17 +1,33 @@ +# coding=utf-8 import itertools import numpy as np from numpy import pi -from phi import math +from phi import math, struct from phi.geom import AABox -from phi.physics.field import StaggeredGrid -from .field import StaggeredSamplePoints +from phi.physics.field import StaggeredGrid, ConstantField +from .field import StaggeredSamplePoints, Field from .grid import CenteredGrid def diffuse(field, amount, substeps=1): - assert isinstance(field, CenteredGrid) - if field.extrapolation == 'periodic': + u""" +Simulate a finite-time diffusion process of the form dF/dt = α · ΔF on a given `Field` F with diffusion coefficient α. + +If `field` is periodic (set via `extrapolation='periodic'`), diffusion may be simulated in Fourier space. +Otherwise, finite differencing is used to approximate the + :param field: CenteredGrid, StaggeredGrid or ConstantField + :param amount: number of Field, typically α · dt + :param substeps: number of iterations to use + :return: Field of same type as `field` + :rtype: Field + """ + if isinstance(field, ConstantField): + return field + if isinstance(field, StaggeredGrid): + return struct.map(lambda grid: diffuse(grid, amount, substeps=substeps), field, leaf_condition=lambda x: isinstance(x, CenteredGrid)) + assert isinstance(field, CenteredGrid), "Cannot diffuse field of type '%s'" % type(field) + if field.extrapolation == 'periodic' and not isinstance(amount, Field): frequencies = math.fft(field.data) k = math.fftfreq(field.resolution) / field.dx k = math.sum(k ** 2, axis=-1, keepdims=True) @@ -20,6 +36,8 @@ def diffuse(field, amount, substeps=1): data = math.ifft(frequencies * diffuse_kernel) data = math.real(data) else: + if isinstance(amount, Field): + amount = amount.at(field).data data = field.data for i in range(substeps): data += amount / substeps * field.laplace().data diff --git a/phi/physics/fluid.py b/phi/physics/fluid.py index 2460394ea..ffae87764 100644 --- a/phi/physics/fluid.py +++ b/phi/physics/fluid.py @@ -11,8 +11,7 @@ from .field.effect import Gravity, effect_applied, gravity_tensor from .material import OPEN, Material from .physics import Physics, StateDependency -from .pressuresolver.solver_api import FluidDomain -from .pressuresolver.sparse import SparseCG +from .pressuresolver.solver_api import FluidDomain, poisson_solve @struct.definition() @@ -123,20 +122,14 @@ def _is_div_free(velocity, is_div_free): def solve_pressure(divergence, fluiddomain, pressure_solver=None): """ - Computes the pressure from the given velocity or velocity divergence using the specified solver. +Computes the pressure from the given velocity divergence using the specified solver. :param divergence: CenteredGrid :param fluiddomain: FluidDomain instance :param pressure_solver: PressureSolver to use, None for default :return: pressure field, iteration count :rtype: CenteredGrid, int """ - assert isinstance(divergence, CenteredGrid) - if pressure_solver is None: - pressure_solver = SparseCG() - pressure, iteration = pressure_solver.solve(divergence.data, fluiddomain, pressure_guess=None) - if isinstance(divergence, CenteredGrid): - pressure = CenteredGrid(pressure, divergence.box, name='pressure') - return pressure, iteration + return poisson_solve(divergence, fluiddomain, solver=pressure_solver) def divergence_free(velocity, domain=None, obstacles=(), pressure_solver=None, return_info=False): @@ -168,4 +161,4 @@ def divergence_free(velocity, domain=None, obstacles=(), pressure_solver=None, r pressure *= velocity.dx[0] gradp = StaggeredGrid.gradient(pressure) velocity -= fluiddomain.with_hard_boundary_conditions(gradp) - return velocity if not return_info else (velocity, {'pressure': pressure, 'iterations': iterations}) + return velocity if not return_info else (velocity, {'pressure': pressure, 'iterations': iterations, 'divergence': divergence_field}) diff --git a/phi/physics/physics.py b/phi/physics/physics.py index 36036754f..162c76228 100644 --- a/phi/physics/physics.py +++ b/phi/physics/physics.py @@ -3,7 +3,6 @@ """ from phi import struct -from phi.math import staticshape @struct.definition() @@ -58,22 +57,6 @@ def default_physics(self): """ return STATIC - @property - def shape(self): - """ -Similar to phi.math.shape(self) but respects unknown dimensions. - """ - def tensorshape(tensor): - if tensor is None: - return None - default_batched_shape = staticshape(tensor) - if len(default_batched_shape) >= 2: - return (self._batch_size,) + default_batched_shape[1:] - else: - return default_batched_shape - with struct.unsafe(): - return struct.map(tensorshape, self, item_condition=struct.VARIABLES) - @property def state(self): """ @@ -156,7 +139,7 @@ def step(self, state, dt=1.0, **dependent_states): """ Does not alter the state except for increasing its age. """ - return state.copied_with(age=state.age + dt) + return state.map_item(State.age, lambda age: age + dt) STATIC = Static() diff --git a/phi/physics/pressuresolver/geom.py b/phi/physics/pressuresolver/geom.py index ba94effff..633724468 100644 --- a/phi/physics/pressuresolver/geom.py +++ b/phi/physics/pressuresolver/geom.py @@ -2,22 +2,21 @@ from phi import math from phi.math.blas import conjugate_gradient +from phi.math.helper import _dim_shifted from phi.physics.field import CenteredGrid -from .solver_api import PressureSolver, FluidDomain +from .solver_api import PoissonDomain, PoissonSolver -# ToDo can cause NaNs, unsafe - - -class GeometricCG(PressureSolver): +class GeometricCG(PoissonSolver): def __init__(self, accuracy=1e-5, gradient_accuracy='same', max_iterations=2000, max_gradient_iterations='same', autodiff=False): - ''' - Conjugate gradient solver that geometrically calculates laplace pressure in each iteration. - Unlike most other solvers, this algorithm is TPU compatible but usually performs worse than SparseCG. - At the moment, boundary conditions are only partly supported. + """ +Conjugate gradient solver that geometrically calculates laplace pressure in each iteration. +Unlike most other solvers, this algorithm is TPU compatible but usually performs worse than SparseCG. + +Obstacles are allowed to vary between examples but the same number of iterations is performed for each example in one batch. :param accuracy: the maximally allowed error on the divergence channel for each cell :param gradient_accuracy: accuracy applied during backpropagation, number of 'same' to use forward accuracy @@ -29,10 +28,10 @@ def __init__(self, accuracy=1e-5, gradient_accuracy='same', The intermediate results of each loop iteration will be permanently stored if backpropagation is used. If False, replaces autodiff by a forward pressure solve in reverse accumulation backpropagation. This requires less memory but is only accurate if the solution is fully converged. - ''' - PressureSolver.__init__(self, 'Single-Phase Conjugate Gradient', - supported_devices=('CPU', 'GPU', 'TPU'), - supports_guess=True, supports_loop_counter=True, supports_continuous_masks=True) + """ + PoissonSolver.__init__(self, 'Single-Phase Conjugate Gradient', + supported_devices=('CPU', 'GPU', 'TPU'), + supports_guess=True, supports_loop_counter=True, supports_continuous_masks=True) assert isinstance(accuracy, Number), 'invalid accuracy: %s' % accuracy assert gradient_accuracy == 'same' or isinstance(gradient_accuracy, Number), 'invalid gradient_accuracy: %s' % gradient_accuracy assert max_gradient_iterations in ['same', 'mirror'] or isinstance(max_gradient_iterations, Number), 'invalid max_gradient_iterations: %s' % max_gradient_iterations @@ -48,19 +47,19 @@ def __init__(self, accuracy=1e-5, gradient_accuracy='same', assert not autodiff, 'Cannot specify max_gradient_iterations when autodiff=True' self.autodiff = autodiff - def solve(self, divergence, domain, pressure_guess): - assert isinstance(domain, FluidDomain) + def solve(self, divergence, domain, guess): + assert isinstance(domain, PoissonDomain) fluid_mask = domain.accessible_tensor(extend=1) if self.autodiff: - return solve_pressure_forward(divergence, fluid_mask, self.max_iterations, pressure_guess, self.accuracy, domain, back_prop=True) + return solve_pressure_forward(divergence, fluid_mask, self.max_iterations, guess, self.accuracy, domain, back_prop=True) else: def pressure_gradient(op, grad): return solve_pressure_forward(grad, fluid_mask, max_gradient_iterations, None, self.gradient_accuracy, domain)[0] pressure, iteration = math.with_custom_gradient( solve_pressure_forward, - [divergence, fluid_mask, self.max_iterations, pressure_guess, self.accuracy, domain], + [divergence, fluid_mask, self.max_iterations, guess, self.accuracy, domain], pressure_gradient, input_index=0, output_index=0, name_base='geom_solve' ) @@ -70,12 +69,13 @@ def pressure_gradient(op, grad): def solve_pressure_forward(divergence, fluid_mask, max_iterations, guess, accuracy, domain, back_prop=False): + from phi.physics.material import Material + extrapolation = Material.extrapolation_mode(domain.domain.boundaries) def apply_A(pressure): - from phi.physics.material import Material - mode = 'replicate' if Material.solid(domain.domain.boundaries) else 'constant' - padded = math.pad(pressure, [[0,0]] + [[1,1]]*(math.ndims(pressure)-2) + [[0,0]], mode=mode) - return _weighted_sliced_laplace_nd(padded, weights=fluid_mask) + pressure = CenteredGrid(pressure, extrapolation=extrapolation) + pressure_padded = pressure.padded([[1, 1]] * pressure.rank) + return _weighted_sliced_laplace_nd(pressure_padded.data, weights=fluid_mask) return conjugate_gradient(divergence, apply_A, guess, accuracy, max_iterations, back_prop=back_prop) @@ -86,20 +86,8 @@ def _weighted_sliced_laplace_nd(tensor, weights): dims = range(math.spatial_rank(tensor)) components = [] for dimension in dims: - center_slices = tuple([(slice(1, -1) if i == dimension else slice(1,-1)) for i in dims]) - upper_slices = tuple([(slice(2, None) if i == dimension else slice(1,-1)) for i in dims]) - lower_slices = tuple([(slice(-2) if i == dimension else slice(1,-1)) for i in dims]) - - lower_weights = weights[(slice(None),) + lower_slices + (slice(None),)] * weights[(slice(None),) + center_slices + (slice(None),)] - upper_weights = weights[(slice(None),) + upper_slices + (slice(None),)] * weights[(slice(None),) + center_slices + (slice(None),)] - center_weights = - lower_weights - upper_weights - - lower_values = tensor[(slice(None),) + lower_slices + (slice(None),)] - upper_values = tensor[(slice(None),) + upper_slices + (slice(None),)] - center_values = tensor[(slice(None),) + center_slices + (slice(None),)] - - diff = math.mul(upper_values, upper_weights) + \ - math.mul(lower_values, lower_weights) + \ - math.mul(center_values, center_weights) + lower_weights, center_weights, upper_weights = _dim_shifted(weights, dimension, (-1, 0, 1), diminish_others=(1, 1)) + lower_values, center_values, upper_values = _dim_shifted(tensor, dimension, (-1, 0, 1), diminish_others=(1, 1)) + diff = math.mul(upper_values, upper_weights * center_weights) + math.mul(lower_values, lower_weights * center_weights) + math.mul(center_values, - lower_weights - upper_weights) components.append(diff) return math.sum(components, 0) diff --git a/phi/physics/pressuresolver/solver_api.py b/phi/physics/pressuresolver/solver_api.py index 33240057c..1bccad99b 100644 --- a/phi/physics/pressuresolver/solver_api.py +++ b/phi/physics/pressuresolver/solver_api.py @@ -4,11 +4,12 @@ from phi.physics.domain import Domain from phi.physics.field import CenteredGrid from phi.physics.material import Material +from phi.struct.functions import mappable -class PressureSolver(object): +class PoissonSolver(object): """ - Base class for solvers + Base class for Poisson solvers """ def __init__(self, name, supported_devices, supports_guess, supports_loop_counter, supports_continuous_masks): @@ -19,14 +20,14 @@ def __init__(self, name, supported_devices, supports_guess, supports_loop_counte self.supports_loop_counter = supports_loop_counter self.supports_continuous_masks = supports_continuous_masks - def solve(self, divergence, domain, pressure_guess): + def solve(self, field, domain, guess): """ - Solves the pressure equation Δp = ∇·v for all active fluid cells where active cells are given by the active_mask. - The resulting pressure is expected to fulfill (Δp-∇·v) ≤ accuracy for every active cell. + Solves the Poisson equation Δp = d for p for all active fluid cells where active cells are given by the active_mask. + p is expected to fulfill (Δp-d) ≤ accuracy for every active cell. - :param divergence: the scalar divergence of the velocity channel, ∇·v + :param field: scalar input field to the solve, e.g. the divergence of the velocity channel, ∇·v :param domain: DomainState object specifying boundary conditions and active/fluid masks. The domain must be equal for all examples (batch dimension equal to 1). - :param pressure_guess: (Optional) Pressure channel which can be used as an initial state for the solver + :param guess: (Optional) Pressure channel which can be used as an initial state for the solver :return: pressure tensor (same shape as divergence tensor), number of iterations (integer, 1D integer tensor or None if unknown) """ raise NotImplementedError(self.__class__) @@ -36,8 +37,11 @@ def __repr__(self): return self.name +PressureSolver = PoissonSolver + + @struct.definition() -class FluidDomain(struct.Struct): +class PoissonDomain(struct.Struct): def __init__(self, domain, valid_state=(), active=None, accessible=None, **kwargs): struct.Struct.__init__(self, **struct.kwargs(locals(), ignore='valid_state')) @@ -50,13 +54,15 @@ def domain(self, domain): @struct.constant(dependencies='domain') def active(self, active): + extrapolation = _active_extrapolation(Material.extrapolation_mode(self.domain.boundaries)) if active is not None: assert isinstance(active, CenteredGrid) assert active.rank == self.domain.rank - assert active.extrapolation == 'constant' + if active.extrapolation != extrapolation: + active = active.copied_with(extrapolation=extrapolation) return active else: - return self.domain.centered_grid(1, extrapolation='constant') + return self.domain.centered_grid(1, extrapolation=extrapolation) @struct.constant(dependencies='domain') def accessible(self, accessible): @@ -67,14 +73,10 @@ def accessible(self, accessible): else: return self.domain.centered_grid(1, extrapolation=Material.extrapolation_mode(self.domain.boundaries)) - @property def rank(self): return self.domain.rank - def is_valid(self, state): - return self._valid_state == state - def active_tensor(self, extend=0): """ Scalar channel encoding active cells as ones and inactive (open/obstacle) as zero. @@ -82,7 +84,7 @@ def active_tensor(self, extend=0): :param extend: Extend the grid in all directions beyond the grid size specified by the domain """ - return math.pad(self.active.data, [[0, 0]] + [[extend, extend]] * self.rank + [[0, 0]], "constant") + return self.active.padded([[extend, extend]] * self.rank).data def accessible_tensor(self, extend=0): """ @@ -107,3 +109,34 @@ def _frictionless_velocity_mask(self, velocity): lower = self.accessible.padded([[1, 0] if ax == axis else [0, 0] for ax in range(self.rank)]) tensors.append(math.minimum(upper.data, lower.data)) return velocity.with_data(tensors) + + +FluidDomain = PoissonDomain + + +@mappable() +def _active_extrapolation(boundaries): + return 'periodic' if boundaries == 'periodic' else 'constant' + + +def poisson_solve(input_field, poisson_domain, solver=None): + """ +Solves the Poisson equation Δp = input_field for p. + :param input_field: CenteredGrid + :param poisson_domain: PoissonDomain instance + :param solver: PoissonSolver to use, None for default + :return: p as CenteredGrid, iteration count as int or None if not available + :rtype: CenteredGrid, int + """ + from .sparse import SparseSciPy, SparseCG + assert isinstance(input_field, CenteredGrid) + if isinstance(poisson_domain, Domain): + poisson_domain = PoissonDomain(poisson_domain) + if solver is None: + if math.choose_backend([input_field.data, poisson_domain.active.data, poisson_domain.accessible.data]).matches_name('SciPy'): + solver = SparseSciPy() + else: + solver = SparseCG() + pressure, iteration = solver.solve(input_field.data, poisson_domain, guess=None) + pressure = CenteredGrid(pressure, input_field.box, name='pressure') + return pressure, iteration diff --git a/phi/physics/pressuresolver/sparse.py b/phi/physics/pressuresolver/sparse.py index e683916c7..d5502f8f4 100644 --- a/phi/physics/pressuresolver/sparse.py +++ b/phi/physics/pressuresolver/sparse.py @@ -7,24 +7,25 @@ from phi import math from phi.math.blas import conjugate_gradient -from .solver_api import PressureSolver, FluidDomain +from phi.math.helper import _dim_shifted +from phi.physics.material import Material +from phi.struct.tensorop import collapsed_gather_nd +from .solver_api import PoissonSolver, FluidDomain -class SparseSciPy(PressureSolver): +class SparseSciPy(PoissonSolver): def __init__(self): """ The SciPy solver uses the function scipy.sparse.linalg.spsolve to determine the pressure. It does not support initial guesses for the pressure and does not keep track of a loop counter. """ - PressureSolver.__init__(self, 'SciPy sparse solver', - supported_devices=('CPU',), - supports_guess=False, supports_loop_counter=False, supports_continuous_masks=True) + PoissonSolver.__init__(self, 'SciPy sparse solver', supported_devices=('CPU',), supports_guess=False, supports_loop_counter=False, supports_continuous_masks=True) - def solve(self, divergence, domain, pressure_guess): + def solve(self, field, domain, guess): assert isinstance(domain, FluidDomain) - dimensions = list(divergence.shape[1:-1]) - A = sparse_pressure_matrix(dimensions, domain.active_tensor(extend=1), domain.accessible_tensor(extend=1)) + dimensions = list(field.shape[1:-1]) + A = sparse_pressure_matrix(dimensions, domain.active_tensor(extend=1), domain.accessible_tensor(extend=1), Material.periodic(domain.domain.boundaries)) def np_solve_p(div): div_vec = div.reshape([-1, A.shape[0]]) @@ -32,75 +33,19 @@ def np_solve_p(div): return np.array(pressure).reshape(div.shape).astype(np.float32) def np_solve_p_gradient(op, grad_in): - return math.py_func(np_solve_p, [grad_in], np.float32, divergence.shape) + return math.py_func(np_solve_p, [grad_in], np.float32, field.shape) - pressure = math.py_func(np_solve_p, [divergence], np.float32, divergence.shape, grad=np_solve_p_gradient) + pressure = math.py_func(np_solve_p, [field], np.float32, field.shape, grad=np_solve_p_gradient) return pressure, None -def sparse_pressure_matrix(dimensions, extended_active_mask, extended_fluid_mask): - """ - Builds a sparse matrix such that when applied to a flattened pressure channel, it calculates the laplace - of that channel, taking into account obstacles and empty cells. - - :param dimensions: valid simulation dimensions. Pressure channel should be of shape (batch size, dimensions..., 1) - :param extended_active_mask: Binary tensor with 2 more entries in every dimension than 'dimensions'. - :param extended_fluid_mask: Binary tensor with 2 more entries in every dimension than 'dimensions'. - :return: SciPy sparse matrix that acts as a laplace on a flattened pressure channel given obstacles and empty cells - """ - N = int(np.prod(dimensions)) - d = len(dimensions) - A = scipy.sparse.lil_matrix((N, N), dtype=np.float32) - dims = range(d) - - center_values = None # diagonal matrix entries +class SparseCG(PoissonSolver): - gridpoints_linear = np.arange(N) - gridpoints = np.stack(np.unravel_index(gridpoints_linear, dimensions)) # d * (N^2) array mapping from linear to spatial frames - - for dim in dims: - upper_indices = tuple([slice(None)] + [slice(2, None) if i == dim else slice(1, -1) for i in dims] + [slice(None)]) - center_indices = tuple([slice(None)] + [slice(1, -1) if i == dim else slice(1, -1) for i in dims] + [slice(None)]) - lower_indices = tuple([slice(None)] + [slice(0, -2) if i == dim else slice(1, -1) for i in dims] + [slice(None)]) - - self_active = extended_active_mask[center_indices] - stencil_upper = extended_active_mask[upper_indices] * self_active - stencil_lower = extended_active_mask[lower_indices] * self_active - stencil_center = - extended_fluid_mask[upper_indices] - extended_fluid_mask[lower_indices] - - if center_values is None: - center_values = math.flatten(stencil_center) - else: - center_values = center_values + math.flatten(stencil_center) - - # Find entries in matrix - dim_direction = np.zeros_like(gridpoints) - dim_direction[dim] = 1 - # Upper frames - upper_indices = gridpoints + dim_direction - upper_in_range_inx = np.nonzero(upper_indices[dim] < dimensions[dim]) - upper_indices_linear = np.ravel_multi_index(upper_indices[:, upper_in_range_inx], dimensions) - A[gridpoints_linear[upper_in_range_inx], upper_indices_linear] = stencil_upper.flatten()[upper_in_range_inx] - # Lower frames - lower_indices = gridpoints - dim_direction - lower_in_range_inx = np.nonzero(lower_indices[dim] >= 0) - lower_indices_linear = np.ravel_multi_index(lower_indices[:, lower_in_range_inx], dimensions) - A[gridpoints_linear[lower_in_range_inx], lower_indices_linear] = stencil_lower.flatten()[lower_in_range_inx] - - A[gridpoints_linear, gridpoints_linear] = math.minimum(center_values, -1) - - return scipy.sparse.csc_matrix(A) - - -class SparseCG(PressureSolver): - - def __init__(self, accuracy=1e-5, gradient_accuracy='same', - max_iterations=2000, max_gradient_iterations='same', - autodiff=False): + def __init__(self, accuracy=1e-5, gradient_accuracy='same', max_iterations=2000, max_gradient_iterations='same', autodiff=False): """ Conjugate gradient solver using sparse matrix multiplications. - :param accuracy: the maximally allowed error on the divergence channel for each cell + :param accuracy: the maximally allowed error for each cell, measured in terms of field values. :param gradient_accuracy: accuracy applied during backpropagation, number of 'same' to use forward accuracy :param max_iterations: integer specifying maximum conjugent gradient loop iterations or None for no limit :param max_gradient_iterations: maximum loop iterations during backpropagation, @@ -111,9 +56,7 @@ def __init__(self, accuracy=1e-5, gradient_accuracy='same', If False, replaces autodiff by a forward pressure solve in reverse accumulation backpropagation. This requires less memory but is only accurate if the solution is fully converged. """ - PressureSolver.__init__(self, 'Sparse Conjugate Gradient', - supported_devices=('CPU', 'GPU'), - supports_guess=True, supports_loop_counter=True, supports_continuous_masks=True) + PoissonSolver.__init__(self, 'Sparse Conjugate Gradient', supported_devices=('CPU', 'GPU'), supports_guess=True, supports_loop_counter=True, supports_continuous_masks=True) assert isinstance(accuracy, Number), 'invalid accuracy: %s' % accuracy assert gradient_accuracy == 'same' or isinstance(gradient_accuracy, Number), 'invalid gradient_accuracy: %s' % gradient_accuracy assert max_gradient_iterations in ['same', 'mirror'] or isinstance(max_gradient_iterations, Number), 'invalid max_gradient_iterations: %s' % max_gradient_iterations @@ -129,28 +72,29 @@ def __init__(self, accuracy=1e-5, gradient_accuracy='same', assert not autodiff, 'Cannot specify max_gradient_iterations when autodiff=True' self.autodiff = autodiff - def solve(self, divergence, domain, pressure_guess): + def solve(self, field, domain, guess): assert isinstance(domain, FluidDomain) active_mask = domain.active_tensor(extend=1) fluid_mask = domain.accessible_tensor(extend=1) - dimensions = list(divergence.shape[1:-1]) + dimensions = math.staticshape(field)[1:-1] N = int(np.prod(dimensions)) + periodic = Material.periodic(domain.domain.boundaries) - if math.choose_backend(divergence).matches_name('SciPy'): - A = sparse_pressure_matrix(dimensions, active_mask, fluid_mask) + if math.choose_backend([field, active_mask, fluid_mask]).matches_name('SciPy'): + A = sparse_pressure_matrix(dimensions, active_mask, fluid_mask, periodic) else: - sidx, sorting = sparse_indices(dimensions) - sval_data = sparse_values(dimensions, active_mask, fluid_mask, sorting) - A = math.choose_backend(divergence).sparse_tensor(indices=sidx, values=sval_data, shape=[N, N]) + sidx, sorting = sparse_indices(dimensions, periodic) + sval_data = sparse_values(dimensions, active_mask, fluid_mask, sorting, periodic) + A = math.choose_backend(field).sparse_tensor(indices=sidx, values=sval_data, shape=[N, N]) if self.autodiff: - return sparse_cg(divergence, A, self.max_iterations, pressure_guess, self.accuracy, back_prop=True) + return sparse_cg(field, A, self.max_iterations, guess, self.accuracy, back_prop=True) else: def pressure_gradient(op, grad): return sparse_cg(grad, A, max_gradient_iterations, None, self.gradient_accuracy)[0] pressure, iteration = math.with_custom_gradient(sparse_cg, - [divergence, A, self.max_iterations, pressure_guess, self.accuracy], + [field, A, self.max_iterations, guess, self.accuracy], pressure_gradient, input_index=0, output_index=0, name_base='scg_pressure_solve') @@ -158,49 +102,81 @@ def pressure_gradient(op, grad): return pressure, iteration -def sparse_cg(divergence, A, max_iterations, guess, accuracy, back_prop=False): - div_vec = math.reshape(divergence, [-1, int(np.prod(divergence.shape[1:]))]) +def sparse_cg(field, A, max_iterations, guess, accuracy, back_prop=False): + div_vec = math.reshape(field, [-1, int(np.prod(field.shape[1:]))]) if guess is not None: - guess = math.reshape(guess, [-1, int(np.prod(divergence.shape[1:]))]) + guess = math.reshape(guess, [-1, int(np.prod(field.shape[1:]))]) apply_A = lambda pressure: math.matmul(A, pressure) result_vec, iterations = conjugate_gradient(div_vec, apply_A, guess, accuracy, max_iterations, back_prop) - return math.reshape(result_vec, math.shape(divergence)), iterations + return math.reshape(result_vec, math.shape(field)), iterations -def sparse_indices(dimensions): +def sparse_pressure_matrix(dimensions, extended_active_mask, extended_fluid_mask, periodic=False): + """ +Builds a sparse matrix such that when applied to a flattened pressure channel, it calculates the laplace +of that channel, taking into account obstacles and empty cells. + + :param dimensions: valid simulation dimensions. Pressure channel should be of shape (batch size, dimensions..., 1) + :param extended_active_mask: Binary tensor with 2 more entries in every dimension than 'dimensions'. + :param extended_fluid_mask: Binary tensor with 2 more entries in every dimension than 'dimensions'. + :return: SciPy sparse matrix that acts as a laplace on a flattened pressure channel given obstacles and empty cells + """ N = int(np.prod(dimensions)) d = len(dimensions) + A = scipy.sparse.lil_matrix((N, N), dtype=np.float32) dims = range(d) - gridpoints_linear = np.arange(N) - gridpoints = np.stack(np.unravel_index(gridpoints_linear, dimensions)) # d * (N^2) array mapping from linear to spatial frames + diagonal_entries = np.zeros(N, extended_active_mask.dtype) # diagonal matrix entries - indices_list = [np.stack([gridpoints_linear] * 2, axis=-1)] + gridpoints_linear = np.arange(N) + gridpoints = np.stack(np.unravel_index(gridpoints_linear, dimensions)) # d * (N^2) array mapping from linear to spatial frames for dim in dims: - dim_direction = np.zeros_like(gridpoints) - dim_direction[dim] = 1 - # Upper frames - upper_indices = gridpoints + dim_direction - upper_in_range_inx = np.nonzero(upper_indices[dim] < dimensions[dim]) - upper_indices_linear = np.ravel_multi_index(upper_indices[:, upper_in_range_inx], dimensions)[0, :] - indices_list.append(np.stack([gridpoints_linear[upper_in_range_inx], upper_indices_linear], axis=-1)) - # Lower frames - lower_indices = gridpoints - dim_direction - lower_in_range_inx = np.nonzero(lower_indices[dim] >= 0) - lower_indices_linear = np.ravel_multi_index(lower_indices[:, lower_in_range_inx], dimensions)[0, :] - indices_list.append(np.stack([gridpoints_linear[lower_in_range_inx], lower_indices_linear], axis=-1)) + lower_active, self_active, upper_active = _dim_shifted(extended_active_mask, dim, (-1, 0, 1), diminish_others=(1,1)) + lower_accessible, upper_accessible = _dim_shifted(extended_fluid_mask, dim, (-1, 1), diminish_others=(1, 1)) - indices = np.concatenate(indices_list, axis=0) + stencil_upper = upper_active * self_active + stencil_lower = lower_active * self_active + stencil_center = - lower_accessible - upper_accessible - sorting = np.lexsort(np.transpose(indices)[:, ::-1]) + diagonal_entries += math.flatten(stencil_center) - sorted_indices = indices[sorting] + dim_direction = math.expand_dims([1 if i == dim else 0 for i in range(d)], axis=-1) + # --- Stencil upper cells --- + upper_points, upper_idx = wrap_or_discard(gridpoints + dim_direction, dim, dimensions, periodic=collapsed_gather_nd(periodic, [dim, 1])) + A[gridpoints_linear[upper_idx], upper_points] = stencil_upper.flatten()[upper_idx] + # --- Stencil lower cells --- + lower_points, lower_idx = wrap_or_discard(gridpoints - dim_direction, dim, dimensions, periodic=collapsed_gather_nd(periodic, [dim, 0])) + A[gridpoints_linear[lower_idx], lower_points] = stencil_lower.flatten()[lower_idx] + + A[gridpoints_linear, gridpoints_linear] = math.minimum(diagonal_entries, -1) # avoid 0, could lead to NaN + return scipy.sparse.csc_matrix(A) + + +def sparse_indices(dimensions, periodic=False): + N = int(np.prod(dimensions)) + d = len(dimensions) + dims = range(d) + gridpoints_linear = np.arange(N) + gridpoints = np.stack(np.unravel_index(gridpoints_linear, dimensions)) # d * (N^2) array mapping from linear to spatial frames + indices_list = [np.stack([gridpoints_linear] * 2, axis=-1)] + for dim in dims: + dim_direction = math.expand_dims([1 if i == dim else 0 for i in range(d)], axis=-1) + # --- Stencil upper cells --- + upper_points, upper_idx = wrap_or_discard(gridpoints + dim_direction, dim, dimensions, periodic=collapsed_gather_nd(periodic, [dim, 1])) + indices_list.append(np.stack([gridpoints_linear[upper_idx], upper_points], axis=-1)) + # --- Stencil lower cells --- + lower_points, lower_idx = wrap_or_discard(gridpoints - dim_direction, dim, dimensions, periodic=collapsed_gather_nd(periodic, [dim, 0])) + indices_list.append(np.stack([gridpoints_linear[lower_idx], lower_points], axis=-1)) + indices = np.concatenate(indices_list, axis=0) + # --- Sort indices --- + sorting = np.lexsort(np.transpose(indices)[:, ::-1]) + sorted_indices = indices[sorting] return sorted_indices, sorting -def sparse_values(dimensions, extended_active_mask, extended_fluid_mask, sorting=None): +def sparse_values(dimensions, extended_active_mask, extended_fluid_mask, sorting=None, periodic=False): """ Builds a sparse matrix such that when applied to a flattened pressure channel, it calculates the laplace of that channel, taking into account obstacles and empty cells. @@ -215,41 +191,51 @@ def sparse_values(dimensions, extended_active_mask, extended_fluid_mask, sorting dims = range(d) values_list = [] - center_values = None # diagonal matrix entries + diagonal_entries = 0 # diagonal matrix entries gridpoints_linear = np.arange(N) gridpoints = np.stack(np.unravel_index(gridpoints_linear, dimensions)) # d * (N^2) array mapping from linear to spatial frames for dim in dims: - upper_indices = tuple([slice(None)] + [slice(2, None) if i == dim else slice(1, -1) for i in dims] + [slice(None)]) - center_indices = tuple([slice(None)] + [slice(1, -1) if i == dim else slice(1, -1) for i in dims] + [slice(None)]) - lower_indices = tuple([slice(None)] + [slice(0, -2) if i == dim else slice(1, -1) for i in dims] + [slice(None)]) + lower_active, self_active, upper_active = _dim_shifted(extended_active_mask, dim, (-1, 0, 1), diminish_others=(1, 1)) + lower_accessible, upper_accessible = _dim_shifted(extended_fluid_mask, dim, (-1, 1), diminish_others=(1, 1)) - self_active = extended_active_mask[center_indices] - stencil_upper = extended_active_mask[upper_indices] * self_active - stencil_lower = extended_active_mask[lower_indices] * self_active - stencil_center = - extended_fluid_mask[upper_indices] - extended_fluid_mask[lower_indices] + stencil_upper = upper_active * self_active + stencil_lower = lower_active * self_active + stencil_center = - lower_accessible - upper_accessible - if center_values is None: - center_values = math.flatten(stencil_center) - else: - center_values = center_values + math.flatten(stencil_center) - - dim_direction = np.zeros_like(gridpoints) - dim_direction[dim] = 1 - # Upper frames - upper_indices = gridpoints + dim_direction - upper_in_range_inx = np.nonzero(upper_indices[dim] < dimensions[dim])[0] - values_list.append(math.gather(math.flatten(stencil_upper), upper_in_range_inx)) - # Lower frames - lower_indices = gridpoints - dim_direction - lower_in_range_inx = np.nonzero(lower_indices[dim] >= 0)[0] - values_list.append(math.gather(math.flatten(stencil_lower), lower_in_range_inx)) - - center_values = math.minimum(center_values, -1.) - values_list.insert(0, center_values) + diagonal_entries += math.flatten(stencil_center) + dim_direction = math.expand_dims([1 if i == dim else 0 for i in range(d)], axis=-1) + # --- Stencil upper cells --- + upper_points, upper_idx = wrap_or_discard(gridpoints + dim_direction, dim, dimensions, periodic=collapsed_gather_nd(periodic, [dim, 1])) + values_list.append(math.gather(math.flatten(stencil_upper), upper_idx)) + # --- Stencil lower cells --- + lower_points, lower_idx = wrap_or_discard(gridpoints - dim_direction, dim, dimensions, periodic=collapsed_gather_nd(periodic, [dim, 0])) + values_list.append(math.gather(math.flatten(stencil_lower), lower_idx)) + + values_list.insert(0, math.minimum(diagonal_entries, -1.)) values = math.concat(values_list, axis=0) if sorting is not None: values = math.gather(values, sorting) return values + + +def wrap_or_discard(points, check_bounds_dim, dimensions, periodic=False): + """ +Handles points that lie outside the domain by either discarding them or wrapping them, depending on periodic. + :param points: grid indices, typically of shape (dimensions, cell_count) + :param check_bounds_dim: int + :param dimensions: domain resolution + :param periodic: if False: discard indices outside domain, if True: wrap indices outside domain + :return: + """ + if not periodic: + upper_in_range_inx = np.nonzero((points[check_bounds_dim] < dimensions[check_bounds_dim]) & (points[check_bounds_dim] >= 0))[0] + new_points = points[:, upper_in_range_inx] # discard points outside domain + else: + upper_in_range_inx = slice(None) + new_points = points % math.expand_dims(dimensions, -1) # wrap points + + linear_points = np.ravel_multi_index(new_points, dimensions) + return linear_points, upper_in_range_inx diff --git a/phi/struct/__init__.py b/phi/struct/__init__.py index 3b64bcbb5..54928ff80 100644 --- a/phi/struct/__init__.py +++ b/phi/struct/__init__.py @@ -1,7 +1,8 @@ from .context import unsafe +from .item_condition import DATA, VARIABLES, CONSTANTS, ALL_ITEMS from .trait import Trait -from .structdef import definition, variable, constant, derived, DATA, VARIABLES, CONSTANTS, ALL_ITEMS -from .struct import Struct, kwargs, to_dict, variables, constants, properties_dict, copy_with, isstruct, equal +from .structdef import definition, variable, constant, derived +from .struct import Struct, kwargs, to_dict, variables, constants, properties_dict, copy_with, isstruct, equal, VALID, INVALID # pylint: disable-msg = redefined-builtin -from .functions import flatten, names, map, zip, Trace, compare, print_differences +from .functions import flatten, names, map, map_item, zip, Trace, compare, print_differences, shape, staticshape, dtype diff --git a/phi/struct/context.py b/phi/struct/context.py index ffb5207ed..d4ed77d4f 100644 --- a/phi/struct/context.py +++ b/phi/struct/context.py @@ -1,3 +1,4 @@ +import warnings from contextlib import contextmanager @@ -5,13 +6,22 @@ @contextmanager -def unsafe(): - _STRUCT_CONTEXT_STACK.append('unsafe') +def _struct_context(object): + _STRUCT_CONTEXT_STACK.append(object) try: yield None finally: _STRUCT_CONTEXT_STACK.pop(-1) +def unsafe(): + warnings.warn("struct.unsafe() is deprecated. Use map() with new_type argument to avoid validation.") + return _struct_context('unsafe') + + +def _unsafe(): + return _struct_context('unsafe') + + def skip_validate(): return 'unsafe' in _STRUCT_CONTEXT_STACK diff --git a/phi/struct/functions.py b/phi/struct/functions.py index 25ff04b33..4e2c6db3b 100644 --- a/phi/struct/functions.py +++ b/phi/struct/functions.py @@ -1,18 +1,28 @@ -import six +import warnings -from .context import unsafe -from .struct import copy_with, equal, isstruct, to_dict -from .structdef import ALL_ITEMS, DATA +import six +from ..backend.dynamic_backend import DYNAMIC_BACKEND as math, NoBackendFound +from .context import _unsafe +from .item_condition import ALL_ITEMS, context_item_condition +from .structdef import Item +from .struct import copy_with, equal, isstruct, to_dict, Struct, VALID, INVALID, items -def flatten(struct, leaf_condition=None, trace=False, item_condition=DATA): - result = [] +def flatten(struct, leaf_condition=None, trace=False, item_condition=None): + """ +Generates a list of all leaves by recursively iterating over the given struct. + :param struct: struct or leaf + :param leaf_condition: (optional) function that determines which structs are treated as leaves. Non-structs are always treated as leaves. + :param trace: If True, returns a list of Trace objects instead of values. + :param item_condition: (optional) ItemCondition or boolean function that filters which Items are accumulated. + :return: list containing all leaves in the struct hierarchy + """ def map_leaf(value): result.append(value) return value - with unsafe(): - map(map_leaf, struct, leaf_condition, recursive=True, trace=trace, item_condition=item_condition) + result = [] + map(map_leaf, struct, leaf_condition, recursive=True, trace=trace, item_condition=item_condition, content_type=INVALID) return result @@ -22,21 +32,37 @@ def to_name(trace): return trace.name if basename is None else basename + separator + trace.name else: return trace.path(separator) if basename is None else basename + separator + trace.path(separator) - with unsafe(): - return map(to_name, struct, leaf_condition, recursive=True, trace=True) + return map(to_name, struct, leaf_condition, recursive=True, trace=True, content_type=names) -def zip(structs, leaf_condition=None, item_condition=DATA, zip_parents_if_incompatible=False): +def zip(structs, leaf_condition=None, item_condition=None, zip_parents_if_incompatible=False): + """ +Builds a single struct containing LeaefZip entries from a list of compatible structs. +Passing zipped structs to 'map' will call the mapping function with the all leaves at equal positions in the structure. + +Example `struct.map(lambda x, y: x+y, struct.zip([{0: 'Hello'}, {0: ' World'}]))` returns `{0: 'Hello World'}`. + :param structs: iterable collection of structs or leaves + :param leaf_condition: (optional) function that determines which structs are treated as leaves. Non-structs are always treated as leaves. + :param item_condition: (optional) ItemCondition or boolean function that filters which Items are zipped. Excluded items should have the same values among all structs. + :param zip_parents_if_incompatible: If True, suppresses IncompatibleStructs errors if structs with non-matching excluded items are encountered. Instead, these structs are treated as leaves and zipped. + :return: Single struct matching the structure of any of the given structs and holding LeafZip objects as leaves for non-excluded items + :raise IncompatibleStructs: If structs with non-matching excluded items are encountered and zip_parents_if_incompatible=False + """ # pylint: disable-msg = redefined-builtin assert len(structs) > 0 first = structs[0] if isstruct(first, leaf_condition): for struct in structs[1:]: + if not isstruct(struct): + if zip_parents_if_incompatible: + return LeafZip(structs) + else: + raise IncompatibleStructs('Cannot zip %s and %s because the latter is not a struct.' % (first, struct)) if set(to_dict(struct, item_condition=item_condition).keys()) != set(to_dict(first, item_condition=item_condition).keys()): if zip_parents_if_incompatible: return LeafZip(structs) else: - raise IncompatibleStructs('Cannot zip %s and %s because keys vary:\n%s\n%s' % (struct, first, to_dict(struct, item_condition=item_condition).keys(), to_dict(first, item_condition=item_condition).keys())) + raise IncompatibleStructs('Cannot zip %s and %s because keys vary:\n%s\n%s' % (first, struct, to_dict(first, item_condition=item_condition).keys(), to_dict(struct, item_condition=item_condition).keys())) if not isstruct(first, leaf_condition): return LeafZip(structs) @@ -48,15 +74,14 @@ def zip(structs, leaf_condition=None, item_condition=DATA, zip_parents_if_incomp values = [d[key] for d in dicts] values = zip(values, leaf_condition, item_condition=item_condition, zip_parents_if_incompatible=zip_parents_if_incompatible) new_dict[key] = values - with unsafe(): - return copy_with(first, new_dict) + return copy_with(first, new_dict, change_type=zip) class LeafZip(object): """ Created by struct.zip to replace data. +When a LeafZip is mapped using 'map', the values are passed as multiple arguments (*args). """ - def __init__(self, values): self.values = values @@ -72,17 +97,32 @@ def __str__(self): class IncompatibleStructs(Exception): """ -Thrown when two or more structs are required to have the same structure but do not. +Thrown when two or more structs are required to have the same structure but do not, e.g. when trying to zip incompatible structs. """ - def __init__(self, *args): Exception.__init__(self, *args) -def map(function, struct, leaf_condition=None, recursive=True, trace=False, item_condition=DATA): +def map(function, struct, leaf_condition=None, recursive=True, trace=False, item_condition=None, content_type=None): + """ +Iterates over all items of the struct and maps their values according to the specified function. +Preserves the hierarchical structure of struct, returning an object of the same type and leaving struct untouched. + :param function: function mapping from leaf values to new values. If not otherwise specified, the new values will be validated before map returns. If trace=True, Trace objects will be passed instead of values. For zipped structs, multiple values or a Trace containing multiple values is passed to function. + :param struct: struct or leaf value + :param leaf_condition: (optional) function that determines which structs are treated as leaves. Non-structs are always treated as leaves. Leaf structs are not iterated over but directly passed to function. + :param recursive: If True, recursively iterates over all non-leaf sub-structs, passing only leaves to function. Otherwise only iterates over direct items of struct; all sub-structs are treated as leaves. + :param trace: If True, passes a Trace object to function instead of the value. Traces contain additional information. + :param item_condition: (optional) ItemCondition or boolean function that filters which Items are iterated over. Excluded items are left untouched. If None, the context item condition is used (data-holding items by default). + :param content_type: (optional) Type key to use for new Structs. Defaults to VALID. Item-specific overrides can be defined by calling Item.override using the content_type as key. Override functions must have the signature (parent_struct, value). + :return: object of the same type and hierarchy as struct + """ # pylint: disable-msg = redefined-builtin if trace is True: trace = Trace(struct, None, None) + if item_condition is None: + item_condition = context_item_condition + if content_type is None: + content_type = VALID if not isstruct(struct, leaf_condition): if trace is False: if isinstance(struct, LeafZip): @@ -92,16 +132,32 @@ def map(function, struct, leaf_condition=None, recursive=True, trace=False, item else: return function(trace) else: - old_values = to_dict(struct, item_condition=item_condition) new_values = {} if not recursive: - leaf_condition = lambda x: True - for key, value in old_values.items(): - new_values[key] = map(function, value, leaf_condition, recursive, - Trace(value, key, trace) if trace is not False else False, - item_condition=item_condition) + def leaf_condition(_): return True + for item in items(struct): + if item_condition(item): + old_value = item.get(struct) + if content_type is not VALID and content_type is not INVALID and item.has_override(content_type): + new_value = item.get_override(content_type)(struct, old_value) + else: + new_value = map(function, old_value, leaf_condition, recursive, + Trace(old_value, item.name, trace) if trace is not False else False, + item_condition, + content_type) + new_values[item.name] = new_value + return copy_with(struct, new_values, change_type=content_type) - return copy_with(struct, new_values) + +def map_item(item, function, struct, leaf_condition=None, recursive=True, content_type=None): + assert isinstance(item, Item) or isinstance(item, six.string_types) + + def item_condition(item_): + if isinstance(item, six.string_types): + return item_.name == item + else: + return item_.name == item.name + return map(function, struct, leaf_condition=leaf_condition, recursive=recursive, trace=False, item_condition=item_condition, content_type=content_type) class Trace(object): @@ -128,7 +184,7 @@ def path(self, separator='.'): if self.parent is not None and self.parent.key is not None: return self.parent.path(separator) + separator + self.name else: - return self.name + return self.name if self.name is not None else '' def __repr__(self): return "%s = %s" % (self.path(), self.value) @@ -154,8 +210,7 @@ def check(trace): result.add(trace) except (ValueError, KeyError, TypeError): result.add(trace) - with unsafe(): - map(check, structs[0], leaf_condition=leaf_condition, recursive=recursive, trace=True, item_condition=item_condition) + map(check, structs[0], leaf_condition=leaf_condition, recursive=recursive, trace=True, item_condition=item_condition, content_type=INVALID) return result @@ -183,15 +238,74 @@ def print_differences(struct1, struct2, level=0): print(indent+'Item "%s" is missing from %s.' % (key2, struct1)) -def mappable(leaf_condition=None, recursive=True, item_condition=DATA, unsafe_context=False): +def mappable(leaf_condition=None, recursive=True, item_condition=None, unsafe_context=False, content_type=None): + if unsafe_context: + warnings.warn("unsafe_context is deprecated. Use content_type=INVALID instead.") def decorator(function): def broadcast_function(obj, *args, **kwargs): def function_with_args(x): return function(x, *args, **kwargs) if unsafe_context: - with unsafe(): - result = map(function_with_args, obj, leaf_condition=leaf_condition, recursive=recursive, item_condition=item_condition) + with _unsafe(): + result = map(function_with_args, obj, leaf_condition=leaf_condition, recursive=recursive, item_condition=item_condition, content_type=content_type) else: - result = map(function_with_args, obj, leaf_condition=leaf_condition, recursive=recursive, item_condition=item_condition) + result = map(function_with_args, obj, leaf_condition=leaf_condition, recursive=recursive, item_condition=item_condition, content_type=content_type) return result return broadcast_function return decorator + + +def shape(obj, leaf_condition=None, item_condition=None): + """ +Maps all values of a struct to their respective dynamic shapes using `math.shape()`. +To specify custom shapes, add an override with key struct.shape to the Item. + :param obj: struct or leaf + :param leaf_condition: (optional) leaf_condition passed to `map` + :param item_condition: (optional) item_condition passed to `map` + :return: Struct of same type holding shapes instead of data + """ + def get_shape(obj): + try: + return math.shape(obj) + except NoBackendFound: + return () + if isinstance(obj, Struct): + assert obj.content_type is VALID or obj.content_type is INVALID, "shape can only be accessed on data structs but '%s' has content type '%s'" % (type(obj).__name__, obj.content_type) + return map(get_shape, obj, leaf_condition=leaf_condition, item_condition=item_condition, content_type=shape) + + +def staticshape(obj, leaf_condition=None, item_condition=None): + """ +Maps all values of a struct to their respective static shapes using `math.staticshape()`. +To specify custom static shapes, add an override with key struct.staticshape to the Item. + :param obj: struct or leaf + :param leaf_condition: (optional) leaf_condition passed to `map` + :param item_condition: (optional) item_condition passed to `map` + :return: Struct of same type holding shapes instead of data + """ + def get_staticshape(obj): + try: + return math.staticshape(obj) + except NoBackendFound: + return () + if isinstance(obj, Struct): + assert obj.content_type is VALID or obj.content_type is INVALID, "staticshape can only be accessed on data structs but '%s' has content type '%s'" % (type(obj).__name__, obj.content_type) + return map(get_staticshape, obj, leaf_condition=leaf_condition, item_condition=item_condition, content_type=staticshape) + + +def dtype(obj, leaf_condition=None, item_condition=None): + """ +Maps all values of a struct to their respective data types using `math.dtype()`. +To specify custom dtypes, add an override with key struct.dtype to the Item. + :param obj: struct or leaf + :param leaf_condition: (optional) leaf_condition passed to `map` + :param item_condition: (optional) item_condition passed to `map` + :return: Struct of same type holding data types instead of data + """ + def get_dtype(obj): + try: + return math.dtype(obj) + except NoBackendFound: + return type(obj) + if isinstance(obj, Struct): + assert obj.content_type is VALID or obj.content_type is INVALID, "dtype can only be accessed on data structs but '%s' has content type '%s'" % (type(obj).__name__, obj.content_type) + return map(get_dtype, obj, leaf_condition=leaf_condition, item_condition=item_condition, content_type=dtype) diff --git a/phi/struct/item_condition.py b/phi/struct/item_condition.py new file mode 100644 index 000000000..0a042b464 --- /dev/null +++ b/phi/struct/item_condition.py @@ -0,0 +1,64 @@ +from .context import _struct_context, _STRUCT_CONTEXT_STACK + + +class ItemCondition(object): + """ +ItemConditions are used to filter struct items. +They represent a named boolean function on items. + +In addition, they can be used in 'with ItemCondition:' blocks, adding the condition to the thread context for all actions within that block. +In particular, struct.map, Struct.shape, Struct.staticshape are affected by context conditions. + +This module provides some standard conditions like ALL_ITEMS, DATA, VARIABLES, CONSTANTS. + """ + + def __init__(self, item_condition, name=None): + assert item_condition is None or callable(item_condition), item_condition + self.item_condition = item_condition + if name is not None: + self.name = name + else: + self.name = item_condition.__name__ if item_condition is not None else 'ALL' + + def condition_check(self, item): + return True if self.item_condition is None else self.item_condition(item) + + __call__ = condition_check + + def __enter__(self): + self.context = _struct_context(self) + return self.context.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + result = self.context.__exit__(exc_type, exc_val, exc_tb) + self.context = None + return result + + def __repr__(self): + return self.name + + +CONSTANTS = ItemCondition(lambda item: not item.is_variable, 'CONSTANTS') +VARIABLES = ItemCondition(lambda item: item.is_variable, 'VARIABLES') +DATA = ItemCondition(lambda item: item.holds_data, 'DATA') +ALL_ITEMS = ItemCondition(None) + + +def context_item_condition(item): + """ +Checks all thread-global item conditions. +Conditions can be specified using 'with ItemCondition:' blocks. +If no condition was specified, this function defaults to testing whether the item holds data. + :param item: item to be checked + :return: True if the item passes all conditions, False otherwise + """ + user_specified = False + for context in _STRUCT_CONTEXT_STACK: + if isinstance(context, ItemCondition): + user_specified = True + if not context.condition_check(item): + return False + if user_specified: + return True + else: + return item.holds_data # default condition if none specified diff --git a/phi/struct/struct.py b/phi/struct/struct.py index 226c6c058..26bf91b89 100644 --- a/phi/struct/struct.py +++ b/phi/struct/struct.py @@ -5,8 +5,10 @@ import numpy as np import six +from ..backend.dynamic_backend import DYNAMIC_BACKEND as math, NoBackendFound from .context import skip_validate -from .structdef import CONSTANTS, VARIABLES, Item +from .item_condition import context_item_condition, VARIABLES, CONSTANTS +from .structdef import Item, derived, _IndexItem def kwargs(locals, include_self=False, ignore=()): @@ -26,6 +28,18 @@ def kwargs(locals, include_self=False, ignore=()): return locals +class _DataType(object): + def __init__(self, name): + self.name = name + + def __repr__(self): + return self.name + + +INVALID = _DataType('invalid') +VALID = _DataType('valid') + + class Struct(object): """ Base class for all custom structs. @@ -38,35 +52,114 @@ class Struct(object): __traits__ = None __initialized_class__ = None - def __init__(self, **kwargs): + def __init__(self, content_type=VALID, **kwargs): assert isinstance(self, Struct), 'Struct.__init__() called on %s. Maybe you forgot **' % type(self) assert self.__initialized_class__ == self.__class__, "Instancing %s before struct class is initialized. Maybe you forgot to decorate the class with @struct.definition()" % self.__class__.__name__ + self.__content_type__ = INVALID if content_type is VALID else content_type # VALID, INVALID, Item for property, string for custom for item in self.__items__: if item.name not in kwargs: kwargs[item.name] = item.default_value self._set_items(**kwargs) for trait in self.__traits__: trait.endow(self) - self.validate() - - def copied_with(self, **kwargs): + if content_type is VALID: + self.validate() + + @derived() + def shape(self): + """ +Retrieves the dynamic shapes of items specified through the context (see :class:`phi.struct.item_condition.ItemCondition`). +Shapes of sub-structs are obtained using `struct.shape` while shapes of non-structs are obtained using `math.shape()`. + +To override the shapes of items, use `Item.override` with key `struct.shape` instead of overriding this method. + +The result of `x.shape` is equivalent to calling `struct.shape(x)`. + :return: Struct of same type holding shapes instead of data + """ + from .functions import shape + return shape(self) + + @derived() + def staticshape(self): + """ +Retrieves the static shapes of items specified through the context (see :class:`phi.struct.item_condition.ItemCondition`). +Shapes of sub-structs are obtained using `struct.staticshape` while shapes of non-structs are obtained using `math.staticshape()`. + +To override the static shapes of items, use `Item.override` with key `struct.staticshape` instead of overriding this method. + +The result of `x.staticshape` is equivalent to calling `struct.staticshape(x)`. + :return: Struct of same type holding shapes instead of data + """ + from .functions import staticshape + return staticshape(self) + + @derived() + def dtype(self): + """ +Retrieves the data types of items specified through the context (see :class:`phi.struct.item_condition.ItemCondition`). +Data types of sub-structs are obtained using `struct.dtype` while types of non-structs are obtained using `math.dtype()`. + +To override the dtype of items, use `Item.override` with key `struct.dtype` instead of overriding this method. + +The result of `x.dtype` is equivalent to calling `struct.dtype(x)`. + :return: Struct of same type holding data types instead of data + """ + from .functions import dtype + return dtype(self) + + def map(self, function, leaf_condition=None, recursive=True, trace=False, item_condition=None, content_type=None): + """Alias for struct.map()""" + from .functions import map + return map(function, self, leaf_condition=leaf_condition, recursive=recursive, trace=trace, item_condition=item_condition, content_type=content_type) + + def map_item(self, item, function, leaf_condition=None, recursive=True, content_type=None): + """Alias for struct.map_item()""" + from .functions import map_item + return map_item(item, function, self, leaf_condition=leaf_condition, recursive=recursive, content_type=content_type) + + def copied_with(self, change_type=None, **kwargs): + """ +Returns a copy of this Struct with some items values changed. +The Struct, this method is invoked on, remains unaltered. +The returned struct will be validated unless this struct is not valid or the content_type is set to something different than VALID. + :param change_type: content type of the returned struct + :param kwargs: Items to change, in the form item_name=new_value. + :return: Altered copy of this object + """ duplicate = copy(self) duplicate._set_items(**kwargs) # pylint: disable-msg = protected-access - duplicate.validate() + target_type = change_type if change_type is not None else self.__content_type__ + if target_type is VALID and not duplicate.is_valid: + duplicate.__content_type__ = INVALID + duplicate.validate() + else: + duplicate.__content_type__ = target_type return duplicate def _set_items(self, **kwargs): + if len(kwargs) == 0: + return + if self.is_valid: + self.__content_type__ = INVALID for name, value in kwargs.items(): try: item = getattr(self.__class__, name) except (KeyError, TypeError): raise TypeError('Struct %s has no property %s' % (self, name)) item.set(self, value) - return self def validate(self): - if not skip_validate(): + """ +Performs validation on this struct if it holds data and is invalid. +Data-holding structs should always be valid while structs holding non-data content such as shapes or data types are not regarded as valid. + :return: True if validation was performed, False otherwise + """ + if not skip_validate() and self.__content_type__ is INVALID: self.__validate__() + self.__content_type__ = VALID + return True + else: + return False def __validate__(self): for trait in self.__traits__: @@ -76,6 +169,14 @@ def __validate__(self): for trait in self.__traits__: trait.post_validate_struct(self) + @property + def is_valid(self): + return self.__content_type__ is VALID + + @property + def content_type(self): + return self.__content_type__ + def __to_dict__(self, item_condition): if item_condition is not None: return {item.name: item.get(self) for item in self.__items__ if item_condition(item)} @@ -111,20 +212,32 @@ def __hash__(self): pass return hash_value + def __repr__(self): + return "%s[%s]" % (type(self).__name__, self.content_type) + def to_dict(struct, item_condition=None): + if item_condition is None: + item_condition = context_item_condition if isinstance(struct, Struct): return struct.__to_dict__(item_condition) if isinstance(struct, (list, tuple, np.ndarray)): - if item_condition is None: - return {i: struct[i] for i in range(len(struct))} - else: - return {i: struct[i] for i in range(len(struct)) if item_condition(Item(name=i, validation_function=None, is_variable=True, default_value=None, dependencies=(), holds_data=True))} + return {i: struct[i] for i in range(len(struct)) if item_condition(Item(name=i, validation_function=None, is_variable=True, default_value=None, dependencies=(), holds_data=True))} if isinstance(struct, dict): return struct raise ValueError("Not a struct: %s" % struct) +def items(struct): + if isinstance(struct, Struct): + return struct.__items__ + if isinstance(struct, (list, tuple, np.ndarray)): + return [_IndexItem(i) for i in range(len(struct))] + if isinstance(struct, dict): + return [_IndexItem(key) for key in struct.keys()] + raise ValueError("Not a struct: '%s'" % struct) + + def variables(struct): return to_dict(struct, VARIABLES) @@ -154,9 +267,9 @@ def properties_dict(struct): return {'type': str(struct.__class__.__name__), 'module': str(struct.__class__.__module__)} -def copy_with(struct, new_values_dict): +def copy_with(struct, new_values_dict, change_type=None): if isinstance(struct, Struct): - return struct.copied_with(**new_values_dict) + return struct.copied_with(change_type=change_type, **new_values_dict) if isinstance(struct, tuple): duplicate = list(struct) for key, value in new_values_dict.items(): @@ -176,7 +289,10 @@ def copy_with(struct, new_values_dict): duplicate = dict(struct) for key, value in new_values_dict.items(): duplicate[key] = value - return duplicate + if type(struct) is dict: + return duplicate + else: + return type(struct)(duplicate) raise ValueError("Not a struct: %s" % struct) @@ -207,3 +323,4 @@ def equal(obj1, obj2): if obj1 is not obj2: return False return True + diff --git a/phi/math/struct_backend.py b/phi/struct/struct_backend.py similarity index 74% rename from phi/math/struct_backend.py rename to phi/struct/struct_backend.py index 28df92bdb..853bf8e6d 100644 --- a/phi/math/struct_backend.py +++ b/phi/struct/struct_backend.py @@ -1,21 +1,22 @@ -from phi import struct -from .base_backend import Backend +from phi.backend.backend import Backend +from . import context, struct, functions class StructBroadcastBackend(Backend): # Abstract mehtods are overridden generically. # pylint: disable-msg = abstract-method - def __init__(self, backend): + def __init__(self, backend, target_content_type=struct.VALID): Backend.__init__(self, 'StructBroadcast') self.backend = backend + self.target_content_type = target_content_type for fname in dir(self): if fname not in ('__init__', 'is_applicable', 'broadcast_function') and not fname.startswith('__'): function = getattr(self, fname) if callable(function): def context(fname=fname): def proxy(*args, **kwargs): - return broadcast_function(self.backend, fname, args, kwargs) + return self.broadcast_function(self.backend, fname, args, kwargs) return proxy setattr(self, fname, context()) @@ -25,17 +26,15 @@ def is_applicable(self, values): return True return False + def broadcast_function(self, backend, func, args, kwargs): + backend_func = getattr(backend, func) + obj, build_arguments = argument_assembler(args, kwargs) -def broadcast_function(backend, func, args, kwargs): - backend_func = getattr(backend, func) - obj, build_arguments = argument_assembler(args, kwargs) - - def f(*values): - args, kwargs = build_arguments(values) - result = backend_func(*args, **kwargs) - return result - with struct.unsafe(): - return struct.map(f, obj) + def f(*values): + args, kwargs = build_arguments(values) + result = backend_func(*args, **kwargs) + return result + return functions.map(f, obj, content_type=self.target_content_type) def argument_assembler(args, kwargs): @@ -44,7 +43,7 @@ def argument_assembler(args, kwargs): if len(structs) == 1: obj = structs[0] else: - obj = struct.zip(structs) + obj = functions.zip(structs) def assemble_arguments(items): args = [] diff --git a/phi/struct/structdef.py b/phi/struct/structdef.py index 2e41768e1..2420ad6fc 100644 --- a/phi/struct/structdef.py +++ b/phi/struct/structdef.py @@ -36,7 +36,8 @@ def decorator(struct_class, traits=traits): for trait in base.__traits__: if trait not in traits: inherited_traits += (trait,) - traits = inherited_traits + traits + traits = inherited_traits + tuple([t for t in traits if t not in inherited_traits]) + assert len(set(traits)) == len(traits), "Duplicate traits on struct class '%s'" % struct_class # --- Initialize & Decorate --- struct_class.__traits__ = traits for item in items.values(): @@ -124,6 +125,7 @@ def __init__(self, name, validation_function, is_variable, default_value, depend self.holds_data = holds_data self.trait_kwargs = trait_kwargs self.struct_class = None + self._overrides = {} def __initialize_for__(self, struct_class): self.struct_class = struct_class @@ -150,6 +152,40 @@ def validate(self, struct): value = trait.post_validated(struct, self, value) self.set(struct, value) + def has_override(self, content_type): + if content_type is None: + return False + return self._attribute_name(content_type) in self._overrides + + def get_override(self, content_type): + return self._overrides[self._attribute_name(content_type)] + + def override(self, content_type, override_function): + """ +Override a property or behaviour of this item and/or its values. +This affects all instances of the associated Struct. +The override function is called instead of the usual function in `struct.map` to obtain a leaf value. + +Overrides can also be used to specify custom property getters, e.g. to override shape, staticshape, dtype. +As this method is called on an Item, it must be invoked outside the item it affects. + +Example: to override the shape of an item, put the following just below its declaration: `item.override(struct.shape, lambda self, value: custom_shape)` + :param content_type: custom name or Item/DerivedItem reference + :param override_function: function, signature depends on the overridden property. + """ + self._overrides[self._attribute_name(content_type)] = override_function + + @staticmethod + def _attribute_name(name_or_attribute): + if isinstance(name_or_attribute, (Item, DerivedProperty)): + name = name_or_attribute.name + elif callable(name_or_attribute): + name = name_or_attribute.__name__ + else: + name = name_or_attribute + assert isinstance(name, six.string_types), 'Not an attribute: %s' % name + return name + def __get__(self, instance, owner): if instance is not None: return getattr(instance, '_' + self.name) @@ -171,16 +207,17 @@ def __repr__(self): return self.name -def CONSTANTS(item): return not item.is_variable - - -def VARIABLES(item): return item.is_variable +class _IndexItem(Item): + def __init__(self, index, is_variable=True, holds_data=True): + Item.__init__(self, name=index, validation_function=None, is_variable=is_variable, default_value=None, dependencies=(), holds_data=holds_data) + self.index = index -def DATA(item): return item.holds_data - + def get(self, struct): + return struct[self.index] -ALL_ITEMS = None + def set(self, struct, value): + struct[self.index] = value class DerivedProperty(object): diff --git a/phi/struct/tensorop.py b/phi/struct/tensorop.py index 8611e2e81..0c0044f0e 100644 --- a/phi/struct/tensorop.py +++ b/phi/struct/tensorop.py @@ -1,64 +1,2 @@ -import numpy as np - - -def _is_leaf(tensor_like, leaf_condition): - if not isinstance(tensor_like, (tuple, list, np.ndarray)): - return True - if leaf_condition is not None and leaf_condition(tensor_like): - return True - return False - - -def collapse(tensor_like, leaf_condition=None): - if _is_leaf(tensor_like, leaf_condition): - return tensor_like - collapsed_elements = tuple([collapse(element, leaf_condition) for element in tensor_like]) - first = collapsed_elements[0] - for element in collapsed_elements[1:]: - if element != first: - return collapsed_elements - return first - - -def collapsed_gather_nd(collapsed, nd_index, leaf_condition=None): - if isinstance(collapsed, (tuple, list, np.ndarray)): - if leaf_condition is not None and leaf_condition(collapsed): - return collapsed - # collapsed = np.array(collapsed) - if len(nd_index) == 1: - return collapsed[nd_index[0]] - else: - return collapsed_gather_nd(collapsed[nd_index[0]], nd_index[1:]) - else: - return collapsed - - -def expand(collapsed, shape): - if len(shape) == 0: - return collapsed - if isinstance(collapsed, (tuple, list, np.ndarray)): - if len(collapsed) == shape[0]: - return [expand(item, shape[1:]) for item in collapsed] - elif len(collapsed) == 1: - item = expand(collapsed[0], shape[1:]) - return [item] * shape[0] - else: - raise ValueError('Cannot match shape: requested %d but actual %d' % (shape[0], len(collapsed))) - else: - return [expand(collapsed, shape[1:])] * shape[0] - - -class CollapsedTensor(object): - - def __init__(self, collapsed, leaf_condition=None, shape=None): - self.collapsed = collapsed - self.leaf_condition = leaf_condition - self.shape = shape - - def __getitem__(self, item): - return collapsed_gather_nd(self.collapsed, item, self.leaf_condition) - - def expand(self, shape=None): - shape = self.shape if shape is None else shape - assert shape is not None - return expand(self.collapsed, shape) +# Alias for phi.backend.tensorop for compatibility +from ..backend.tensorop import * \ No newline at end of file diff --git a/phi/struct/trait.py b/phi/struct/trait.py index 5d8e04129..4ad329f01 100644 --- a/phi/struct/trait.py +++ b/phi/struct/trait.py @@ -72,3 +72,6 @@ def __ne__(self, other): def __repr__(self): return '%s (Trait)' % self.__class__.__name__ + + def __hash__(self): + return hash(self.__class__) diff --git a/phi/tf/app.py b/phi/tf/app.py index 86f1f34b6..87f0d7959 100644 --- a/phi/tf/app.py +++ b/phi/tf/app.py @@ -98,6 +98,9 @@ def __init__(self, name='TensorFlow application', subtitle='', self.log_scalars = log_scalars def prepare(self): + if self.prepared: + return + scalars = [tf.summary.scalar(self.scalar_names[i], self.scalars[i]) for i in range(len(self.scalars))] self.merged_scalars = tf.summary.merge(scalars) diff --git a/phi/tf/data.py b/phi/tf/data.py index fbc3d6d22..ea62de502 100644 --- a/phi/tf/data.py +++ b/phi/tf/data.py @@ -2,6 +2,7 @@ from phi.data.fluidformat import _transform_for_writing from phi.physics.physics import State from phi.physics.world import StateProxy +from phi.struct.context import _unsafe from .util import placeholder @@ -12,7 +13,7 @@ def load_state(state): assert isinstance(state, State) state = _transform_for_writing(state) names = struct.names(state) - with struct.unsafe(): + with _unsafe(): placeholders = placeholder(state.shape) state_in = struct.map(lambda x: x, placeholders) # validates fields, splits staggered tensors return state_in, {placeholders: names} diff --git a/phi/tf/session.py b/phi/tf/session.py index 2672e1a20..f959c54b6 100644 --- a/phi/tf/session.py +++ b/phi/tf/session.py @@ -5,6 +5,10 @@ import numpy as np import tensorflow as tf +if tf.__version__[0] == '2': + logging.info('Adjusting for tensorflow 2.0') + tf = tf.compat.v1 + tf.disable_eager_execution() from phi import struct from .profiling import Timeliner @@ -13,10 +17,10 @@ class Session(object): - def __init__(self, scene, session=tf.Session()): + def __init__(self, scene, session=None): self._scene = scene - self._session = session - assert self._session.graph == tf.get_default_graph() + self._session = session if session is not None else tf.Session() + assert self._session.graph == tf.get_default_graph(), 'Session %s does not reference the current TensorFlow graph.' self.graph = tf.get_default_graph() self.summary_writers = {} self.summary_directory = os.path.abspath(scene.subpath('summary')) if scene is not None else None @@ -25,11 +29,6 @@ def __init__(self, scene, session=tf.Session()): self.saver = None def initialize_variables(self): - import tensorflow as tf - if tf.__version__[0] == '2': - logging.info('Adjusting for tensorflow 2.0') - tf = tf.compat.v1 - tf.disable_eager_execution() self._session.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(max_to_keep=100, allow_empty=True) @@ -49,8 +48,7 @@ def add_to_dict(key_tensor, value_tensor): if isplaceholder(key_tensor): tensor_feed_dict[key_tensor] = value_tensor return None - with struct.unsafe(): - struct.map(add_to_dict, pairs, item_condition=struct.ALL_ITEMS) + struct.map(add_to_dict, pairs, item_condition=struct.ALL_ITEMS, content_type=struct.INVALID) tensor_fetches = struct.flatten(fetches, item_condition=struct.ALL_ITEMS) if isinstance(fetches, (tuple, list)): diff --git a/phi/tf/tf_backend.py b/phi/tf/tf_backend.py index b44dbe019..05960a7ae 100644 --- a/phi/tf/tf_backend.py +++ b/phi/tf/tf_backend.py @@ -7,8 +7,8 @@ import tensorflow as tf from packaging import version -from phi.math.base_backend import Backend -from phi.struct.tensorop import expand, collapsed_gather_nd +from phi.backend.backend import Backend +from phi.backend.tensorop import expand, collapsed_gather_nd if tf.__version__[0] == '2': logging.info('Adjusting for tensorflow 2.0') @@ -267,6 +267,8 @@ def to_complex(self, x): return tf.to_complex64(x) def gather(self, values, indices): + if isinstance(indices, slice): + return values[indices] return tf.gather(values, indices) def gather_nd(self, values, indices): @@ -459,14 +461,19 @@ def _boundary_circular(sample_coords, input_size): def _boundary_symmetric(sample_coords, input_size): - circular_size = input_size + input_size - 2 - return (input_size - 1) - tf.abs( - (input_size - 1) - _boundary_circular(sample_coords, circular_size)) + sample_coords = _boundary_circular(sample_coords, 2 * input_size) + return ((2 * input_size - 1) - tf.abs((2 * input_size - 1) - 2 * sample_coords)) // 2 + + +def _boundary_reflect(sample_coords, input_size): + sample_coords = _boundary_circular(sample_coords, 2 * input_size - 2) + return (input_size - 1) - tf.abs((input_size - 1) - sample_coords) SUPPORTED_BOUNDARY = { 'zero': _boundary_replicate, 'replicate': _boundary_replicate, 'circular': _boundary_circular, - 'symmetric': _boundary_symmetric + 'symmetric': _boundary_symmetric, + 'reflect': _boundary_reflect, } diff --git a/phi/tf/util.py b/phi/tf/util.py index eb50b389a..852a90051 100644 --- a/phi/tf/util.py +++ b/phi/tf/util.py @@ -17,25 +17,27 @@ def _tf_name(trace, basename): - if basename is None: - return trace.path('/') - else: - return basename + '/' + trace.path('/') + path = trace.path('/') + if basename is None and len(path) == 0: + return None + result = path if basename is None else basename + '/' + path + print(result) + return result -def placeholder(shape, dtype=np.float32, basename=None, item_condition=struct.VARIABLES): +def placeholder(shape, dtype=np.float32, basename='Placeholder'): if struct.isstruct(dtype): def placeholder_map(trace): shape, dtype = trace.value return tf.placeholder(dtype, shape, _tf_name(trace, basename)) - zipped = struct.zip([shape, dtype], leaf_condition=is_static_shape, item_condition=item_condition) - return struct.map(placeholder_map, zipped, leaf_condition=is_static_shape, trace=True, item_condition=item_condition) + zipped = struct.zip([shape, dtype], leaf_condition=is_static_shape) + return struct.map(placeholder_map, zipped, leaf_condition=is_static_shape, trace=True) else: def f(trace): return tf.placeholder(dtype, trace.value, _tf_name(trace, basename)) - return struct.map(f, shape, leaf_condition=is_static_shape, trace=True, item_condition=item_condition) + return struct.map(f, shape, leaf_condition=is_static_shape, trace=True) -def placeholder_like(obj, basename=None): +def placeholder_like(obj, basename='Placeholder'): warnings.warn("placeholder_like may not respect the batch dimension. " "For State objects, use placeholder(state.shape) instead.", DeprecationWarning, stacklevel=2) @@ -43,12 +45,12 @@ def f(attr): return tf.placeholder(attr.value.dtype, attr.value.shape, _tf_name( return struct.map(f, obj, leaf_condition=is_static_shape, trace=True) -def variable(initial_value, dtype=np.float32, basename=None, trainable=True, item_condition=struct.VARIABLES): +def variable(initial_value, dtype=np.float32, basename='Variable', trainable=True): def f(attr): return tf.Variable(attr.value, name=_tf_name(attr, basename), dtype=dtype, trainable=trainable) - return struct.map(f, initial_value, trace=True, item_condition=item_condition) + return struct.map(f, initial_value, trace=True) -def variable_generator(initializer, dtype=np.float32, basename=None, trainable=True): +def variable_generator(initializer, dtype=np.float32, basename='Variable', trainable=True): def create_variable(shape): initial_value = initializer(shape) return variable(initial_value, dtype, basename, trainable) diff --git a/phi/tf/world.py b/phi/tf/world.py index 12e9aba14..f04333565 100644 --- a/phi/tf/world.py +++ b/phi/tf/world.py @@ -1,18 +1,20 @@ import numpy as np +from phi import struct from phi.physics.world import World from phi.physics import Physics -from phi.physics.collective import CollectivePhysics -from phi import math, struct +from phi.physics.collective import CollectivePhysics, StateCollection +from phi.struct import VARIABLES from phi.struct.functions import mappable from .util import placeholder def tf_bake_graph(world, session): # --- Build placeholder state --- - shape = world.state.shape - dtype = _32_bit(math.types(world.state)) - state_in = placeholder(shape, dtype=dtype) + with VARIABLES: + shape = world.state.staticshape + dtype = _32_bit(world.state.dtype) + state_in = placeholder(shape, dtype=dtype) dt = placeholder(()) # --- Build graph --- state_out = world.physics.step(state_in, dt=dt) @@ -28,9 +30,10 @@ def tf_bake_subgraph(tracker, session): tfworld = World() tfworld.add(tracker.state) # --- Build placeholder state --- - dtype = _32_bit(math.types(tracker.state)) - shape = tracker.state.shape - state_in = placeholder(shape, dtype=dtype) + with VARIABLES: + shape = tracker.state.staticshape + dtype = _32_bit(tracker.state.dtype) + state_in = placeholder(shape, dtype=dtype) dt = placeholder(()) # --- Build graph --- state_out = tracker.world.physics.substep(state_in, tracker.world.state, dt) @@ -68,7 +71,7 @@ def step(self, state_collection, dt=1.0, **dependent_states): return result -@mappable(item_condition=None, unsafe_context=True) +@mappable(content_type=struct.dtype) def _32_bit(dtype): if dtype == np.float64: return np.float32 diff --git a/phi/torch/torch_backend.py b/phi/torch/torch_backend.py index a8189f589..a01ea0a3e 100644 --- a/phi/torch/torch_backend.py +++ b/phi/torch/torch_backend.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as torchf -from phi.math.base_backend import Backend +from phi.backend.backend import Backend class TorchBackend(Backend): diff --git a/setup.py b/setup.py index e5756bad4..7101277de 100644 --- a/setup.py +++ b/setup.py @@ -114,16 +114,20 @@ def finalize_options(self): assert os.path.isfile(self.nvcc) or self.nvcc == 'nvcc' -with open("documentation/Package_Info.md", "r") as readme: - long_description = readme.read() +try: + with open(os.path.join(os.path.dirname(__file__), 'documentation/Package_Info.md'), 'r') as readme: + long_description = readme.read() +except FileNotFoundError: + pass setup( name='phiflow', - version='1.0.2', - download_url='https://github.com/tum-pbs/PhiFlow/archive/1.0.2.tar.gz', + version='1.0.3', + download_url='https://github.com/tum-pbs/PhiFlow/archive/1.0.3.tar.gz', packages=['phi', 'phi.app', + 'phi.backend', 'phi.data', 'phi.geom', 'phi.local', @@ -133,14 +137,15 @@ def finalize_options(self): 'phi.physics.pressuresolver', 'phi.struct', 'phi.tf', - 'phi.viz', 'phi.viz.dash', + 'phi.viz', + 'phi.viz.dash', 'webglviewer'], cmdclass={ 'tf_cuda': CudaCommand, }, description='Research-oriented differentiable fluid simulation framework', long_description=long_description, - long_description_content_type="text/markdown", + long_description_content_type='text/markdown', keywords=['Differentiable', 'Simulation', 'Fluid', 'Machine Learning', 'Deep Learning'], license='MIT', author='Philipp Holl', @@ -157,6 +162,8 @@ def finalize_options(self): 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', ], extras_require={ 'gui': ['dash', diff --git a/tests/test_demos.py b/tests/test_demos.py index da304282b..934fd9682 100644 --- a/tests/test_demos.py +++ b/tests/test_demos.py @@ -72,6 +72,9 @@ def test_moving_inflow(self): def test_simpleplume(self): demo_run('simpleplume') + def test_simpleplume_3d(self): + demo_run('simpleplume_3d') + def test_smoke_datagen_commandline(self): demo_run('smoke_datagen_commandline') diff --git a/tests/test_fluid.py b/tests/test_fluid.py index 2b3f741ff..7ec76b2b8 100644 --- a/tests/test_fluid.py +++ b/tests/test_fluid.py @@ -3,12 +3,13 @@ import numpy from phi import struct, math -from phi.geom import Sphere, AABox +from phi.geom import Sphere, AABox, box from phi.physics.domain import Domain from phi.physics.field import StaggeredGrid from phi.physics.field.effect import Fan, Inflow from phi.physics.material import CLOSED, OPEN from phi.physics.fluid import Fluid, INCOMPRESSIBLE_FLOW, IncompressibleFlow +from phi.physics.obstacle import Obstacle from phi.physics.pressuresolver.sparse import SparseCG from phi.physics.world import World @@ -54,10 +55,12 @@ def typetest(fluid): def test_effects(self): world = World() - world.add(Fluid(Domain([16, 16]))) - world.add(Fan(Sphere((10, 8), 5), [-1, 0])) - world.step() - world.step() + fluid = world.add(Fluid(Domain([16, 16]))) + fan = world.add(Fan(Sphere((10, 8), 5), [-1, 0])) + obstacle = world.add(Obstacle(box[0:1, 0:1])) + world.step(dt=1) + world.step(dt=0.5) + assert fluid.age == fan.age == obstacle.age == 1.5 def test_properties_dict(self): world = World() diff --git a/tests/test_fluid_tf.py b/tests/test_fluid_tf.py index 74b430f36..05fb10cd1 100644 --- a/tests/test_fluid_tf.py +++ b/tests/test_fluid_tf.py @@ -16,6 +16,7 @@ class TestFluidTF(TestCase): def test_fluid_tf(self): + tf.reset_default_graph() world = World() fluid = Fluid(Domain([16, 16])) world.add(fluid) @@ -30,6 +31,7 @@ def test_fluid_tf(self): self.assertIsInstance(fluid, Fluid) def test_tf_subgraph(self): + tf.reset_default_graph() world = World() fluid = world.add(Fluid(Domain([16, 16]))) tf_bake_subgraph(fluid, Session(Scene.create('data', copy_calling_script=False))) @@ -38,6 +40,7 @@ def test_tf_subgraph(self): self.assertIsInstance(fluid.state.density.data, numpy.ndarray) def test_tf_worldgraph(self): + tf.reset_default_graph() world = World() fluid = world.add(Fluid(Domain([16, 16]))) tf_bake_graph(world, Session(Scene.create('data', copy_calling_script=False))) diff --git a/tests/test_initializers.py b/tests/test_initializers.py index 208d66aa2..21c63cba8 100644 --- a/tests/test_initializers.py +++ b/tests/test_initializers.py @@ -18,9 +18,7 @@ def test_direct_initializers(self): self.assertEqual(math.randn([1, 4]).dtype, np.float32) def test_struct_initializers(self): - bounds = box[0:1] # outside unsafe - with struct.unsafe(): - obj = ([4], CenteredGrid([1, 4, 1], bounds), ([9], [8, 2])) + obj = ([4], CenteredGrid([1, 4, 1], box[0:1], content_type=struct.shape), ([9], [8, 2])) z = math.zeros(obj) self.assertIsInstance(z, tuple) np.testing.assert_equal(z[0], np.zeros([4])) diff --git a/tests/test_math.py b/tests/test_math.py index a8f3c0305..49e40934c 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -3,6 +3,7 @@ import numpy as np from phi.geom import AABox +from phi.math.nd import _dim_shifted from phi.tf import tf # pylint: disable-msg = redefined-builtin, redefined-outer-name, unused-wildcard-import, wildcard-import @@ -37,7 +38,7 @@ def test_fft(self): self.assertLess(max(abs(x_np - x_tf.eval())), 1e-3) - def test_laplace(self): + def test_laplace_padding(self): tf.InteractiveSession() for dims in range(1, 4): shape = [2] + [4]*dims + [3] @@ -48,7 +49,7 @@ def test_laplace(self): l = laplace(a, padding='reflect') np.testing.assert_equal(l, 0) np.testing.assert_equal(l.shape, a.shape) - l = laplace(a, padding='cyclic') + l = laplace(a, padding='circular') np.testing.assert_equal(l, 0) np.testing.assert_equal(l.shape, a.shape) l = laplace(a, padding='valid') @@ -106,3 +107,33 @@ def test_div_no_nan(self): y = tf.convert_to_tensor(y) result = divide_no_nan(x, y).eval() np.testing.assert_equal(result, [1, -0.5, 0, 0, 0]) + + def test_dim_shifted(self): + # --- 1D --- + tensor = np.expand_dims(np.expand_dims(np.arange(10), axis=-1), axis=0) + lower, center, upper = _dim_shifted(tensor, 0, (-1, 0, 1), components=0) + np.testing.assert_equal(lower[0,:,0], np.arange(8)) + np.testing.assert_equal(center[0,:,0], np.arange(1,9)) + np.testing.assert_equal(upper[0,:,0], np.arange(2,10)) + # --- 2D --- + tensor = np.ones([1, 4, 4, 2]) + lower, upper = _dim_shifted(tensor, 0, (0, 1), diminish_others=(0, 1), components=0) + np.testing.assert_equal(lower.shape, (1, 3, 3, 1)) + np.testing.assert_equal(upper.shape, (1, 3, 3, 1)) + + def test_gradient(self): + # --- 1D --- + tensor = np.expand_dims(np.expand_dims(np.arange(5), axis=-1), axis=0) + grad = gradient(tensor, padding='replicate') + np.testing.assert_equal(grad[0,:,0], [1, 1, 1, 1, 0]) + grad = gradient(tensor, padding='circular') + np.testing.assert_equal(grad[0,:,0], [1, 1, 1, 1, -4]) + grad = gradient(tensor, dx=0.1, padding='replicate') + np.testing.assert_equal(grad[0,:,0], [10, 10, 10, 10, 0]) + + def test_upsample_downsample(self): + # --- 1D --- + tensor = np.expand_dims(np.expand_dims(np.arange(5), axis=-1), axis=0) + up = upsample2x(tensor) + inverted = downsample2x(up) + np.testing.assert_equal(inverted[:, 1:-1, :], tensor[:, 1:-1, :]) diff --git a/tests/test_poisson_solve.py b/tests/test_poisson_solve.py new file mode 100644 index 000000000..c912cf67d --- /dev/null +++ b/tests/test_poisson_solve.py @@ -0,0 +1,70 @@ +from unittest import TestCase + +import numpy as np +from phi import math + +from phi.flow import CLOSED, PERIODIC, OPEN, Domain, poisson_solve +from phi.physics.pressuresolver.geom import GeometricCG +from phi.physics.pressuresolver.sparse import SparseCG, SparseSciPy + + +def _generate_examples(): + # --- Example 1 --- + ex1 = np.tile(np.linspace(1, 0, 5), [4, 1]) + ex1 = math.expand_dims(math.expand_dims(ex1, -1), 0) - math.mean(ex1) + # --- Example 2 --- + ex2 = np.zeros([1, 4, 5, 1]) + ex2[0, :, 2, 0] = 1 + ex2 -= math.mean(ex2) + # --- Stack examples to batch --- + return math.concat([ex1, ex2], axis=0) + + +def _test_solve_no_obstacles(domain, solver): + print('Testing domain with boundaries: %s' % (domain.boundaries,)) + data_in = _generate_examples() + p = poisson_solve(domain.centered_grid(data_in), domain, solver=solver)[0] + np.testing.assert_almost_equal(p.laplace().data[:, 1:-1, 1:-1, :], data_in[:, 1:-1, 1:-1, :], decimal=5) + if domain.boundaries is CLOSED: + np.testing.assert_almost_equal(p.laplace().data, data_in, decimal=5) + # rows = math.unstack(p.data, 1) + # for row in rows[1:]: + # np.testing.assert_almost_equal(row, rows[0], decimal=5) + + +DOMAINS = [ + Domain([4, 5], boundaries=CLOSED), + Domain([4, 5], boundaries=OPEN), + Domain([4, 5], boundaries=PERIODIC), + Domain([4, 5], boundaries=[PERIODIC, CLOSED]), + Domain([4, 5], boundaries=[CLOSED, OPEN]), + ] + +SOLVERS = [ + SparseCG(), GeometricCG() +] + + +class TestPoissonSolve(TestCase): + + def test_equal_results(self): + data_in = _generate_examples() + for domain in DOMAINS: + pressure_fields = [poisson_solve(domain.centered_grid(data_in), domain, solver=solver)[0].data for solver in SOLVERS] + for field in pressure_fields[1:]: + np.testing.assert_almost_equal(field, pressure_fields[0], decimal=4) + + def test_sparse_cg(self): + solver = SparseCG() + for domain in DOMAINS: + _test_solve_no_obstacles(domain, solver) + + # def test_sparse_scipy(self): + # solver = SparseSciPy() + # for domain in DOMAINS: + # _test_solve_no_obstacles(domain, solver) + + def test_geometric_cg(self): + solver = GeometricCG() + for domain in DOMAINS: + _test_solve_no_obstacles(domain, solver) diff --git a/tests/test_struct.py b/tests/test_struct.py index 36c5066e2..1589d1645 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -2,6 +2,7 @@ import numpy +from phi import math from phi.geom import box from phi.physics.collective import StateCollection from phi.physics.domain import Domain @@ -24,13 +25,12 @@ class TestStruct(TestCase): def test_identity(self): for obj in generate_test_structs(): - with struct.unsafe(): - obj2 = struct.map(lambda s: s, obj, recursive=False) - self.assertEqual(obj, obj2) - obj3 = struct.map(lambda t: t, obj, recursive=True) - self.assertEqual(obj, obj3) - obj4 = struct.map(lambda t: t, obj, item_condition=struct.ALL_ITEMS) - self.assertEqual(obj, obj4) + obj2 = struct.map(lambda s: s, obj, recursive=False) + self.assertEqual(obj, obj2) + obj3 = struct.map(lambda t: t, obj, recursive=True) + self.assertEqual(obj, obj3) + obj4 = struct.map(lambda t: t, obj, item_condition=struct.ALL_ITEMS) + self.assertEqual(obj, obj4) def test_flatten(self): for obj in generate_test_structs(): @@ -42,37 +42,34 @@ def test_flatten(self): def test_names(self): for obj in generate_test_structs(): - with struct.unsafe(): - names = struct.flatten(struct.map(lambda attr: attr.name, obj, trace=True)) - self.assertGreater(len(names), 0) - for name in names: - self.assertIsInstance(name, str) + names = struct.flatten(struct.map(lambda attr: attr.name, obj, trace=True, content_type='name')) + self.assertGreater(len(names), 0) + for name in names: + self.assertIsInstance(name, str) def test_paths(self): obj = {'Vels': [CenteredGrid(numpy.zeros([1, 4, 1]), box[0:1], name='v')]} - with struct.unsafe(): - names = struct.flatten(struct.map(lambda attr: attr.path(), obj, trace=True)) + names = struct.flatten(struct.map(lambda attr: attr.path(), obj, trace=True, content_type='name')) self.assertEqual(names[0], 'Vels.0.data') def test_copy(self): - with struct.unsafe(): - fluid = Fluid(Domain([4]), density='Density', velocity='Velocity') - v = fluid.copied_with(velocity='V2') - self.assertEqual(v.velocity, 'V2') - self.assertEqual(v.density, 'Density') - - try: - fluid.copied_with(velocity='D2') - self.fail() - except AssertionError: - pass + fluid = Fluid(Domain([4]), density='Density', velocity='Velocity', content_type=struct.INVALID) + v = fluid.copied_with(velocity='V2') + self.assertEqual(v.velocity, 'V2') + self.assertEqual(v.density, 'Density') + + try: + fluid.copied_with(velocity='D2') + self.fail() + except AssertionError: + pass def test_zip(self): - with struct.unsafe(): - a = CenteredGrid('a') - b = CenteredGrid('b') - stacked = struct.map(lambda *x: x, struct.zip([a, b])) - numpy.testing.assert_equal(stacked.data, ('a', 'b')) + a = CenteredGrid('a', content_type='name') + b = CenteredGrid('b', content_type='name') + zipped = struct.zip([a, b]) + stacked = struct.map(lambda *x: x, zipped, content_type='name') + numpy.testing.assert_equal(stacked.data, ('a', 'b')) def test_collapse(self): self.assertEqual(0, collapse(numpy.zeros([2, 2]))) @@ -90,9 +87,24 @@ def test_expand(self): def test_mappable(self): x = [0] + @mappable(item_condition=VARIABLES) def act_on_variables(x): return x + 1 + @mappable(item_condition=CONSTANTS) def act_on_constants(x): return x + 1 + self.assertEqual([1], act_on_variables(x)) - self.assertEqual([0], act_on_constants(x)) \ No newline at end of file + self.assertEqual([0], act_on_constants(x)) + + def test_content_types(self): + dom = Domain([4]) + assert dom.is_valid + # --- CenteredGrid --- + assert dom.centered_shape().content_type is struct.Struct.shape + assert dom.centered_grid(math.zeros).content_type is struct.VALID + # --- StaggeredGrid --- + assert dom.staggered_shape().content_type is struct.Struct.shape + assert dom.staggered_shape().x.content_type is struct.Struct.shape + assert dom.staggered_grid(math.zeros).content_type is struct.VALID + assert dom.staggered_grid(math.zeros).x.content_type is struct.VALID diff --git a/tests/test_tensorflow.py b/tests/test_tensorflow.py index a219bdacb..3d866d26b 100644 --- a/tests/test_tensorflow.py +++ b/tests/test_tensorflow.py @@ -25,11 +25,9 @@ def test_direct_placeholders(self): self.assertEqual(v.name, 'Variable:0') def test_struct_placeholders(self): - bounds = box[0:1] # outside unsafe - with struct.unsafe(): - obj = ([4], CenteredGrid([1, 4, 1], bounds), ([9], [8, 2])) + obj = ([4], CenteredGrid([1, 4, 1], box[0:1], content_type=struct.shape), ([9], [8, 2])) tensorflow.reset_default_graph() p = placeholder(obj) - self.assertEqual(p[0].name, '0:0') - self.assertEqual(p[1].data.name, '1/data:0') + self.assertEqual('Placeholder/0:0', p[0].name) + self.assertEqual('Placeholder/1/data:0', p[1].data.name) self.assertIsInstance(p, tuple) diff --git a/tests/test_world.py b/tests/test_world.py index 6adecbd38..619ef9d15 100644 --- a/tests/test_world.py +++ b/tests/test_world.py @@ -1,7 +1,9 @@ from unittest import TestCase import numpy +import six +from phi import struct from phi.physics.collective import StateCollection from phi.physics.domain import Domain from phi.physics.fluid import Fluid @@ -24,3 +26,33 @@ def test_names(self): world = World(add_default_objects=True) assert world.gravity.state is world.state.gravity + + def test_state_collection(self): + fluid = Fluid(Domain([1, 1])) + fluid2 = Fluid(Domain([2, 2])) + + c1 = StateCollection([fluid]) + assert c1.fluid is fluid + assert fluid in c1 + assert c1[fluid] is fluid + assert isinstance(repr(c1), six.string_types) + assert len(c1) == len(c1.shape) == len(c1.staticshape) == len(c1.dtype) + assert c1.shape.fluid.density.data == (1, 1, 1, 1) + self.assertIsInstance(c1.dtype.fluid.density.data, numpy.dtype) + + c2 = StateCollection() + assert len(c2) == 0 + c2 = c2.state_added(fluid) + assert c2 == c1 + assert hash(c2) == hash(c1) + + c3 = c2.state_replaced(fluid2) + assert c3 != c2 + assert c3.fluid is fluid2 + + c4 = c3.state_removed(fluid2) + assert len(c4) == 0 + + c5 = struct.map(lambda x: x, c1) + assert isinstance(c5, StateCollection) + assert c5 == c1