Skip to content

Commit

Permalink
Merge pull request #17 from DrChainsaw/spiraldemo
Browse files Browse the repository at this point in the history
Add first set of classes for spiral demo
  • Loading branch information
DrChainsaw authored Mar 27, 2019
2 parents 43ca828 + ab89237 commit 81f12ef
Show file tree
Hide file tree
Showing 137 changed files with 7,836 additions and 387 deletions.
126 changes: 118 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# neuralODE4j

Implementation of neural ordinary differential equations built for deeplearning4j.
Travis [![Build Status](https://travis-ci.org/DrChainsaw/AmpControl.svg?branch=master)](https://travis-ci.org/DrChainsaw/NeuralODE4j)
AppVeyor[![Build status](https://ci.appveyor.com/api/projects/status/wjdi11f4cmx32ir8?svg=true)](https://ci.appveyor.com/project/DrChainsaw/neuralode4j)

[![codebeat badge](https://codebeat.co/badges/d9e719b4-5465-4f08-9c14-f924691cdd86)](https://codebeat.co/projects/github-com-drchainsaw-neuralode4j-master)
[![Codacy Badge](https://api.codacy.com/project/badge/Grade/d491774f94944895b6aa3e22b7aae8b3)](https://www.codacy.com/app/DrChainsaw/neuralODE4j?utm_source=github.com&utm_medium=referral&utm_content=DrChainsaw/neuralODE4j&utm_campaign=Badge_Grade)
[![Maintainability](https://api.codeclimate.com/v1/badges/c0d216da01a0c8b8d615/maintainability)](https://codeclimate.com/github/DrChainsaw/neuralODE4j/maintainability)
[![Test Coverage](https://api.codeclimate.com/v1/badges/c0d216da01a0c8b8d615/test_coverage)](https://codeclimate.com/github/DrChainsaw/neuralODE4j/test_coverage)

Implementation of neural Ordinary Differential Equations (ODE) built for [deeplearning4j](https://deeplearning4j.org/).

[[Arxiv](https://arxiv.org/abs/1806.07366)]

[[Pytorch repo by paper authors](https://github.com/rtqichen/torchdiffeq)]

NOTE: This is very much a work in progress and given that I haven't touched a differential equation since school chances
are that there are conceptual misunderstandings.

The performance of the MNIST example is in line with the results presented in the paper, but given the simplicity of that dataset this is no guarantee of correct implementation.
[[Very good blog post](https://julialang.org/blog/2019/01/fluxdiffeq)]

## Getting Started

Expand All @@ -21,20 +26,125 @@ cd neuralODE4j
mvn install
```

Currently only the MNIST toy experiment from the paper is implemented [[link]](./src/main/java/examples)
I will try to create a maven artifact whenever I find the time for it. Please file an issue for this if you are interested.

Implementations of the MNIST and spiral generation toy experiments from the paper can be found under examples [[link]](./src/main/java/examples)

## Usage

The class [OdeVertex](./src/main/java/ode/vertex/conf/OdeVertex.java) is used to add an arbitrary graph of Layers or GraphVertices as an ODE block in a ComputationGraph.

OdeVertex extends GraphVertex and can be added to a GraphBuilder just as any other vertex. It has a similar API as GraphBuilder for adding
layers and vertices.

Example:
```
final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("input")
.setInputTypes(InputType.convolutional(9, 9, 3))
.addLayer("normalLayer0",
new Convolution2D.Builder(3, 3)
.nOut(32)
.convolutionMode(ConvolutionMode.Same).build(), "input")
// Add an ODE block called "odeBlock" to the graph.
.addVertex("odeBlock",
new OdeVertex.Builder(new NeuralNetConfiguration.Builder(), "odeLayer0", new BatchNormalization.Builder().build())
// OdeVertex has a similar API as GraphBuilder for adding new layers/vertices to the OdeBlock
.addLayer("odeLayer1", new Convolution2D.Builder(3, 3)
.nOut(32)
.convolutionMode(ConvolutionMode.Same).build(), "odeLayer0")
// Add more layers and vertices as desired
// Build the OdeVertex. The resulting "inner graph" will be treated as an ODE
.build(), "normalLayer0")
// Layers/vertices can be added to the graph after the ODE block
.addLayer("normalLayer1", new BatchNormalization.Builder().build(), "odeBlock")
.setOutputs("output")
.addLayer("output", new CnnLossLayer(), "normalLayer1")
.build());
```

An inherent constraint to the method itself is that the output of the last layer in the OdeVertex must have the exact same
shape as the input to the first layer in the OdeVertex.

Note that OdeVertex.Builder requires a NeuralNetConfiguration.Builder as constructor input. This is because DL4J does not set graph wise
default values for things like updaters and weight initialization for vertices so the only way to apply them to the
Layers of the OdeVertex is to pass in the global configuration. Putting it as a required constructor argument will
hopefully make this harder to forget. It is of course possible to have a separate set of default values for the layers
of the OdeVertex by just giving it another NeuralNetConfiguration.Builder.

Method for solving the ODE can be configured:

```
new OdeVertex.Builder(...)
.odeConf(new FixedStep(
new DormandPrince54Solver(),
Nd4j.arange(0,2))) // Integrate between t = 0 and t = 1
```

Currently, the only ODE solver implementation which is integrated with Nd4j is [DormandPrince54Solver](./src/main/java/ode/solve/impl/DormandPrince54Solver.java),
It is however possible to use FirstOrderIntegrators from apache.commons:commons-math3 through [FirstOrderSolverAdapter](./src/main/java/ode/solve/commons/FirstOrderSolverAdapter.java)
at the cost of slower training and inference speed.

Time can also be input from another vertex in the graph:
```
new OdeVertex.Builder(...)
.odeConf(new InputStep(solverConf, 1)) // Number "1" refers to input "time" on the line below
.build(), "someLayer", "time");
```

Note that time must be a vector meaning it can not be minibatched; It has to be the same for all examples in a minibatch. This is because the implementation uses the minibatching approach from
section 6 in the paper where all examples in the batch are concatenated into one state. If one time sequence per example is desired this
can be achieved by using minibatch size of 1.

Gradients for loss with respect to time will be output from the vertex when using time as input but will be set to 0 by default to save computation. To have them computed, set needTimeGradient to true:

```
final boolean needTimeGradient = true;
new OdeVertex.Builder(...)
.odeConf(new InputStep(solverConf, 1, true, needTimeGradient))
.build(), "someLayer", "time");
```

I have not seen these being used for anything in the original implementation and if used, some extra measure is most likely required to ensure that time is always strictly increasing or decreasing.

In either case, the minimum number of elements in the time vector is two. If more than two elements are given the output of the OdeVertex
will have one more dimension compared to the input (corresponding to each time element).

For example, if the graph in the OdeVertex is the function `f = dz/dt` and `time` is the sequence `t0, t1, ..., tN-1`
with `N > 2` then the output of the OdeVertex will be (an approximation of) the sequence `z(t0), z(t1), ... , z(tN-1)`.
Note that `z(t0)` is also the input to the OdeVertex.

The exact mapping to dimensions depends on the shape of the input. Currently the following mappings are supported:

| Input shape | Output shape |
|---------------------------|-------------------------------|
| `B x H (dense/FF)` | `B x H x t (RNN)` |
| `B x H x T(RNN)` | `Not supported` |
| `B x D x H x W (conv 2D) `| `B x D x H x W x t (conv 3D)`|


### Prerequisites

Maven and GIT. Project uses ND4Js CUDA 10 backend as default which requires [CUDA 10](https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn).
To use CPU backend instead, set the maven property backend-CPU (e.g. through the -P flag when running from command line).
To use CPU backend instead, set the maven property backend-CPU:

```
mvn install -P backend-CPU
```

## Contributing

All contributions are welcome. Head over to the issues page and either add a new issue or pick up and existing one.

## Versioning

TBD
TBD.

## Authors

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>io.github.drchainsaw</groupId>
<artifactId>neuralODE4j</artifactId>
<version>0.0.1-SNAPSHOT</version>
<version>0.8.0</version>

<properties>

Expand Down
33 changes: 32 additions & 1 deletion src/main/java/examples/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Examples

# MNIST
## MNIST

Reimplementation of the MNIST experiment from the [original repo](https://github.com/rtqichen/torchdiffeq/tree/master/examples).

Expand All @@ -16,6 +16,11 @@ mvn exec:java -Dexec.mainClass="examples.mnist.Main" -Dexec.args="odenet"

Running from the IDE is also possible in which case resnet/odenet must be set as program arguments.

Use -help for full list of command line arguments:

```
mvn exec:java -Dexec.mainClass="examples.mnist.Main" -Dexec.args="-help"
```

Performance (approx):

Expand All @@ -26,3 +31,29 @@ Performance (approx):
| stem | 0.5% |

Model "stem" is using the resnet option with zero resblocks after the downsampling layers. This indicates that neither the residual blocks nor the ode block seems to be contributing much to the performance in this simple experiment. Performance also varies about +-0.1% for each run of the same model.

## Spiral demo

Reimplementation of the spiral generation experiment from the [original repo](https://github.com/rtqichen/torchdiffeq/tree/master/examples).

To run the ODE net model, use the following command:

```
mvn exec:java -Dexec.mainClass="examples.spiral.Main" -Dexec.args="odenet"
```

Running from the IDE is also possible in which case odenet must be set as program arguments.

Use -help for full list of command line arguments:

```
mvn exec:java -Dexec.mainClass="examples.spiral.Main" -Dexec.args="-help"
```

Note that this example tends to run faster on CPU than on GPU, probably due to the relatively low number of parameters. Example:

```
mvn -P backend-CPU exec:java -Dexec.mainClass="examples.spiral.Main" -Dexec.args="odenet"
```

Furthermore, original implementation does not use the adjoint method for back propagation in this example and instead does backpropagation through the operations of the ODE solver. Backpropagation through the ODE solver is not supported in this project as of yet.
10 changes: 9 additions & 1 deletion src/main/java/examples/mnist/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class Main {

private static final Logger log = LoggerFactory.getLogger(Main.class);

@Parameter(names = {"-help", "-h"}, description = "Prints help message")
private boolean help = false;

@Parameter(names = "-trainBatchSize", description = "Batch size to use for training")
private int trainBatchSize = 128;

Expand Down Expand Up @@ -85,6 +88,11 @@ private static Main parseArgs(String[] args) {
JCommander jCommander = parbuilder.build();
jCommander.parse(args);

if(main.help) {
jCommander.usage();
System.exit(0);
}

final ModelFactory factory = modelCommands.get(jCommander.getParsedCommand());

main.init(factory.create(), factory.name());
Expand All @@ -103,7 +111,7 @@ private void init(ComputationGraph model, String modelName) {
}

private void addListeners() {
final File savedir = new File("savedmodels" + File.separator + modelName);
final File savedir = new File("savedmodels" + File.separator + "MNIST" + File.separator + modelName);
log.info("Models will be saved in: " + savedir.getAbsolutePath());
savedir.mkdirs();
model.addListeners(
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/examples/mnist/OdeNetModel.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package examples.mnist;

import com.beust.jcommander.Parameter;
import com.beust.jcommander.Parameters;
import com.beust.jcommander.ParametersDelegate;
import ode.solve.api.FirstOrderSolverConf;
import ode.solve.conf.DormandPrince54Solver;
import ode.solve.conf.SolverConfig;
import ode.vertex.conf.OdeVertex;
import ode.vertex.conf.helper.FixedStep;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import util.listen.step.Mask;
Expand All @@ -20,6 +23,7 @@
*
* @author Christian Skarby
*/
@Parameters(commandDescription = "Configuration for image classification using an ODE block")
public class OdeNetModel implements ModelFactory {

private static final Logger log = LoggerFactory.getLogger(OdeNetModel.class);
Expand Down Expand Up @@ -77,7 +81,7 @@ private String addOdeBlock(String prev, FirstOrderSolverConf solver) {
conv3x3Same(nrofKernels), "normSecond")
.addLayer("normThird",
norm(nrofKernels), "convSecond")
.odeSolver(solver)
.odeConf(new FixedStep(solver, Nd4j.arange(2)))
.build(), prev);
return "odeBlock";
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/examples/mnist/ResNetReferenceModel.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package examples.mnist;

import com.beust.jcommander.Parameter;
import com.beust.jcommander.Parameters;
import com.beust.jcommander.ParametersDelegate;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
Expand All @@ -15,6 +16,7 @@
*
* @author Christian Skarby
*/
@Parameters(commandDescription = "Configuration for image classification using a number of residual blocks")
public class ResNetReferenceModel implements ModelFactory {

private static final Logger log = LoggerFactory.getLogger(ResNetReferenceModel.class);
Expand Down
36 changes: 36 additions & 0 deletions src/main/java/examples/spiral/AddKLDLabel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package examples.spiral;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/**
* Adds a label for KLD loss
*
* @author Christian Skarby
*/
public class AddKLDLabel implements MultiDataSetPreProcessor {

private final double mean;
private final double logvar;
private final long nrofLatentDims;

public AddKLDLabel(double mean, double var, long nrofLatentDims) {
this.mean = mean;
this.logvar = Math.log(var);
this.nrofLatentDims = nrofLatentDims;
}

@Override
public void preProcess(MultiDataSet multiDataSet) {
final INDArray label0 = multiDataSet.getLabels(0);
final long batchSize = label0.size(0);
final INDArray kldLabel = Nd4j.zeros(batchSize, 2*nrofLatentDims);
kldLabel.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, nrofLatentDims)}, mean);
kldLabel.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(nrofLatentDims, 2*nrofLatentDims)}, logvar);
multiDataSet.setLabels(new INDArray[]{label0, kldLabel});
}
}
19 changes: 19 additions & 0 deletions src/main/java/examples/spiral/Block.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package examples.spiral;

import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;

/**
* A simple block of layers
*
* @author Christian Skarby
*/
interface Block {

/**
* Add layers to given builder
* @param builder Builder to add layers to
* @param prev previous layers
* @return name of last layer added
*/
String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev);
}
43 changes: 43 additions & 0 deletions src/main/java/examples/spiral/DenseDecoderBlock.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package examples.spiral;

import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationReLU;

/**
* Simple decoder using {@link DenseLayer}s. Also uses a {@link FeedForwardToRnnPreProcessor} as it assumes 3D input.
*
* @author Christian Skarby
*/
public class DenseDecoderBlock implements Block {

private final long nrofHidden;
private final long nrofOutputs;

public DenseDecoderBlock(long nrofHidden, long nrofOutputs) {
this.nrofHidden = nrofHidden;
this.nrofOutputs = nrofOutputs;
}

@Override
public String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev) {
builder
.addLayer("dec0", new DenseLayer.Builder()
.nOut(nrofHidden)
.activation(new ActivationReLU())
.build(), prev)
.addLayer("dec1", new DenseLayer.Builder()
.nOut(nrofOutputs)
.activation(new ActivationIdentity())
.build(), "dec0")
.addVertex("decodedOutput",
new PreprocessorVertex(
new FeedForwardToRnnPreProcessor()),
"dec1");

return "decodedOutput";
}
}
Loading

0 comments on commit 81f12ef

Please sign in to comment.