diff --git a/README.md b/README.md index 1ea2268..55a7dcf 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,20 @@ # neuralODE4j -Implementation of neural ordinary differential equations built for deeplearning4j. +Travis [![Build Status](https://travis-ci.org/DrChainsaw/AmpControl.svg?branch=master)](https://travis-ci.org/DrChainsaw/NeuralODE4j) +AppVeyor[![Build status](https://ci.appveyor.com/api/projects/status/wjdi11f4cmx32ir8?svg=true)](https://ci.appveyor.com/project/DrChainsaw/neuralode4j) + +[![codebeat badge](https://codebeat.co/badges/d9e719b4-5465-4f08-9c14-f924691cdd86)](https://codebeat.co/projects/github-com-drchainsaw-neuralode4j-master) +[![Codacy Badge](https://api.codacy.com/project/badge/Grade/d491774f94944895b6aa3e22b7aae8b3)](https://www.codacy.com/app/DrChainsaw/neuralODE4j?utm_source=github.com&utm_medium=referral&utm_content=DrChainsaw/neuralODE4j&utm_campaign=Badge_Grade) +[![Maintainability](https://api.codeclimate.com/v1/badges/c0d216da01a0c8b8d615/maintainability)](https://codeclimate.com/github/DrChainsaw/neuralODE4j/maintainability) +[![Test Coverage](https://api.codeclimate.com/v1/badges/c0d216da01a0c8b8d615/test_coverage)](https://codeclimate.com/github/DrChainsaw/neuralODE4j/test_coverage) + +Implementation of neural Ordinary Differential Equations (ODE) built for [deeplearning4j](https://deeplearning4j.org/). [[Arxiv](https://arxiv.org/abs/1806.07366)] [[Pytorch repo by paper authors](https://github.com/rtqichen/torchdiffeq)] -NOTE: This is very much a work in progress and given that I haven't touched a differential equation since school chances -are that there are conceptual misunderstandings. - -The performance of the MNIST example is in line with the results presented in the paper, but given the simplicity of that dataset this is no guarantee of correct implementation. +[[Very good blog post](https://julialang.org/blog/2019/01/fluxdiffeq)] ## Getting Started @@ -21,12 +26,117 @@ cd neuralODE4j mvn install ``` -Currently only the MNIST toy experiment from the paper is implemented [[link]](./src/main/java/examples) +I will try to create a maven artifact whenever I find the time for it. Please file an issue for this if you are interested. + +Implementations of the MNIST and spiral generation toy experiments from the paper can be found under examples [[link]](./src/main/java/examples) + +## Usage + +The class [OdeVertex](./src/main/java/ode/vertex/conf/OdeVertex.java) is used to add an arbitrary graph of Layers or GraphVertices as an ODE block in a ComputationGraph. + +OdeVertex extends GraphVertex and can be added to a GraphBuilder just as any other vertex. It has a similar API as GraphBuilder for adding +layers and vertices. + +Example: +``` +final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .graphBuilder() + .addInputs("input") + .setInputTypes(InputType.convolutional(9, 9, 3)) + .addLayer("normalLayer0", + new Convolution2D.Builder(3, 3) + .nOut(32) + .convolutionMode(ConvolutionMode.Same).build(), "input") + + // Add an ODE block called "odeBlock" to the graph. + .addVertex("odeBlock", + new OdeVertex.Builder(new NeuralNetConfiguration.Builder(), "odeLayer0", new BatchNormalization.Builder().build()) + + // OdeVertex has a similar API as GraphBuilder for adding new layers/vertices to the OdeBlock + .addLayer("odeLayer1", new Convolution2D.Builder(3, 3) + .nOut(32) + .convolutionMode(ConvolutionMode.Same).build(), "odeLayer0") + + // Add more layers and vertices as desired + + // Build the OdeVertex. The resulting "inner graph" will be treated as an ODE + .build(), "normalLayer0") + + // Layers/vertices can be added to the graph after the ODE block + .addLayer("normalLayer1", new BatchNormalization.Builder().build(), "odeBlock") + .setOutputs("output") + .addLayer("output", new CnnLossLayer(), "normalLayer1") + .build()); +``` + +An inherent constraint to the method itself is that the output of the last layer in the OdeVertex must have the exact same +shape as the input to the first layer in the OdeVertex. + +Note that OdeVertex.Builder requires a NeuralNetConfiguration.Builder as constructor input. This is because DL4J does not set graph wise +default values for things like updaters and weight initialization for vertices so the only way to apply them to the +Layers of the OdeVertex is to pass in the global configuration. Putting it as a required constructor argument will +hopefully make this harder to forget. It is of course possible to have a separate set of default values for the layers +of the OdeVertex by just giving it another NeuralNetConfiguration.Builder. + +Method for solving the ODE can be configured: + +``` +new OdeVertex.Builder(...) + .odeConf(new FixedStep( + new DormandPrince54Solver(), + Nd4j.arange(0,2))) // Integrate between t = 0 and t = 1 +``` + +Currently, the only ODE solver implementation which is integrated with Nd4j is [DormandPrince54Solver](./src/main/java/ode/solve/impl/DormandPrince54Solver.java), +It is however possible to use FirstOrderIntegrators from apache.commons:commons-math3 through [FirstOrderSolverAdapter](./src/main/java/ode/solve/commons/FirstOrderSolverAdapter.java) +at the cost of slower training and inference speed. + +Time can also be input from another vertex in the graph: +``` +new OdeVertex.Builder(...) + .odeConf(new InputStep(solverConf, 1)) // Number "1" refers to input "time" on the line below + .build(), "someLayer", "time"); +``` + +Note that time must be a vector meaning it can not be minibatched; It has to be the same for all examples in a minibatch. This is because the implementation uses the minibatching approach from +section 6 in the paper where all examples in the batch are concatenated into one state. If one time sequence per example is desired this +can be achieved by using minibatch size of 1. + +Gradients for loss with respect to time will be output from the vertex when using time as input but will be set to 0 by default to save computation. To have them computed, set needTimeGradient to true: + +``` +final boolean needTimeGradient = true; +new OdeVertex.Builder(...) + .odeConf(new InputStep(solverConf, 1, true, needTimeGradient)) + .build(), "someLayer", "time"); +``` + +I have not seen these being used for anything in the original implementation and if used, some extra measure is most likely required to ensure that time is always strictly increasing or decreasing. + +In either case, the minimum number of elements in the time vector is two. If more than two elements are given the output of the OdeVertex +will have one more dimension compared to the input (corresponding to each time element). + +For example, if the graph in the OdeVertex is the function `f = dz/dt` and `time` is the sequence `t0, t1, ..., tN-1` +with `N > 2` then the output of the OdeVertex will be (an approximation of) the sequence `z(t0), z(t1), ... , z(tN-1)`. +Note that `z(t0)` is also the input to the OdeVertex. + +The exact mapping to dimensions depends on the shape of the input. Currently the following mappings are supported: + +| Input shape | Output shape | +|---------------------------|-------------------------------| +| `B x H (dense/FF)` | `B x H x t (RNN)` | +| `B x H x T(RNN)` | `Not supported` | +| `B x D x H x W (conv 2D) `| `B x D x H x W x t (conv 3D)`| + ### Prerequisites Maven and GIT. Project uses ND4Js CUDA 10 backend as default which requires [CUDA 10](https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn). -To use CPU backend instead, set the maven property backend-CPU (e.g. through the -P flag when running from command line). +To use CPU backend instead, set the maven property backend-CPU: + +``` +mvn install -P backend-CPU +``` ## Contributing @@ -34,7 +144,7 @@ All contributions are welcome. Head over to the issues page and either add a new ## Versioning -TBD +TBD. ## Authors diff --git a/pom.xml b/pom.xml index 3d92a00..a51f42a 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.github.drchainsaw neuralODE4j - 0.0.1-SNAPSHOT + 0.8.0 diff --git a/src/main/java/examples/README.md b/src/main/java/examples/README.md index b262db9..2a13b0e 100644 --- a/src/main/java/examples/README.md +++ b/src/main/java/examples/README.md @@ -1,6 +1,6 @@ # Examples -# MNIST +## MNIST Reimplementation of the MNIST experiment from the [original repo](https://github.com/rtqichen/torchdiffeq/tree/master/examples). @@ -16,6 +16,11 @@ mvn exec:java -Dexec.mainClass="examples.mnist.Main" -Dexec.args="odenet" Running from the IDE is also possible in which case resnet/odenet must be set as program arguments. +Use -help for full list of command line arguments: + +``` +mvn exec:java -Dexec.mainClass="examples.mnist.Main" -Dexec.args="-help" +``` Performance (approx): @@ -26,3 +31,29 @@ Performance (approx): | stem | 0.5% | Model "stem" is using the resnet option with zero resblocks after the downsampling layers. This indicates that neither the residual blocks nor the ode block seems to be contributing much to the performance in this simple experiment. Performance also varies about +-0.1% for each run of the same model. + +## Spiral demo + +Reimplementation of the spiral generation experiment from the [original repo](https://github.com/rtqichen/torchdiffeq/tree/master/examples). + +To run the ODE net model, use the following command: + +``` +mvn exec:java -Dexec.mainClass="examples.spiral.Main" -Dexec.args="odenet" +``` + +Running from the IDE is also possible in which case odenet must be set as program arguments. + +Use -help for full list of command line arguments: + +``` +mvn exec:java -Dexec.mainClass="examples.spiral.Main" -Dexec.args="-help" +``` + +Note that this example tends to run faster on CPU than on GPU, probably due to the relatively low number of parameters. Example: + +``` +mvn -P backend-CPU exec:java -Dexec.mainClass="examples.spiral.Main" -Dexec.args="odenet" +``` + +Furthermore, original implementation does not use the adjoint method for back propagation in this example and instead does backpropagation through the operations of the ODE solver. Backpropagation through the ODE solver is not supported in this project as of yet. diff --git a/src/main/java/examples/mnist/Main.java b/src/main/java/examples/mnist/Main.java index cdd332d..f007b29 100644 --- a/src/main/java/examples/mnist/Main.java +++ b/src/main/java/examples/mnist/Main.java @@ -37,6 +37,9 @@ class Main { private static final Logger log = LoggerFactory.getLogger(Main.class); + @Parameter(names = {"-help", "-h"}, description = "Prints help message") + private boolean help = false; + @Parameter(names = "-trainBatchSize", description = "Batch size to use for training") private int trainBatchSize = 128; @@ -85,6 +88,11 @@ private static Main parseArgs(String[] args) { JCommander jCommander = parbuilder.build(); jCommander.parse(args); + if(main.help) { + jCommander.usage(); + System.exit(0); + } + final ModelFactory factory = modelCommands.get(jCommander.getParsedCommand()); main.init(factory.create(), factory.name()); @@ -103,7 +111,7 @@ private void init(ComputationGraph model, String modelName) { } private void addListeners() { - final File savedir = new File("savedmodels" + File.separator + modelName); + final File savedir = new File("savedmodels" + File.separator + "MNIST" + File.separator + modelName); log.info("Models will be saved in: " + savedir.getAbsolutePath()); savedir.mkdirs(); model.addListeners( diff --git a/src/main/java/examples/mnist/OdeNetModel.java b/src/main/java/examples/mnist/OdeNetModel.java index 2d808b8..723e383 100644 --- a/src/main/java/examples/mnist/OdeNetModel.java +++ b/src/main/java/examples/mnist/OdeNetModel.java @@ -1,13 +1,16 @@ package examples.mnist; import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; import com.beust.jcommander.ParametersDelegate; import ode.solve.api.FirstOrderSolverConf; import ode.solve.conf.DormandPrince54Solver; import ode.solve.conf.SolverConfig; import ode.vertex.conf.OdeVertex; +import ode.vertex.conf.helper.FixedStep; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import util.listen.step.Mask; @@ -20,6 +23,7 @@ * * @author Christian Skarby */ +@Parameters(commandDescription = "Configuration for image classification using an ODE block") public class OdeNetModel implements ModelFactory { private static final Logger log = LoggerFactory.getLogger(OdeNetModel.class); @@ -77,7 +81,7 @@ private String addOdeBlock(String prev, FirstOrderSolverConf solver) { conv3x3Same(nrofKernels), "normSecond") .addLayer("normThird", norm(nrofKernels), "convSecond") - .odeSolver(solver) + .odeConf(new FixedStep(solver, Nd4j.arange(2))) .build(), prev); return "odeBlock"; } diff --git a/src/main/java/examples/mnist/ResNetReferenceModel.java b/src/main/java/examples/mnist/ResNetReferenceModel.java index 297af36..51e5d3e 100644 --- a/src/main/java/examples/mnist/ResNetReferenceModel.java +++ b/src/main/java/examples/mnist/ResNetReferenceModel.java @@ -1,6 +1,7 @@ package examples.mnist; import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; import com.beust.jcommander.ParametersDelegate; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; @@ -15,6 +16,7 @@ * * @author Christian Skarby */ +@Parameters(commandDescription = "Configuration for image classification using a number of residual blocks") public class ResNetReferenceModel implements ModelFactory { private static final Logger log = LoggerFactory.getLogger(ResNetReferenceModel.class); diff --git a/src/main/java/examples/spiral/AddKLDLabel.java b/src/main/java/examples/spiral/AddKLDLabel.java new file mode 100644 index 0000000..cceb10d --- /dev/null +++ b/src/main/java/examples/spiral/AddKLDLabel.java @@ -0,0 +1,36 @@ +package examples.spiral; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +/** + * Adds a label for KLD loss + * + * @author Christian Skarby + */ +public class AddKLDLabel implements MultiDataSetPreProcessor { + + private final double mean; + private final double logvar; + private final long nrofLatentDims; + + public AddKLDLabel(double mean, double var, long nrofLatentDims) { + this.mean = mean; + this.logvar = Math.log(var); + this.nrofLatentDims = nrofLatentDims; + } + + @Override + public void preProcess(MultiDataSet multiDataSet) { + final INDArray label0 = multiDataSet.getLabels(0); + final long batchSize = label0.size(0); + final INDArray kldLabel = Nd4j.zeros(batchSize, 2*nrofLatentDims); + kldLabel.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, nrofLatentDims)}, mean); + kldLabel.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(nrofLatentDims, 2*nrofLatentDims)}, logvar); + multiDataSet.setLabels(new INDArray[]{label0, kldLabel}); + } +} diff --git a/src/main/java/examples/spiral/Block.java b/src/main/java/examples/spiral/Block.java new file mode 100644 index 0000000..095e3e3 --- /dev/null +++ b/src/main/java/examples/spiral/Block.java @@ -0,0 +1,19 @@ +package examples.spiral; + +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; + +/** + * A simple block of layers + * + * @author Christian Skarby + */ +interface Block { + + /** + * Add layers to given builder + * @param builder Builder to add layers to + * @param prev previous layers + * @return name of last layer added + */ + String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev); +} diff --git a/src/main/java/examples/spiral/DenseDecoderBlock.java b/src/main/java/examples/spiral/DenseDecoderBlock.java new file mode 100644 index 0000000..3acafe8 --- /dev/null +++ b/src/main/java/examples/spiral/DenseDecoderBlock.java @@ -0,0 +1,43 @@ +package examples.spiral; + +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.graph.PreprocessorVertex; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.activations.impl.ActivationReLU; + +/** + * Simple decoder using {@link DenseLayer}s. Also uses a {@link FeedForwardToRnnPreProcessor} as it assumes 3D input. + * + * @author Christian Skarby + */ +public class DenseDecoderBlock implements Block { + + private final long nrofHidden; + private final long nrofOutputs; + + public DenseDecoderBlock(long nrofHidden, long nrofOutputs) { + this.nrofHidden = nrofHidden; + this.nrofOutputs = nrofOutputs; + } + + @Override + public String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev) { + builder + .addLayer("dec0", new DenseLayer.Builder() + .nOut(nrofHidden) + .activation(new ActivationReLU()) + .build(), prev) + .addLayer("dec1", new DenseLayer.Builder() + .nOut(nrofOutputs) + .activation(new ActivationIdentity()) + .build(), "dec0") + .addVertex("decodedOutput", + new PreprocessorVertex( + new FeedForwardToRnnPreProcessor()), + "dec1"); + + return "decodedOutput"; + } +} diff --git a/src/main/java/examples/spiral/KldLossBlock.java b/src/main/java/examples/spiral/KldLossBlock.java new file mode 100644 index 0000000..d5e46dd --- /dev/null +++ b/src/main/java/examples/spiral/KldLossBlock.java @@ -0,0 +1,23 @@ +package examples.spiral; + +import examples.spiral.loss.NormKLDLoss; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.layers.LossLayer; +import org.nd4j.linalg.activations.impl.ActivationIdentity; + +/** + * Adds a {@link LossLayer} using {@link NormKLDLoss} for the mean and log(var) of the latent variable + * + * @author Christian Skarby + */ +class KldLossBlock implements Block { + + @Override + public String add(ComputationGraphConfiguration.GraphBuilder builder, String... qz0_meanAndLogvar) { + builder.addLayer("kld", new LossLayer.Builder() + .activation(new ActivationIdentity()) + .lossFunction(new NormKLDLoss()) + .build(), qz0_meanAndLogvar); + return "kld"; + } +} diff --git a/src/main/java/examples/spiral/LatentOdeBlock.java b/src/main/java/examples/spiral/LatentOdeBlock.java new file mode 100644 index 0000000..e6284b5 --- /dev/null +++ b/src/main/java/examples/spiral/LatentOdeBlock.java @@ -0,0 +1,48 @@ +package examples.spiral; + +import ode.vertex.conf.OdeVertex; +import ode.vertex.conf.helper.OdeHelper; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.nd4j.linalg.activations.impl.ActivationELU; +import org.nd4j.linalg.activations.impl.ActivationIdentity; + +/** + * {@link Block} which uses an {@link OdeVertex} to calculate a latent variable z(t) from z(0) and t + * + * @author Christian Skarby + */ +class LatentOdeBlock implements Block { + + public static final String name = "zt"; + + private final long nrofHidden; + private final long nrofLatentDims; + private final OdeHelper solverConf; + + LatentOdeBlock(long nrofHidden, long nrofLatentDims, OdeHelper solverConf) { + this.nrofHidden = nrofHidden; + this.nrofLatentDims = nrofLatentDims; + this.solverConf = solverConf; + } + + @Override + public String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev) { + builder.addVertex(name, new OdeVertex.Builder( + builder.getGlobalConfiguration(), + "fc1", + new DenseLayer.Builder() + .nIn(nrofLatentDims) // Fail fast if previous layer is incorrect + .nOut(nrofHidden) + .activation(new ActivationELU()).build()) + .addLayer("fc2", new DenseLayer.Builder() + .nOut(nrofHidden) + .activation(new ActivationELU()).build(), "fc1") + .addLayer("fc3", new DenseLayer.Builder() + .nOut(nrofLatentDims) + .activation(new ActivationIdentity()).build(), "fc2") + .odeConf(solverConf) + .build(), prev); + return name; + } +} diff --git a/src/main/java/examples/spiral/LayerUtil.java b/src/main/java/examples/spiral/LayerUtil.java new file mode 100644 index 0000000..9c72d04 --- /dev/null +++ b/src/main/java/examples/spiral/LayerUtil.java @@ -0,0 +1,80 @@ +package examples.spiral; + +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.WeightInitUtil; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.learning.config.Adam; + +import java.util.Map; + +/** + * Utils for creating layers + * + * @author Christian Skarby + */ +class LayerUtil { + + /** + * Initialize a GraphBuilder for 2D spiral generation + * + * @return a GraphBuilder for 2D spiral generation + */ + public static ComputationGraphConfiguration.GraphBuilder initGraphBuilder(long seed, long nrofSamples) { + return new NeuralNetConfiguration.Builder() + .seed(seed) + // At first glance, Pytorch seems to use RELU_UNIFORM for dense layers by default. However, a combination + // of magic numbers and odd default values results in the equivalent of UNIFORM. The spiral experiment + // is stupidly sensitive to hyper parameters and weight (and bias) init is no exception to this. + .weightInit(WeightInit.UNIFORM) + // Original implementation does not seem to use any regularization, but I could not make this experiment work without it + .l2(0.0005) + .l2Bias(0.001) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(new Adam(0.01)) + .graphBuilder() + .setInputTypes(InputType.recurrent(2, nrofSamples), InputType.feedForward(nrofSamples)); + } + + /** + * Initialize biases according to the same strategy as is done for the weights for vertices which have a bias parameter + * + * @param graph Graph to init + * @param weightInit Method for weight init + */ + public static void initBiases(ComputationGraph graph, WeightInit weightInit) { + + Map paramTable = graph.paramTable(false); + for (String parName : paramTable.keySet()) { + if (parName.endsWith(DefaultParamInitializer.BIAS_KEY)) { + initWeigths(parName, paramTable, weightInit); + } + } + } + + private static void initWeigths(String biasKey, Map paramTable, WeightInit weightInit) { + final String weightKey = biasKey.substring(0, biasKey.length() - DefaultParamInitializer.BIAS_KEY.length()) + DefaultParamInitializer.WEIGHT_KEY; + + final INDArray weight = paramTable.get(weightKey); + double fanIn = 0; + if (weight.rank() == 2) { + fanIn = weight.size(0); + } + + final INDArray bias = paramTable.get(biasKey); + + WeightInitUtil.initWeights( + fanIn, + (double) bias.length(), + bias.shape(), + weightInit, + null, + bias); + + } +} diff --git a/src/main/java/examples/spiral/Main.java b/src/main/java/examples/spiral/Main.java new file mode 100644 index 0000000..ed255ab --- /dev/null +++ b/src/main/java/examples/spiral/Main.java @@ -0,0 +1,267 @@ +package examples.spiral; + +import ch.qos.logback.classic.Level; +import com.beust.jcommander.JCommander; +import com.beust.jcommander.Parameter; +import examples.spiral.listener.IterationHook; +import examples.spiral.listener.PlotActivations; +import examples.spiral.listener.PlotDecodedOutput; +import examples.spiral.listener.SpiralPlot; +import org.apache.commons.io.filefilter.OrFileFilter; +import org.apache.commons.io.filefilter.WildcardFileFilter; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.optimize.listeners.CheckpointListener; +import org.deeplearning4j.optimize.listeners.PerformanceListener; +import org.deeplearning4j.util.ModelSerializer; +import org.jetbrains.annotations.NotNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import util.listen.training.NanScoreWatcher; +import util.listen.training.ZeroGrad; +import util.plot.Plot; +import util.plot.RealTimePlot; +import util.random.SeededRandomFactory; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.util.*; + +/** + * Main class for spiral example. Reimplementation of https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py + * + * @author Christian Skarby + */ +class Main { + + private static final Logger log = LoggerFactory.getLogger(Main.class); + + @Parameter(names = {"-help", "-h"}, description = "Prints help message") + private boolean help = false; + + @Parameter(names = "-trainBatchSize", description = "Batch size to use for training") + private int trainBatchSize = 1000; + + @Parameter(names = "-nrofTimeStepsForTraining", description = "Number of time steps per spiral when training") + private int nrofTimeStepsForTraining = 100; + + @Parameter(names = "-nrofTrainIters", description = "Number of iterations for training") + private int nrofTrainIters = 2000; + + @Parameter(names = "-noiseSigma", description = "How much noise to add to generated spirals") + private double noiseSigma = 0.3; + + @Parameter(names = "-nrofLatentDims", description = "Number of latent dimensions to use") + private long nrofLatentDims = 4; + + @Parameter(names = "-newModel", description = "Load latest checkpoint (if available) if set to false. If true or if " + + "no checkpoint exists, a new model will be created") + private boolean newModel = false; + + private TimeVae model; + private String modelName; + private SpiralIterator iterator; + + public static void main(String[] args) throws IOException { + ch.qos.logback.classic.Logger root = (ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME); + root.setLevel(Level.INFO); + + SeededRandomFactory.setNd4jSeed(0); + + final Main main = new Main(); + final ModelFactory factory = parseArgs(main, args); + + if(!main.help) { + createModel(main, factory); + main.addListeners(); + main.run(); + } + } + + private static ModelFactory parseArgs(Main main, String[] args) throws IOException { + + final Map modelCommands = new HashMap<>(); + modelCommands.put("odenet", new OdeNetModel()); + + JCommander.Builder parbuilder = JCommander.newBuilder() + .addObject(main); + + for (Map.Entry command : modelCommands.entrySet()) { + parbuilder.addCommand(command.getKey(), command.getValue()); + } + + JCommander jCommander = parbuilder.build(); + jCommander.parse(args); + + return modelCommands.get(jCommander.getParsedCommand()); + } + + @NotNull + private static Main createModel(Main main, ModelFactory factory) throws IOException { + final File saveDir = saveDir(factory.name()); + + if (!main.newModel) { + + final File[] files = ageOrder(saveDir.listFiles((FilenameFilter) + new OrFileFilter( + new WildcardFileFilter("checkpoint_*_ComputationGraph.zip"), + new WildcardFileFilter("checkpoint_*_ComputationGraph_bck.zip")))); + if (files != null && files.length > 0) { + final Path modelFile = Paths.get(files[files.length - 1].getAbsolutePath()); + log.info("Restoring model from file: " + modelFile); + + if (!modelFile.getFileName().toString().matches(".*_bck\\.zip")) { + // Because checkpoint listener deletes all files matching the checkpoint_*_ComputationGraph.zip pattern. + final Path backupFile = Paths.get(modelFile.toString().replace(".", "_bck.")); + Files.copy(modelFile, backupFile, StandardCopyOption.REPLACE_EXISTING); + } + + final ComputationGraph graph = ModelSerializer.restoreComputationGraph(modelFile.toFile(), true); + main.init(factory.createFrom(graph), + factory.name(), + factory.getPreProcessor(main.nrofLatentDims)); + return main; + } + } + + // Else, delete all saved plots and initialize a new model + final File[] plotFiles = saveDir.listFiles((FilenameFilter) new WildcardFileFilter("*.plt")); + if(plotFiles != null) { + for (File plotFile : plotFiles) { + Files.delete(Paths.get(plotFile.getAbsolutePath())); + } + } + + main.init( + factory.createNew(main.nrofTimeStepsForTraining, main.noiseSigma, main.nrofLatentDims), + factory.name(), + factory.getPreProcessor(main.nrofLatentDims)); + return main; + } + + private static File[] ageOrder(File[] files) { + if (files == null) { + return null; + } + + Arrays.sort(files, new Comparator() { + @Override + public int compare(File o1, File o2) { + return Long.compare(o1.lastModified(), o2.lastModified()); + } + }); + return files; + } + + private static File saveDir(String modelName) { + return new File("savedmodels" + File.separator + "spiral" + File.separator + modelName); + } + + private void init(TimeVae model, String modelName, MultiDataSetPreProcessor preProcessor) { + this.model = model; + this.modelName = modelName; + + final SpiralFactory spiralFactory = new SpiralFactory(0, 0.3, 0, 6 * Math.PI, 500); + this.iterator = new SpiralIterator( + new SpiralIterator.Generator(spiralFactory, noiseSigma, nrofTimeStepsForTraining, new Random(Nd4j.getRandom().nextLong())), + trainBatchSize); + iterator.setPreProcessor(preProcessor); + } + + private void addListeners() { + final File savedir = saveDir(modelName); + log.info("Models will be saved in: " + savedir.getAbsolutePath()); + savedir.mkdirs(); + + setupOutputPlotting(savedir); + + final Plot meanAndLogVarPlot = new RealTimePlot<>("Mean and log(var) of z0", savedir.getAbsolutePath()); + + final int saveEveryNIterations = 20; + model.trainingModel().addListeners( + new ZeroGrad(), + new PerformanceListener(1, true), + new CheckpointListener.Builder(savedir.getAbsolutePath()) + .keepLast(1) + .deleteExisting(true) + .saveEveryNIterations(saveEveryNIterations, true) + .build(), + new NanScoreWatcher(() -> { + throw new IllegalStateException("NaN score!"); + }), + new PlotActivations(meanAndLogVarPlot, model.qzMeanAndLogVarName(), new String[] {"qz0Mean" , "qz0Log(Var)"}), + new IterationHook(saveEveryNIterations, () -> { + try { + meanAndLogVarPlot.storePlotData(); + } catch (IOException e) { + log.error("Could not save plot data! Exception:" + e.getMessage()); + } + })); + } + + private void setupOutputPlotting(File savedir) { + final SpiralPlot outputPlot = new SpiralPlot(new RealTimePlot<>("Training Output", savedir.getAbsolutePath())); + for (int batchNrToPlot = 0; batchNrToPlot < Math.min(trainBatchSize, 4); batchNrToPlot++) { + outputPlot.plot("True output " + batchNrToPlot, iterator.next().getLabels(0), batchNrToPlot); + model.trainingModel().addListeners(new PlotDecodedOutput(outputPlot, model.outputName(), batchNrToPlot)); + } + } + + private void run() throws IOException { + final ComputationGraph trainingModel = model.trainingModel(); + final Plot samplePlot = new RealTimePlot<>("Reconstruction", saveDir(modelName).getAbsolutePath()); + for (int i = trainingModel.getIterationCount(); i < nrofTrainIters; i++) { + trainingModel.fit(iterator.next()); + + if (i > 0 && i % 100 == 0) { + drawSample(0, samplePlot); + samplePlot.savePicture("_iter" + trainingModel.getIterationCount()); + } + } + + for (int i = 0; i < Math.min(trainBatchSize, 8); i++) { + final Plot plot = new RealTimePlot<>("Reconstruction " + i, saveDir(modelName).getAbsolutePath()); + drawSample(i, plot); + plot.savePicture(""); + } + } + + private void drawSample(final int toSample, Plot reconstructionPlot) { + log.info("Sampling model..."); + + final SpiralIterator.SpiralSet spiralSet = iterator.getCurrent(); + final MultiDataSet mds = spiralSet.getMds(); + final INDArray sample = mds.getFeatures(0).tensorAlongDimension(toSample, 1, 2).reshape(1, 2, nrofTimeStepsForTraining); + + final INDArray z0 = model.encode(sample); + + final INDArray tsPos = Nd4j.linspace(0, 2 * Math.PI, 2000); + final INDArray tsNeg = Nd4j.linspace(0, -Math.PI, 2000); + + final INDArray zsPos = model.timeDependency(z0, tsPos); + final INDArray zsNeg = model.timeDependency(z0, tsNeg); + + final INDArray xsPos = model.decode(zsPos); + final INDArray xsNeg = model.decode(zsNeg); + + final SpiralPlot spiralPlot = new SpiralPlot(reconstructionPlot); + + reconstructionPlot.clearData("True trajectory"); + reconstructionPlot.clearData("Sampled data"); + reconstructionPlot.clearData("Learned trajectory (t > 0)"); + reconstructionPlot.clearData("Learned trajectory (t < 0)"); + + spiralSet.getSpirals().get(toSample).plotBase(reconstructionPlot, "True trajectory"); + spiralPlot.plot("Sampled data", sample, 0); // Always dim 0 as shape is [1, 2, nrofTimeSteps] + spiralPlot.plot("Learned trajectory (t > 0)", xsPos, 0); // Always dim 0 as shape is [1, 2, 2000] + spiralPlot.plot("Learned trajectory (t < 0)", xsNeg, 0); // Always dim 0 as shape is [1, 2, 2000] + } +} diff --git a/src/main/java/examples/spiral/ModelFactory.java b/src/main/java/examples/spiral/ModelFactory.java new file mode 100644 index 0000000..68619b8 --- /dev/null +++ b/src/main/java/examples/spiral/ModelFactory.java @@ -0,0 +1,40 @@ +package examples.spiral; + +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; + +/** + * Interface for models + * + * @author Christian Skarby + */ +public interface ModelFactory { + + /** + * Create a new model to use + * @param nrofSamples The number of samples in each spiral + * @param noiseSigma Noise std for training spirals + * @param nrofLatentDims How many dimensions for latent variable + * @return a {@link TimeVae} for the model + */ + TimeVae createNew(long nrofSamples, double noiseSigma, long nrofLatentDims); + + /** + * Create a new {@link TimeVae} from an existing {@link ComputationGraph} + * @param graph Computation graph to use + * @return a {@link TimeVae} for the model + */ + TimeVae createFrom(ComputationGraph graph); + + /** + * Return the name of the model built to use e.g. for saving models + * @return the name of the models + */ + String name(); + + /** + * Return a {@link MultiDataSetPreProcessor} which needs to be applied to the input + * @return a {@link MultiDataSetPreProcessor} + */ + MultiDataSetPreProcessor getPreProcessor(long nrofLatentDims); +} diff --git a/src/main/java/examples/spiral/OdeNetModel.java b/src/main/java/examples/spiral/OdeNetModel.java new file mode 100644 index 0000000..2c1c980 --- /dev/null +++ b/src/main/java/examples/spiral/OdeNetModel.java @@ -0,0 +1,93 @@ +package examples.spiral; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; +import examples.spiral.loss.NormLogLikelihoodLoss; +import examples.spiral.vertex.conf.SampleGaussianVertex; +import ode.solve.conf.DormandPrince54Solver; +import ode.solve.conf.SolverConfig; +import ode.vertex.conf.helper.InputStep; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Model used for spiral generation using neural ODE. Equivalent to model used in + * https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py + * + * @author Christian Skarby + */ +@Parameters(commandDescription = "Configuration for spiral generating VAE using latent ODE") +class OdeNetModel implements ModelFactory { + + @Parameter(names = "-dontInterpolateOdeForward", description = "Don't use interpolation when solving latent ODE in forward " + + "direction if set. Default is to use interpolation as this is the method used in original implementation") + private boolean interpolateOdeForward = true; + + @Parameter(names = "-encoderNrofHidden", description = "Number of hidden units in encoder") + private long encoderNrofHidden = 25; + + @Parameter(names = "-latentNrofHidden", description = "Number of hidden units in latent ODE function") + private long latentNrofHidden = 20; + + @Parameter(names = "-decoderNrofHidden", description = "Number of hidden units in decoder") + private long decoderNrofHidden = 20; + + @Override + public TimeVae createNew(long nrofSamples, double noiseSigma, long nrofLatentDims) { + + final Block enc = new RnnEncoderBlock(nrofLatentDims, encoderNrofHidden, "spiral"); + final Block dec = new DenseDecoderBlock(decoderNrofHidden, 2); + final Block ode = new LatentOdeBlock(latentNrofHidden, nrofLatentDims, + new InputStep( + new DormandPrince54Solver( + new SolverConfig(1e-12, 1e-6, 1e-20, 1e2)), + 1, interpolateOdeForward)); + final Block outReconstruction = new ReconstructionLossBlock(new NormLogLikelihoodLoss(noiseSigma)); + final Block outKld = new KldLossBlock(); + + final GraphBuilder builder = LayerUtil.initGraphBuilder(Nd4j.getRandom().nextLong(), nrofSamples); + builder.addInputs("spiral", "time"); + + String next = enc.add(builder, "spiral"); + final String qz0_meanAndLogvar = next; + + // Add sampling of a gaussian with the encoded mean and log(var) + final String z0 = "z0"; + builder.addVertex(z0, new SampleGaussianVertex(Nd4j.getRandom().nextLong()), next); + + next = ode.add(builder, z0, "time"); // Position of "time" is dependent on argument in constructor to InputStep above + final String zt = next; + next = dec.add(builder, next); + + // Steps after this is just for ELBO calculation + String output0 = outReconstruction.add(builder, next); + String output1 = outKld.add(builder, qz0_meanAndLogvar); + + builder.setOutputs(output0, output1); + + final ComputationGraph graph = new ComputationGraph(builder.build()); + graph.init(); + + LayerUtil.initBiases(graph, WeightInit.UNIFORM); + return new TimeVae(graph, z0, zt); + } + + @Override + public TimeVae createFrom(ComputationGraph graph) { + return new TimeVae(graph, "z0", LatentOdeBlock.name); + } + + @Override + public String name() { + return "odenet_enc" + encoderNrofHidden + "_lat" + latentNrofHidden + "_dec" + decoderNrofHidden; + } + + + @Override + public MultiDataSetPreProcessor getPreProcessor(long nrofLatentDims) { + return new AddKLDLabel(0, 1, nrofLatentDims); + } +} diff --git a/src/main/java/examples/spiral/ReconstructionLossBlock.java b/src/main/java/examples/spiral/ReconstructionLossBlock.java new file mode 100644 index 0000000..fdd7a3d --- /dev/null +++ b/src/main/java/examples/spiral/ReconstructionLossBlock.java @@ -0,0 +1,31 @@ +package examples.spiral; + +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.layers.RnnLossLayer; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.lossfunctions.ILossFunction; + +/** + * Adds a {@link RnnLossLayer} for the decoder output + * + * @author Christian Skarby + */ +class ReconstructionLossBlock implements Block { + + private final ILossFunction loss; + + ReconstructionLossBlock(ILossFunction loss) { + this.loss = loss; + } + + + @Override + public String add(ComputationGraphConfiguration.GraphBuilder builder, String... decoderOutput) { + builder.addLayer("reconstruction", new RnnLossLayer.Builder() + .activation(new ActivationIdentity()) + .lossFunction(loss) + .build(), decoderOutput); + + return "reconstruction"; + } +} diff --git a/src/main/java/examples/spiral/RnnEncoderBlock.java b/src/main/java/examples/spiral/RnnEncoderBlock.java new file mode 100644 index 0000000..939ecf0 --- /dev/null +++ b/src/main/java/examples/spiral/RnnEncoderBlock.java @@ -0,0 +1,51 @@ +package examples.spiral; + +import examples.spiral.vertex.conf.ConcatRnn; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.activations.impl.ActivationTanH; + +/** + * Simple RRN encode. Structure is different compared to the one used in + * https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py as dl4j does not seem to have the possibility + * to have a different number of recurrent weights compared to number of outputs. Instead, a {@link DenseLayer} is added + * after the {@link SimpleRnn} to set output size equal to number of latent dimensions. + * + * Furthermore, the RNN in the original implementation concatenates input and previous state of RNN before weighting while + * SimpleRnn adds input to previous state after input is weighted. + * + * @author Christian Skarby + */ +public class RnnEncoderBlock implements Block { + + private final long nrofLatentDims; + private final long nrofHidden; + private final String inputName; + + public RnnEncoderBlock(long nrofLatentDims, long nrofHidden, String inputName) { + this.nrofLatentDims = nrofLatentDims; + this.nrofHidden = nrofHidden; + this.inputName = inputName; + } + + @Override + public String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev) { + builder + // Note that reverse time series input is assumed. ReverseTimeSeriesVertex adds significant processing time + .addLayer("encRnn", new ConcatRnn.Builder() + .nOut(nrofHidden) + .activation(new ActivationTanH()) + .build(),prev) + .addVertex("encLastStep", new LastTimeStepVertex(inputName), "encRnn") + .addLayer("encOut", new DenseLayer.Builder() + .activation(new ActivationIdentity()) + .nOut(2*nrofLatentDims) + .build(), "encLastStep"); + + return "encOut"; + } + +} diff --git a/src/main/java/examples/spiral/SpiralFactory.java b/src/main/java/examples/spiral/SpiralFactory.java new file mode 100644 index 0000000..b462057 --- /dev/null +++ b/src/main/java/examples/spiral/SpiralFactory.java @@ -0,0 +1,129 @@ +package examples.spiral; + +import examples.spiral.listener.SpiralPlot; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.ops.transforms.Transforms; +import util.plot.Plot; +import util.plot.RealTimePlot; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BooleanSupplier; +import java.util.function.DoubleSupplier; + +/** + * Creates spirals, equivalent to generate_spiral2d in https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py. + * Formula for a spiral {@code r = a + b * theta} where theta is linearly spaced from a specified start and stop value. + * + * @author Christian Skarby + */ +class SpiralFactory { + + private final Spiral baseCw; + private final Spiral baseCc; + private final INDArray baseTs; + + static class Spiral { + private final INDArray trajectory; + private final INDArray theta; + + private Spiral(INDArray trajectory, INDArray theta) { + this.trajectory = trajectory; + this.theta = theta; + } + + void plot(Plot plot, String series) { + new SpiralPlot(plot).plot(series, trajectory); + } + + void plotBase(Plot plot, String series) { + plot(plot, series); + } + + INDArray trajectory() { + return trajectory; + } + + INDArray theta() { + return theta; + } + } + + static class SpiralFragment extends Spiral { + + private final Spiral base; + + public SpiralFragment(Spiral base, INDArray trajectory, INDArray theta) { + super(trajectory, theta); + this.base = base; + } + + @Override + void plotBase(Plot plot, String series) { + base.plotBase(plot, series); + } + } + + + SpiralFactory(double a, double b, double startTheta, double stopTheta, long nrofSamples) { + final INDArray thetaCc = Nd4j.linspace(startTheta, stopTheta, nrofSamples); + final INDArray rCc = thetaCc.mul(b).addi(a); + final INDArray trajectoryCc = Nd4j.vstack(rCc.mul(Transforms.cos(thetaCc)).add(5), rCc.mul(Transforms.sin(thetaCc))); + this.baseCc = new Spiral(trajectoryCc, thetaCc); + + final INDArray thetaCw = thetaCc.rsub(1 + stopTheta); + final INDArray rCw = thetaCw.rdiv(50).mul(b).addi(a); + final INDArray trajectoryCw = Nd4j.vstack(rCw.mul(Transforms.cos(thetaCw)).sub(5), rCw.mul(Transforms.sin(thetaCw))); + this.baseCw = new Spiral(trajectoryCw, thetaCw); + baseTs = thetaCc; + } + + void plotClockWise(Plot plot, String label) { + baseCw.plot(plot, label); + } + + void plotCounterClock(Plot plot, String label) { + baseCc.plot(plot, label); + } + + long baseNrofSamples() { + return baseTs.length(); + } + + INDArray baseTs() { + return baseTs; + } + + List sample(long nrofSpirals, long nrofSamples, DoubleSupplier startSupplier, BooleanSupplier cwOrCc) { + final List output = new ArrayList<>(); + + for(int i = 0; i < nrofSpirals; i++) { + Spiral base = cwOrCc.getAsBoolean() ? baseCw : baseCc; + long start = (long)Math.min(base.theta.length() - nrofSamples, startSupplier.getAsDouble() * base.theta.length()); + + output.add(new SpiralFragment( + base, + base.trajectory.get(NDArrayIndex.all(), NDArrayIndex.interval(start, start + nrofSamples)), + base.theta.get(NDArrayIndex.all(), NDArrayIndex.interval(start, start + nrofSamples))) + ); + } + + return output; + } + + + public static void main(String[] args) { + final Plot plot = new RealTimePlot<>("Spiral test", ""); + final SpiralFactory factory = new SpiralFactory(0, 0.3, 0, 6 * Math.PI, 500); + + final String cw = "ClockWise"; + final String cc = "CounterClock"; + plot.createSeries(cw); + plot.createSeries(cc); + factory.plotClockWise(plot, cw); + factory.plotCounterClock(plot, cc); + } + +} diff --git a/src/main/java/examples/spiral/SpiralIterator.java b/src/main/java/examples/spiral/SpiralIterator.java new file mode 100644 index 0000000..4c0be10 --- /dev/null +++ b/src/main/java/examples/spiral/SpiralIterator.java @@ -0,0 +1,134 @@ +package examples.spiral; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.deeplearning4j.util.TimeSeriesUtils; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.dataset.api.preprocessor.CompositeMultiDataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import java.util.Collections; +import java.util.List; +import java.util.Random; + +/** + * {@link MultiDataSetIterator} for spirals. Note that the same {@link MultiDataSet} instance will be used until + * reset is called. This is what the original implementation does as well + * + * @author Christian Skarby + */ +public class SpiralIterator implements MultiDataSetIterator { + + private final Generator generator; + private final int batchSize; + private SpiralSet current; + private MultiDataSetPreProcessor preProcessor = new CompositeMultiDataSetPreProcessor(); // Noop + + @Getter @AllArgsConstructor + public static class SpiralSet { + private final MultiDataSet mds; + private final List spirals; + } + + /** + * Generates {@link SpiralSet}s from a {@link SpiralFactory}. + */ + public static class Generator { + private final SpiralFactory factory; + private final double noiseSigma; + private final long nrofSamples; + private final Random rng; + + + public Generator(SpiralFactory factory, double noiseSigma, long nrofSamples, Random rng) { + this.factory = factory; + this.noiseSigma = noiseSigma; + this.nrofSamples = nrofSamples; + this.rng = rng; + } + + SpiralSet generate(int batchSize) { + + final double sampoffset = nrofSamples / (double)factory.baseNrofSamples(); + final double samprange = 1 - 2*sampoffset; + + final List spirals = factory.sample( + batchSize, + nrofSamples, + () -> sampoffset + rng.nextDouble()*samprange, + rng::nextBoolean); + + final INDArray trajFeature = Nd4j.createUninitialized( new long[] {batchSize, 2, nrofSamples}, 'f'); + final INDArray tFeature = factory.baseTs().get(NDArrayIndex.all(), NDArrayIndex.interval(0, nrofSamples)).dup('f'); + + for(int i = 0; i < batchSize; i++) { + trajFeature.tensorAlongDimension(i, 1,2).assign(spirals.get(i).trajectory()); + } + + trajFeature.addi(Nd4j.randn(trajFeature.shape(), Nd4j.getRandomFactory().getNewRandomInstance(rng.nextLong())).muli(noiseSigma)); + return new SpiralSet(new org.nd4j.linalg.dataset.MultiDataSet( + // Reverse trajectory so last time step of RNN represents first element of trajectory. + // Unsure if really needed since RNN anyways mangles whole sequence into something (mean and var) + new INDArray[] {TimeSeriesUtils.reverseTimeSeries(trajFeature), tFeature}, + new INDArray[] {trajFeature}), + Collections.unmodifiableList(spirals)); + } + } + + public SpiralIterator(Generator generator, int batchSize) { + this.generator = generator; + this.batchSize = batchSize; + } + + @Override + public MultiDataSet next(int num) { + if(current == null || num != current.getMds().getFeatures(0).size(0)) { + current = generator.generate(num); + preProcessor.preProcess(current.getMds()); + } + return current.getMds(); + } + + @Override + public void setPreProcessor(MultiDataSetPreProcessor preProcessor) { + this.preProcessor = preProcessor; + } + + @Override + public MultiDataSetPreProcessor getPreProcessor() { + return preProcessor; + } + + @Override + public boolean resetSupported() { + return false; + } + + @Override + public boolean asyncSupported() { + return true; + } + + @Override + public void reset() { + current = null; + } + + @Override + public boolean hasNext() { + return true; + } + + @Override + public MultiDataSet next() { + return next(batchSize); + } + + public SpiralSet getCurrent() { + return current; + } +} diff --git a/src/main/java/examples/spiral/TimeVae.java b/src/main/java/examples/spiral/TimeVae.java new file mode 100644 index 0000000..cb9df5f --- /dev/null +++ b/src/main/java/examples/spiral/TimeVae.java @@ -0,0 +1,123 @@ +package examples.spiral; + +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.GraphVertex; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.indexing.NDArrayIndex; + +/** + * A variational auto encoder which takes time as input. Implemented as multiple {@link ComputationGraph}s; one encoder, + * one latentTime and one decoder. These three share weights with the input {@link ComputationGraph} which has all three + * parts connected for training purposes + * + * @author Christian Skarby + */ +public class TimeVae { + + private final ComputationGraph trainingModel; + private final ComputationGraph encoder; + private final GraphVertex latentTime; + private final ComputationGraph decoder; + + public TimeVae(ComputationGraph model, String z0, String zt) { + trainingModel = model; + encoder = createEncoder(model, z0); + latentTime = createLatentTime(model, zt, encoder.numParams()); + decoder = createDecoder(model, zt, encoder.numParams() + latentTime.numParams()); + } + + private ComputationGraph createEncoder(ComputationGraph model, String z0) { + final ComputationGraphConfiguration conf = model.getConfiguration(); + + final ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder(model.conf()).graphBuilder(); + long nrofParams = 0; + for (String vertexName : conf.getTopologicalOrderStr()) { + if (!conf.getNetworkInputs().contains(vertexName) && !conf.getNetworkOutputs().contains(vertexName)) { + builder.addVertex(vertexName, conf.getVertices().get(vertexName), conf.getVertexInputs().get(vertexName).toArray(new String[0])); + nrofParams += model.getVertex(vertexName).numParams(); + } + if (vertexName.equals(z0)) { + builder.setOutputs(z0); + builder.addInputs(conf.getNetworkInputs().get(0)); + break; + } + } + + final ComputationGraph encoder = new ComputationGraph(builder.build()); + encoder.init(model.params().get(NDArrayIndex.interval(0, nrofParams)), false); + + return encoder; + } + + private GraphVertex createLatentTime(ComputationGraph model, String zt, long paramStart) { + final ComputationGraphConfiguration conf = model.getConfiguration(); + + final ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder(model.conf()).graphBuilder(); + + builder.addVertex(zt, conf.getVertices().get(zt), conf.getVertexInputs().get(zt).toArray(new String[0])); + builder.setOutputs(zt); + builder.addInputs("z0", conf.getNetworkInputs().get(1)); + final long nrofParams = model.getVertex(zt).numParams(); + + final ComputationGraph latentTime = new ComputationGraph(builder.build()); + latentTime.init(model.params().get(NDArrayIndex.interval(paramStart, paramStart + nrofParams)), false); + + return latentTime.getVertex(zt); + } + + private ComputationGraph createDecoder(ComputationGraph model, String zt, long paramStart) { + final ComputationGraphConfiguration conf = model.getConfiguration(); + + final ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder(model.conf()).graphBuilder(); + + long nrofParams = 0; + boolean include = false; + String last = ""; + for (String vertexName : conf.getTopologicalOrderStr()) { + + if (include && !conf.getNetworkOutputs().contains(vertexName)) { + builder.addVertex(vertexName, conf.getVertices().get(vertexName), conf.getVertexInputs().get(vertexName).toArray(new String[0])); + nrofParams += model.getVertex(vertexName).numParams(); + last = vertexName; + } + include |= vertexName.equals(zt); + } + builder.setOutputs(last); + builder.addInputs(zt); + + final ComputationGraph decoder = new ComputationGraph(builder.build()); + decoder.init(model.params().get(NDArrayIndex.interval(paramStart, paramStart + nrofParams)), false); + + return decoder; + } + + INDArray encode(INDArray... inputs) { + return encoder.outputSingle(inputs); + } + + INDArray timeDependency(INDArray z0, INDArray time) { + latentTime.setInputs(z0, time); + return latentTime.doForward(true, LayerWorkspaceMgr.noWorkspacesImmutable()); + } + + INDArray decode(INDArray... inputs) { + return decoder.outputSingle(inputs); + } + + ComputationGraph trainingModel() { + return trainingModel; + } + + String outputName() { + return decoder.getConfiguration().getNetworkOutputs().get(0); + } + + String qzMeanAndLogVarName() { + final String z0 = encoder.getConfiguration().getNetworkOutputs().get(0); + return encoder.getConfiguration().getVertexInputs().get(z0).iterator().next(); + } + +} diff --git a/src/main/java/examples/spiral/listener/IterationHook.java b/src/main/java/examples/spiral/listener/IterationHook.java new file mode 100644 index 0000000..54dfe09 --- /dev/null +++ b/src/main/java/examples/spiral/listener/IterationHook.java @@ -0,0 +1,27 @@ +package examples.spiral.listener; + +import org.deeplearning4j.nn.api.Model; +import org.deeplearning4j.optimize.api.BaseTrainingListener; + +/** + * Call a {@link Runnable} every iterPeriod number of iterations. + * + * @author Christian Skarby + */ +public class IterationHook extends BaseTrainingListener { + + private final int iterPeriod; + private final Runnable callback; + + public IterationHook(int iterPeriod, Runnable callback) { + this.iterPeriod = iterPeriod; + this.callback = callback; + } + + @Override + public void iterationDone(Model model, int iteration, int epoch) { + if(iteration > 0 && iteration % iterPeriod == 0) { + callback.run(); + } + } +} diff --git a/src/main/java/examples/spiral/listener/PlotActivations.java b/src/main/java/examples/spiral/listener/PlotActivations.java new file mode 100644 index 0000000..11b3a80 --- /dev/null +++ b/src/main/java/examples/spiral/listener/PlotActivations.java @@ -0,0 +1,39 @@ +package examples.spiral.listener; + +import org.deeplearning4j.nn.api.Model; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.optimize.api.BaseTrainingListener; +import org.nd4j.linalg.api.ndarray.INDArray; +import util.plot.Plot; + +import java.util.Map; + +/** + * Plots activations as a function of iteration number + * + * @author Christian Skarby + */ +public class PlotActivations extends BaseTrainingListener { + + private final Plot plot; + private final String activationName; + private final String[] labels; + + public PlotActivations(Plot plot, String activationName, String[] labels) { + this.plot = plot; + this.activationName = activationName; + this.labels = labels; + } + + @Override + public void onForwardPass(Model model, Map activations) { + final int iteration = ((ComputationGraph)model).getIterationCount(); + + final INDArray toPlot = activations.get(activationName).mean(0); + + int labelSwitch = (int)toPlot.length() / labels.length; + for(int i = 0; i < toPlot.length(); i++) { + plot.plotData( labels[i / labelSwitch] + "_" + (i % labelSwitch), iteration, toPlot.getDouble(i)); + } + } +} diff --git a/src/main/java/examples/spiral/listener/PlotDecodedOutput.java b/src/main/java/examples/spiral/listener/PlotDecodedOutput.java new file mode 100644 index 0000000..38cc553 --- /dev/null +++ b/src/main/java/examples/spiral/listener/PlotDecodedOutput.java @@ -0,0 +1,39 @@ +package examples.spiral.listener; + +import org.deeplearning4j.nn.api.Model; +import org.deeplearning4j.optimize.api.BaseTrainingListener; +import org.nd4j.linalg.api.ndarray.INDArray; +import util.plot.Plot; + +import java.util.Map; + +/** + * Plots the decoded output + * + * @author Christian Skarby + */ +public class PlotDecodedOutput extends BaseTrainingListener { + + private final SpiralPlot plot; + private final String outputName; + private final String plotLabel; + private final int batchNrToPlot; + + public PlotDecodedOutput(Plot plot, String outputName, int batchNrToPlot) { + this(new SpiralPlot(plot), outputName, batchNrToPlot); + } + + public PlotDecodedOutput(SpiralPlot plot, String outputName, int batchNrToPlot) { + this.plot = plot; + this.outputName = outputName; + this.batchNrToPlot = batchNrToPlot; + this.plotLabel = outputName + " " + batchNrToPlot; + this.plot.createSeries(plotLabel); + } + + @Override + public void onForwardPass(Model model, Map activations) { + final INDArray toPlot = activations.get(outputName); + plot.plot(plotLabel, toPlot, batchNrToPlot); + } +} diff --git a/src/main/java/examples/spiral/listener/SpiralPlot.java b/src/main/java/examples/spiral/listener/SpiralPlot.java new file mode 100644 index 0000000..46e073b --- /dev/null +++ b/src/main/java/examples/spiral/listener/SpiralPlot.java @@ -0,0 +1,62 @@ +package examples.spiral.listener; + +import org.nd4j.linalg.api.ndarray.INDArray; +import util.plot.Plot; + +import java.util.ArrayList; +import java.util.List; + +/** + * Simple plot util for plotting spirals from a 2D {@link org.nd4j.linalg.api.ndarray.INDArray} where each column is is x + * and y coordinates for a time index. + * + * @author Christian Skarby + */ +public class SpiralPlot { + + private final Plot plot; + + public SpiralPlot(Plot plot) { + this.plot = plot; + } + + /** + * Create a label in the plot. Call this method before plotting anything in any label in the same plot to avoid + * null pointer exceptions in the plot thread. + * @param label Series to create + */ + public void createSeries(String label) { + plot.createSeries(label); + } + + /** + * Plot data assuming each row of the given {@link INDArray} is an x and y coordinate pair. + * @param label Label for curve to plot + * @param toPlot Data to plot + */ + public void plot(String label, INDArray toPlot) { + plot.clearData(label); + final List x = toDoubleList(toPlot, 0); + final List y = toDoubleList(toPlot, 1); + plot.plotData(label, x, y); + } + + /** + * Plot data assuming each a 3D {@link INDArray} where each element along the first dimension is a set of + * x and y coordinate pairs. + * @param label Label for curve to plot + * @param toPlot Data to plot + * @param batchNr which set of x,y pairs to plot + */ + public void plot(String label, INDArray toPlot, int batchNr) { + plot(label, toPlot.tensorAlongDimension(batchNr, 1,2)); + } + + private static List toDoubleList(INDArray toPlot, int row) { + final List out = new ArrayList<>(); + for(double d: toPlot.getRow(row).toDoubleVector()) { + out.add(d); + } + return out; + } +} diff --git a/src/main/java/examples/spiral/loss/NormKLDLoss.java b/src/main/java/examples/spiral/loss/NormKLDLoss.java new file mode 100644 index 0000000..fa0eeb3 --- /dev/null +++ b/src/main/java/examples/spiral/loss/NormKLDLoss.java @@ -0,0 +1,103 @@ +package examples.spiral.loss; + +import lombok.Data; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.linalg.primitives.Pair; + +/** + * Kullback-Leibler divergence assuming gaussian distribution + * Reimplementation of https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py + * + * @author Christian Skarby + */ +@Data +public class NormKLDLoss implements ILossFunction { + + @Override + public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + if(average) { + return scoreArray(labels, preOutput, activationFn, mask).meanNumber().doubleValue(); + } + return scoreArray(labels, preOutput, activationFn, mask).sumNumber().doubleValue(); + } + + @Override + public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + return scoreArray(labels, preOutput, activationFn, mask); + } + + @Override + public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + final INDArray output = activationFn.getActivation(preOutput.dup(), true); + + long size = output.size(1) / 2; + INDArray mean = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size)).dup(); + INDArray logVar = output.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)).dup(); + + final INDArray meanGrad = normalKlGradMu1(mean); + final INDArray logvarGrad = normalKlGradLv1(logVar); + + return activationFn.backprop(output, Nd4j.hstack(meanGrad, logvarGrad)).getFirst(); + } + + @Override + public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + return new Pair<>( + computeScore(labels, preOutput, activationFn, mask, average), + computeGradient(labels, preOutput, activationFn, mask)); + } + + private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + + if(labels.amaxNumber().doubleValue() > 0) { + // Note that variance is given in log scale + // Should be straight forward to implement, but not needed + throw new UnsupportedOperationException("Targets other than N(0,1) not supported! Got: " + labels); + } + + if(mask != null) { + // Should be straight forward to implement, but not needed + throw new UnsupportedOperationException("Masking not supported"); + } + + final INDArray output = activationFn.getActivation(preOutput.dup(), true); + + long size = output.size(1) / 2; + INDArray mean = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size)).dup(); + INDArray logVar = output.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)).dup(); + + return normalKl(mean, logVar).sum(1); + } + + @Override + public String name() { + return "NormKLDLoss()"; + } + + private INDArray normalKl(INDArray mu1, INDArray lv1 ) { + final double mu2 = 0; + final INDArray v1 = Transforms.exp(lv1); + final double v2 = 1; + final INDArray lstd1 = lv1.div(2); + final double lstd2 = 0; + return lstd1.rsubi(lstd2).addi( + v1.addi(Transforms.pow(mu1.sub(mu2), 2, false)).divi(2 * v2) + ).subi(0.5); + } + + private INDArray normalKlGradMu1(INDArray mu1) { + final double mu2 = 0; + final double v2 = 1; + return mu1.sub(mu2).muli(v2); + } + + private INDArray normalKlGradLv1(INDArray lv1) { + final double lv2 = 0; + return Transforms.exp(lv1.sub(lv2)).subi(1).divi(2); + } +} diff --git a/src/main/java/examples/spiral/loss/NormLogLikelihoodLoss.java b/src/main/java/examples/spiral/loss/NormLogLikelihoodLoss.java new file mode 100644 index 0000000..cbac640 --- /dev/null +++ b/src/main/java/examples/spiral/loss/NormLogLikelihoodLoss.java @@ -0,0 +1,105 @@ +package examples.spiral.loss; + +import lombok.Data; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +/** + * Log-likelihood(ish?) loss under gaussian assumptions. + * Reimplementation of https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py + * + * @author Christian Skarby + */ +@Data +public class NormLogLikelihoodLoss implements ILossFunction { + + private final static double log2pi = Math.log(2 * Math.PI); + + private final double logNoiseVar; + + public NormLogLikelihoodLoss(@JsonProperty("noiseSigma") double noiseSigma) { + this.logNoiseVar = 2 * Math.log(noiseSigma); + } + + @Override + public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + if(average) { + return scoreArray(labels, preOutput, activationFn, mask).meanNumber().doubleValue(); + } + return scoreArray(labels, preOutput, activationFn, mask).sumNumber().doubleValue(); + } + + @Override + public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + return scoreArray(labels, preOutput, activationFn, mask); + } + + @Override + public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + final INDArray output = activationFn.getActivation(preOutput.dup(), true); + + if (!labels.equalShapes(output)) { + Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), output.shape()); + } + + final INDArray predGrad = logNormalPdfGradient(labels, output); + + return activationFn.backprop(output, predGrad).getFirst(); + } + + @Override + public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + return new Pair<>( + computeScore(labels, preOutput, activationFn, mask, average), + computeGradient(labels, preOutput, activationFn, mask)); + } + + private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + + if(mask != null) { + // Should be straight forward to implement, but not needed + throw new UnsupportedOperationException("Masking not supported"); + } + + final INDArray output = activationFn.getActivation(preOutput.dup(), true); + + final int[] sumDims = new int[output.rank()-1]; + for(int i = 0; i < sumDims.length; i++) { + sumDims[i] = i+1; + } + + if (!labels.equalShapes(output)) { + Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), output.shape()); + } + + return logNormalPdf(labels, output).sum(sumDims).negi(); + } + + + @Override + public String name() { + return "NormLogLikelihoodLoss(" + logNoiseVar +")"; + } + + private INDArray logNormalPdf(INDArray labels, INDArray output) { + // Expression from original repo. + // Similar to log-likelihood assuming parameters are from a gaussian distribution, but not 100% same? + return Transforms.pow(output.rsub(labels), 2, false) + .divi(Math.exp(logNoiseVar)) + .addi(log2pi) + .addi(logNoiseVar) + .muli(-0.5); + } + + private INDArray logNormalPdfGradient(INDArray labels, INDArray output) { + // 2 from derivative of exponent and -0.5 constant cancel out + // Original implementation uses mean of loss along batch dimension. Dl4j does however divide gradients by minibatch + // in BaseMultiLayerUpdater + return output.sub(labels).muli(Math.exp(-logNoiseVar)); + } +} diff --git a/src/main/java/examples/spiral/vertex/conf/ConcatRnn.java b/src/main/java/examples/spiral/vertex/conf/ConcatRnn.java new file mode 100644 index 0000000..a982a54 --- /dev/null +++ b/src/main/java/examples/spiral/vertex/conf/ConcatRnn.java @@ -0,0 +1,108 @@ +package examples.spiral.vertex.conf; + +import lombok.Data; +import lombok.val; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.LayerValidation; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Collection; +import java.util.Map; + +/** + * Very simple RNN which concatenates input and a hidden state instead of adding them. + * It implements {@code out_t = activationFn( concat(in_t, out_(t-1)) * inWeight + bias)}. + * + * Same type of RNN as used in https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py + * + * @author Christian Skarby + */ +@Data +public class ConcatRnn extends BaseRecurrentLayer { + + private boolean addHiddenToNin = false; + + protected ConcatRnn(ConcatRnn.Builder builder) { + super(builder); + } + + private ConcatRnn() { + + } + + @Override + public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, + int layerIndex, INDArray layerParamsView, boolean initializeParams) { + LayerValidation.assertNInNOutSet("ConcatRnn", getLayerName(), layerIndex, getNIn(), getNOut()); + + examples.spiral.vertex.impl.ConcatRnn ret = + new examples.spiral.vertex.impl.ConcatRnn(conf); + ret.setListeners(trainingListeners); + ret.setIndex(layerIndex); + ret.setParamsViewArray(layerParamsView); + setNIn(getNIn() + getNOut()); + Map paramTable = DefaultParamInitializer.getInstance().init(conf, layerParamsView, initializeParams); + setNIn(getNIn() - getNOut()); + ret.setParamTable(paramTable); + ret.setConf(conf); + return ret; + } + + @Override + public ParamInitializer initializer() { + return new DefaultParamInitializer() { + @Override + public long numParams(org.deeplearning4j.nn.conf.layers.Layer l) { + FeedForwardLayer layerConf = (FeedForwardLayer) l; + val nOut = layerConf.getNOut(); + val nIn = layerConf.getNIn() + nOut; + return (nIn * nOut + (hasBias(l) ? nOut : 0)); //weights + bias + } + }; + } + + @Override + public double getL1ByParam(String paramName) { + switch (paramName) { + case DefaultParamInitializer.WEIGHT_KEY: + return l1; + case DefaultParamInitializer.BIAS_KEY: + return l1Bias; + default: + throw new IllegalStateException("Unknown parameter: \"" + paramName + "\""); + } + } + + @Override + public double getL2ByParam(String paramName) { + switch (paramName) { + case DefaultParamInitializer.WEIGHT_KEY: + return l2; + case DefaultParamInitializer.BIAS_KEY: + return l2Bias; + default: + throw new IllegalStateException("Unknown parameter: \"" + paramName + "\""); + } + } + + @Override + public LayerMemoryReport getMemoryReport(InputType inputType) { + return null; + } + + public static class Builder extends BaseRecurrentLayer.Builder { + + @Override + public ConcatRnn build() { + return new ConcatRnn(this); + } + } +} diff --git a/src/main/java/examples/spiral/vertex/conf/SampleGaussianVertex.java b/src/main/java/examples/spiral/vertex/conf/SampleGaussianVertex.java new file mode 100644 index 0000000..6c46362 --- /dev/null +++ b/src/main/java/examples/spiral/vertex/conf/SampleGaussianVertex.java @@ -0,0 +1,99 @@ +package examples.spiral.vertex.conf; + +import lombok.Data; +import org.deeplearning4j.nn.conf.graph.GraphVertex; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.conf.memory.MemoryReport; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.shade.jackson.annotation.JsonProperty; + + +/** + * Takes samples from a Gaussian process where mean and std are inputs, typically from a set of layers acting as a + * variational auto encoder. + * + * @author Christian Skarby + */ +@Data +public class SampleGaussianVertex extends GraphVertex { + + private final long seed; + + public SampleGaussianVertex(@JsonProperty("seed") long seed) { + this.seed = seed; + } + + @Override + public GraphVertex clone() { + return new SampleGaussianVertex(seed); + } + + @Override + public boolean equals(Object o) { + if(o instanceof SampleGaussianVertex) { + SampleGaussianVertex other = (SampleGaussianVertex)o; + return other.seed == this.seed; + } + return false; + } + + @Override + public int hashCode() { + return (int)seed; + } + + @Override + public long numParams(boolean backprop) { + return 0; + } + + @Override + public int minVertexInputs() { + return 1; + } + + @Override + public int maxVertexInputs() { + return 1; + } + + @Override + public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) { + final Random rng = Nd4j.getRandomFactory().getNewRandomInstance(seed); + return new examples.spiral.vertex.impl.SampleGaussianVertex(graph, name, idx, shape -> Nd4j.randn(shape, rng)); + } + + @Override + public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { + if(vertexInputs.length != 1) { + throw new IllegalArgumentException(this.getClass().getSimpleName() +" must have one inputs!"); + } + + for(InputType inputType: vertexInputs) { + if(inputType.getType() != InputType.Type.FF) { + throw new IllegalArgumentException(this.getClass().getSimpleName() + " only supports feedforward input! Got: " + inputType); + } + + if(inputType.arrayElementsPerExample() % 2 != 0) { + throw new IllegalArgumentException(this.getClass().getSimpleName() + " input size must be even! Got: " + inputType); + } + } + + return InputType.feedForward(vertexInputs[0].arrayElementsPerExample() / 2); + } + + @Override + public MemoryReport getMemoryReport(InputType... inputTypes) { + InputType outputType = getOutputType(-1, inputTypes); + + return new LayerMemoryReport.Builder(null, SampleGaussianVertex.class, inputTypes[0], outputType).standardMemory(0, 0) //No params + .workingMemory(0, 0, 0, 0) //No working memory in addition to activations/epsilons + .cacheMemory(0, 0) //No caching + .build(); + } +} diff --git a/src/main/java/examples/spiral/vertex/impl/ConcatRnn.java b/src/main/java/examples/spiral/vertex/impl/ConcatRnn.java new file mode 100644 index 0000000..42eb25d --- /dev/null +++ b/src/main/java/examples/spiral/vertex/impl/ConcatRnn.java @@ -0,0 +1,205 @@ +package examples.spiral.vertex.impl; + +import lombok.val; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer; +import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp; +import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import static org.nd4j.linalg.indexing.NDArrayIndex.*; + +/** + * Very simple RNN which concatenates input and a hidden state instead of adding them. + * It implements {@code out_t = activationFn( concat(in_t, out_(t-1)) * inWeight + bias)}. + * + * Same type of RNN as used in https://github.com/rtqichen/torchdiffeq/blob/master/examples/latent_ode.py + * + * @author Christian Skarby + */ +public class ConcatRnn extends BaseRecurrentLayer { + public static final String STATE_KEY_PREV_ACTIVATION = "prevAct"; + + public ConcatRnn(NeuralNetConfiguration conf) { + super(conf); + } + + @Override + public INDArray rnnTimeStep(INDArray input, LayerWorkspaceMgr workspaceMgr) { + setInput(input, workspaceMgr); + INDArray last = stateMap.get(STATE_KEY_PREV_ACTIVATION); + INDArray out = activateHelper(last, false, false, workspaceMgr).getFirst(); + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ + stateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)).dup()); + } + return out; + } + + @Override + public INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT, LayerWorkspaceMgr workspaceMgr) { + setInput(input, workspaceMgr); + INDArray last = tBpttStateMap.get(STATE_KEY_PREV_ACTIVATION); + INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst(); + if(storeLastForTBPTT){ + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ + tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1))); + } + } + return out; + } + + @Override + public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { + return tbpttBackpropGradient(epsilon, -1, workspaceMgr); + } + + @Override + public Pair tbpttBackpropGradient(INDArray epsilonIn, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(true); + + INDArray epsilon = epsilonIn; + if(epsilon.ordering() != 'f' || !Shape.hasDefaultStridesForShape(epsilon)) + epsilon = epsilon.dup('f'); + + //First: Do forward pass to get gate activations and Zs + Pair p = activateHelper(Nd4j.zeros(new long[] {input.size(0), layerConf().getNOut()}, 'f'), true, true, workspaceMgr); + + INDArray w = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, true, workspaceMgr); + + INDArray wg = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY); + INDArray bg = gradientViews.get(DefaultParamInitializer.BIAS_KEY); + gradientsFlattened.assign(0); + + IActivation a = layerConf().getActivationFn(); + + val tsLength = input.size(2); + + INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.shape(), 'f'); + + INDArray eps = Nd4j.zeros(new long[] {epsilon.size(0), layerConf().getNOut() + layerConf().getNIn()}, 'f'); + long end; + if(tbpttBackLength > 0){ + end = Math.max(0, tsLength-tbpttBackLength); + } else { + end = 0; + } + for( long i = tsLength-1; i>= end; i--){ + INDArray dldaCurrent = epsilon.get(all(), all(), point(i)); + INDArray zCurrent = p.getSecond().get(all(), all(), point(i)); + INDArray inCurrent = input.get(all(), all(), point(i)); + INDArray epsOutCurrent = epsOut.get(all(), all(), point(i)); + INDArray dldzCurrent = a.backprop(zCurrent.dup(), dldaCurrent.dup()).getFirst(); + + //Handle masking + INDArray maskCol = null; + if( maskArray != null){ + //Mask array: shape [minibatch, tsLength] + //If mask array is present (for example, with bidirectional RNN) -> need to zero out these errors to + // avoid using errors from a masked time step to calculate the parameter gradients + maskCol = maskArray.getColumn(i); + dldzCurrent.muliColumnVector(maskCol); + } + + //weight gradients: + Nd4j.gemm(inCurrent, dldzCurrent, wg, true, false, 1.0, 1.0); + + //Bias gradients + bg.addi(dldzCurrent.sum(0)); + + //Epsilon out to layer below (i.e., dL/dIn) + + Nd4j.gemm(dldzCurrent, w, eps, false, true, 1.0, 0.0); + epsOutCurrent.assign(eps.get(all(), interval(0, input.size(1)))); + + if( maskArray != null){ + //If mask array is present: Also need to zero out errors to avoid sending anything but 0s to layer below for masked steps + epsOutCurrent.muliColumnVector(maskCol); + } + } + + weightNoiseParams.clear(); + + Gradient g = new DefaultGradient(gradientsFlattened); + g.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, wg); + g.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, bg); + + epsOut = backpropDropOutIfPresent(epsOut); + return new Pair<>(g, epsOut); + } + + @Override + public boolean isPretrainLayer() { + return false; + } + + @Override + public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){ + return activateHelper(Nd4j.zeros(new long[] {input.size(0), layerConf().getNOut()}, 'f'), training, false, workspaceMgr).getFirst(); + } + + private Pair activateHelper(INDArray prevStepOut, boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr){ + assertInputSet(false); + Preconditions.checkState(input.rank() == 3, + "3D input expected to RNN layer expected, got " + input.rank()); + + applyDropOutIfNecessary(training, workspaceMgr); + val m = input.size(0); + val tsLength = input.size(2); + val nOut = layerConf().getNOut(); + + INDArray w = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr); + INDArray b = getParamWithNoise(DefaultParamInitializer.BIAS_KEY, training, workspaceMgr); + + if(input.ordering() != 'f' || Shape.strideDescendingCAscendingF(input)) + input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'f'); + + //TODO implement 'mmul across time' optimization + + //Minor performance optimization: do the "add bias" first: + INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{m, nOut, tsLength}, 'f'); + INDArray outZ = (forBackprop ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, out.shape()) : null); + + Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, b, out, 1)); + + IActivation a = layerConf().getActivationFn(); + + INDArray prevStep = prevStepOut; + for( int i=0; i(out, outZ); + } +} + diff --git a/src/main/java/examples/spiral/vertex/impl/SampleGaussianVertex.java b/src/main/java/examples/spiral/vertex/impl/SampleGaussianVertex.java new file mode 100644 index 0000000..d91cd47 --- /dev/null +++ b/src/main/java/examples/spiral/vertex/impl/SampleGaussianVertex.java @@ -0,0 +1,97 @@ +package examples.spiral.vertex.impl; + +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.linalg.primitives.Pair; + +/** + * Takes samples from a Gaussian process where mean and std are inputs, typically from a set of layers acting as a + * variational auto encoder. + * + * @author Christian Skarby + */ +public class SampleGaussianVertex extends BaseGraphVertex { + + private final EpsSupplier rng; + private INDArray lastEps; + + public interface EpsSupplier { + INDArray get(long[] shape); + } + + public SampleGaussianVertex(ComputationGraph graph, String name, int vertexIndex, EpsSupplier rng) { + super(graph, name, vertexIndex, null, null); + this.rng = rng; + } + + @Override + public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { + if (!canDoForward()) + throw new IllegalStateException("Cannot do forward pass: inputs not set"); + + final INDArray input = getInputs()[0].dup(); + final long size = input.size(1) / 2; + + // Dup due to dl4j issue #7263 + INDArray mean = input.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size)).dup(); + INDArray logVar = input.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)).dup(); + + lastEps = rng.get(mean.shape()).mul(Transforms.exp(logVar.mul(0.5))); + if (training) { + lastEps = workspaceMgr.leverageTo(ArrayType.INPUT, lastEps); + } + + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, lastEps.add(mean)); + } + + @Override + public Pair doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { + if (!canDoBackward()) + throw new IllegalStateException("Cannot do backward pass: errors not set"); + + final INDArray epsMean = getEpsilon(); + + // dL/dz * dz/dlogVar = epsilon * d/dlogVar(lastEps * e^0.5logVar + mean) = epsilon*0.5*lastEps*e^0.5logVar + final INDArray epsLogVar = getEpsilon().dup().mul(lastEps).mul(0.5); + + final INDArray combinedEps = Nd4j.hstack(epsMean, epsLogVar); + return new Pair<>(null, new INDArray[]{ + workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, combinedEps) + }); + } + + @Override + public boolean hasLayer() { + return false; + } + + @Override + public Layer getLayer() { + return null; + } + + @Override + public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) { + if (backpropGradientsViewArray != null) + throw new IllegalArgumentException("Vertex does not have gradients; gradients view array cannot be set here"); + } + + @Override + public Pair feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) { + throw new UnsupportedOperationException("Not implemented yet!"); + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + "(id=" + this.getVertexIndex() + ",name=\"" + this.getVertexName() + "\")"; + } +} diff --git a/src/main/java/ode/solve/api/FirstOrderMultiStepSolver.java b/src/main/java/ode/solve/api/FirstOrderMultiStepSolver.java new file mode 100644 index 0000000..f4565d9 --- /dev/null +++ b/src/main/java/ode/solve/api/FirstOrderMultiStepSolver.java @@ -0,0 +1,25 @@ +package ode.solve.api; + +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Marker interface for solvers which can handle multiple time steps (i.e t.length() > 2). Does not add any new + * functionality, just adds API clarity. + * + * @author Christian Skarby + */ +public interface FirstOrderMultiStepSolver extends FirstOrderSolver { + + /** + * Compute an estimated values of Y(t1), Y(t2), ..., Y(tN) for which + * + * @param equation is the function F(Y(t)) for which dY/dt = F(Y(t)) + * @param t is a vector of times t0, t1, t2, ..., tN for which Y(t1), Y(t2), ..., Y(tN) is sought + * @param y0 is the value of Y(t) at t = t0, i.e. Y(t0) + * @param yOut will contain the estimated values of Y(t1), Y(t2), ..., Y(tN). + * @return yOut, same instance as input param. Note that yOut does not contain Y(t0) in order to comply + * with {@link FirstOrderSolver} API in the sense that if t has only two values, only Y(t1) is returned + */ + @Override + INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, INDArray yOut); +} diff --git a/src/main/java/ode/solve/api/FirstOrderSolver.java b/src/main/java/ode/solve/api/FirstOrderSolver.java index 505bdb1..d0fed27 100644 --- a/src/main/java/ode/solve/api/FirstOrderSolver.java +++ b/src/main/java/ode/solve/api/FirstOrderSolver.java @@ -11,25 +11,28 @@ public interface FirstOrderSolver { /** - * Compute estimated value of Y(t1) for which - * @param equation is the function F(Y(t) for which dY/dt = F(Y(t) - * @param t is a vector with initial value t0 and desired value t1 - * @param y0 is the value of Y(t) at t = t0 - * @param yOut will contain the estimated value of Y(t1). Implementations shall assume that y0 and yOut - * are the same instance. + * Compute estimated value of Y(t1)) for which + * + * @param equation is the function F(Y(t)) for which dY/dt = F(Y(t)) + * @param t is a vector with initial time t0 and time t1 for which Y(t1) is sought + * @param y0 is the value of Y(t) at t = t0, i.e. Y(t0) + * @param yOut will contain the estimated value of Y(t1). Implementations may not make any assumption + * as to whether y0 and yOut are the same instance or not. * @return yOut, same instance as input param */ INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, INDArray yOut); /** * Add {@link StepListener}s which will be notified of steps taken + * * @param listeners listeners to add */ void addListener(StepListener... listeners); /** * Clear the given listeners. Clear all listeners if empty + * * @param listeners listeners to remove */ - void clearListeners(StepListener ... listeners); + void clearListeners(StepListener... listeners); } diff --git a/src/main/java/ode/solve/api/StepListener.java b/src/main/java/ode/solve/api/StepListener.java index 5067818..4f44fd4 100644 --- a/src/main/java/ode/solve/api/StepListener.java +++ b/src/main/java/ode/solve/api/StepListener.java @@ -1,5 +1,6 @@ package ode.solve.api; +import ode.solve.impl.util.SolverState; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -18,12 +19,11 @@ public interface StepListener { /** * Indicates a step has been taken - * @param currTime Current time (after step has been taken + * @param solverState Current state of the solver * @param step Taken step * @param error Estimated error - * @param y current state (i.e estimated state at currTime) */ - void step(INDArray currTime, INDArray step, INDArray error, INDArray y); + void step(SolverState solverState, INDArray step, INDArray error); /** * Indicates that previous step was the last step taken diff --git a/src/main/java/ode/solve/commons/StepListenerAdapter.java b/src/main/java/ode/solve/commons/StepListenerAdapter.java index d7079a7..cfabd4b 100644 --- a/src/main/java/ode/solve/commons/StepListenerAdapter.java +++ b/src/main/java/ode/solve/commons/StepListenerAdapter.java @@ -1,6 +1,7 @@ package ode.solve.commons; import ode.solve.api.StepListener; +import ode.solve.impl.util.StateContainer; import org.apache.commons.math3.exception.MaxCountExceededException; import org.apache.commons.math3.ode.sampling.StepHandler; import org.apache.commons.math3.ode.sampling.StepInterpolator; @@ -22,10 +23,12 @@ public void init(double t0, double[] y0, double t) { @Override public void handleStep(StepInterpolator interpolator, boolean isLast) throws MaxCountExceededException { wrappedListener.step( - Nd4j.create(1).assign(interpolator.getCurrentTime()), + new StateContainer( + interpolator.getCurrentTime(), + interpolator.getInterpolatedState(), + interpolator.getInterpolatedDerivatives()), Nd4j.create(1).assign(interpolator.getCurrentTime() - interpolator.getPreviousTime()), - null, // error not observable :( - Nd4j.create(interpolator.getInterpolatedState()) + null // error not observable :( ); } } diff --git a/src/main/java/ode/solve/impl/AdaptiveRungeKuttaSolver.java b/src/main/java/ode/solve/impl/AdaptiveRungeKuttaSolver.java index 09747ba..0b87fe7 100644 --- a/src/main/java/ode/solve/impl/AdaptiveRungeKuttaSolver.java +++ b/src/main/java/ode/solve/impl/AdaptiveRungeKuttaSolver.java @@ -115,7 +115,7 @@ public INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, equation, t.getScalar(0).dup(), yOut.assign(y0), - tableu.c.length() + 1); + tableu.cMid); listener.begin(t, y0); @@ -134,6 +134,7 @@ private void solve(FirstOrderEquationWithState equation, INDArray t) { final INDArray error = Nd4j.create(1); // Alg variable used for new steps final INDArray step = stepPolicy.initializeStep(equation, t); + // Alg variable for where next step starts final TimeLimit timeLimit = t.argMax().getInt(0) == 1 ? new TimeLimitForwards(t.getDouble(1), equation.time()) : @@ -170,7 +171,9 @@ private boolean acceptStep(FirstOrderEquationWithState equation, INDArray step, // local error is small enough: accept the step, equation.update(); - listener.step(equation.time(), step, error, equation.getCurrentState()); + listener.step(equation, step, error); + + equation.shiftDerivative(); return true; } diff --git a/src/main/java/ode/solve/impl/DormandPrince54Solver.java b/src/main/java/ode/solve/impl/DormandPrince54Solver.java index deb77c5..2a8e1f5 100644 --- a/src/main/java/ode/solve/impl/DormandPrince54Solver.java +++ b/src/main/java/ode/solve/impl/DormandPrince54Solver.java @@ -42,7 +42,10 @@ public class DormandPrince54Solver implements FirstOrderSolver { }) .c(new double[]{ 1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0 - }); + }) + .cMid(new double[]{ + 6025192743d / 30085553152d / 2d, 0, 51252292925d / 65400821598d / 2d, -2691868925d / 45128329728d / 2d, + 187940372067d / 1594534317056d / 2d, -1776094331 / 19743644256d / 2d, 11237099 / 235043384d / 2d}); private final AdaptiveRungeKuttaSolver solver; diff --git a/src/main/java/ode/solve/impl/DummyIteration.java b/src/main/java/ode/solve/impl/DummyIteration.java index f447cb2..6e8c5e0 100644 --- a/src/main/java/ode/solve/impl/DummyIteration.java +++ b/src/main/java/ode/solve/impl/DummyIteration.java @@ -4,6 +4,7 @@ import ode.solve.api.FirstOrderSolver; import ode.solve.api.StepListener; import ode.solve.impl.util.AggStepListener; +import ode.solve.impl.util.StateContainer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,7 +33,11 @@ public INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, for (int i = 0; i < nrofIters; i++) { next = equation.calculateDerivative(next, t.getColumn(0), yOut); - listener.step(t.getColumn(0), Nd4j.zeros(1), Nd4j.zeros(1), next); + listener.step( + new StateContainer(t.getColumn(0), + next, + Nd4j.zeros(next.shape())), + Nd4j.zeros(1), Nd4j.zeros(1)); } listener.done(); diff --git a/src/main/java/ode/solve/impl/InterpolatingMultiStepSolver.java b/src/main/java/ode/solve/impl/InterpolatingMultiStepSolver.java new file mode 100644 index 0000000..544e35a --- /dev/null +++ b/src/main/java/ode/solve/impl/InterpolatingMultiStepSolver.java @@ -0,0 +1,68 @@ +package ode.solve.impl; + +import ode.solve.api.FirstOrderEquation; +import ode.solve.api.FirstOrderMultiStepSolver; +import ode.solve.api.FirstOrderSolver; +import ode.solve.api.StepListener; +import ode.solve.impl.util.InterpolatingStepListener; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.indexing.SpecifiedIndex; + +import java.util.Arrays; + +/** + * {@link FirstOrderMultiStepSolver} which uses an {@link InterpolatingStepListener} to estimate intermediate time + * steps. Will invoke the the wrapped {@link FirstOrderSolver} with the first and last time steps and return the + * interpolated state (i.e not the output from the wrapped solver). + */ +public class InterpolatingMultiStepSolver implements FirstOrderMultiStepSolver { + + private final FirstOrderSolver solver; + + public InterpolatingMultiStepSolver(FirstOrderSolver solver) { + this.solver = solver; + } + + @Override + public INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, INDArray yOut) { + if(!t.isVector()) { + throw new IllegalStateException("t must be a vector! Was shape " + Arrays.toString(t.shape())); + } + + final INDArray wantedTimes = t.get(NDArrayIndex.interval(1, t.length())).reshape(getTimeShapeForLength(t, t.length()-1)); + + final InterpolatingStepListener interpolation = new InterpolatingStepListener(wantedTimes, yOut); + + addListener(interpolation); + + // Note: skip first time index in order to comply with API description + final INDArray tStartEnd = t.get(new SpecifiedIndex(0, t.length()-1)).reshape(getTimeShapeForLength(t, 2)); + solver.integrate(equation, tStartEnd, y0, y0); + + clearListeners(interpolation); + + // interpolation has made sure yOut contains estimated solutions at desired times + return yOut; + } + + private long[] getTimeShapeForLength(INDArray t, long length) { + final long[] shape = t.shape().clone(); + for(int i = 0; i < shape.length; i++) { + if(shape[i] != 1) { + shape[i] = length; + } + } + return shape; + } + + @Override + public void addListener(StepListener... listeners) { + solver.addListener(listeners); + } + + @Override + public void clearListeners(StepListener... listeners) { + solver.clearListeners(listeners); + } +} diff --git a/src/main/java/ode/solve/impl/SingleSteppingMultiStepSolver.java b/src/main/java/ode/solve/impl/SingleSteppingMultiStepSolver.java new file mode 100644 index 0000000..7f7a338 --- /dev/null +++ b/src/main/java/ode/solve/impl/SingleSteppingMultiStepSolver.java @@ -0,0 +1,55 @@ +package ode.solve.impl; + +import ode.solve.api.FirstOrderEquation; +import ode.solve.api.FirstOrderMultiStepSolver; +import ode.solve.api.FirstOrderSolver; +import ode.solve.api.StepListener; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +/** + * {@link FirstOrderSolver} for multiple time steps. Will invoke the wrapped {@link FirstOrderSolver} multiple times for + * every pair {@code t[n], t[n+1], 0 < n < t.length()-1} and concatenate the solution. + * + * @author Christian Skarby + */ +public class SingleSteppingMultiStepSolver implements FirstOrderMultiStepSolver { + + private final FirstOrderSolver solver; + + public SingleSteppingMultiStepSolver(FirstOrderSolver solver) { + this.solver = solver; + } + + @Override + public INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, INDArray yOut) { + if(yOut.size(0) + 1 != t.length()) { + throw new IllegalArgumentException("yOut must have one element less in first dimension compared to t!"); + } + final INDArrayIndex[] yOutAccess = new INDArrayIndex[yOut.rank()]; + for(int dim = 0; dim < yOutAccess.length; dim++) { + yOutAccess[dim] = NDArrayIndex.all(); + } + + for (int step = 1; step < t.length(); step++) { + yOutAccess[0] = NDArrayIndex.point(step -2); + final INDArray yPrev = step == 1 ? y0 : yOut.get(yOutAccess); + yOutAccess[0] = NDArrayIndex.point(step-1); + final INDArray yCurr = yOut.get(yOutAccess); + solver.integrate(equation, t.get(NDArrayIndex.interval(step - 1, step+1)), yPrev, yCurr); + } + + return yOut; + } + + @Override + public void addListener(StepListener... listeners) { + solver.addListener(listeners); + } + + @Override + public void clearListeners(StepListener... listeners) { + solver.clearListeners(listeners); + } +} diff --git a/src/main/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicy.java b/src/main/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicy.java index bf54de8..fbaf574 100644 --- a/src/main/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicy.java +++ b/src/main/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicy.java @@ -119,6 +119,7 @@ public INDArray initializeStep(FirstOrderEquationWithState equation, INDArray t) // step size is computed such that // step^order * max (||y'/tol||, ||y''/tol||) = 0.01 + // TODO: Should be abs(yDDotOnScale) for when negative step? final INDArray maxInv2 = max(sqrt(yDotOnScale2), yDDotOnScale); final INDArray step1 = maxInv2.getDouble(0) < 1e-15 ? max(MIN_H, abs(step).muli(0.001)) : diff --git a/src/main/java/ode/solve/impl/util/AggStepListener.java b/src/main/java/ode/solve/impl/util/AggStepListener.java index 5386b4a..06d6f1e 100644 --- a/src/main/java/ode/solve/impl/util/AggStepListener.java +++ b/src/main/java/ode/solve/impl/util/AggStepListener.java @@ -24,9 +24,9 @@ public void begin(INDArray t, INDArray y0) { } @Override - public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) { + public void step(SolverState solverState, INDArray step, INDArray error) { for (StepListener listener : listeners) { - listener.step(currTime, step, error, y); + listener.step(solverState, step, error); } } diff --git a/src/main/java/ode/solve/impl/util/ButcherTableu.java b/src/main/java/ode/solve/impl/util/ButcherTableu.java index 96bfe1e..5cf9898 100644 --- a/src/main/java/ode/solve/impl/util/ButcherTableu.java +++ b/src/main/java/ode/solve/impl/util/ButcherTableu.java @@ -18,6 +18,7 @@ public class ButcherTableu { public final INDArray b; public final INDArray bStar; public final INDArray c; + public final double[] cMid; /** * Create a new {@link Builder} instance @@ -27,11 +28,12 @@ public static Builder builder() { return new Builder(); } - private ButcherTableu(INDArray[] a, INDArray b, INDArray bStar, INDArray c) { + private ButcherTableu(INDArray[] a, INDArray b, INDArray bStar, INDArray c, double[] cMid) { this.a = a; this.b = b; this.bStar = bStar; this.c = c; + this.cMid = cMid; } /** @@ -46,6 +48,7 @@ public static class Builder { private double[] b; private double[] bStar; private double[] c; + private double[] cMid; public Builder a(double[][] a) { cache.clear(); @@ -67,6 +70,11 @@ public Builder c(double[] c) { this.c = c; return this; } + public Builder cMid(double[] cMid) { + cache.clear(); + this.cMid = cMid; return this; + } + public ButcherTableu build() { ButcherTableu tableu = cache.get(Nd4j.dataType()); if(tableu == null) { @@ -78,7 +86,8 @@ public ButcherTableu build() { aArr, Nd4j.create(b), Nd4j.create(bStar), - Nd4j.create(c)); + Nd4j.create(c), + cMid); cache.put(Nd4j.dataType(), tableu); } return tableu; diff --git a/src/main/java/ode/solve/impl/util/FirstOrderEquationWithState.java b/src/main/java/ode/solve/impl/util/FirstOrderEquationWithState.java index d25a4a2..109fa36 100644 --- a/src/main/java/ode/solve/impl/util/FirstOrderEquationWithState.java +++ b/src/main/java/ode/solve/impl/util/FirstOrderEquationWithState.java @@ -11,11 +11,12 @@ * * @author Christian Skarby */ -public class FirstOrderEquationWithState { +public class FirstOrderEquationWithState implements SolverState { private final FirstOrderEquation equation; private final INDArray time; private final State state; + private final double[] midPointCoeffs; private final static class State { private final INDArray y; // Last value of y. May be of any shape @@ -48,13 +49,14 @@ public FirstOrderEquationWithState( FirstOrderEquation equation, INDArray time, INDArray state, - long nrofStages) { + double[] midPointCoeffs) { if (!time.isScalar()) { throw new IllegalArgumentException("Expected time to be a scalar! Was: " + time); } this.equation = equation; this.time = time; - this.state = new State(state, nrofStages); + this.state = new State(state, midPointCoeffs.length); + this.midPointCoeffs = midPointCoeffs; } @@ -64,10 +66,12 @@ public FirstOrderEquationWithState( * @param stage which stage to update */ public void calculateDerivative(long stage) { + //System.out.println("\tUpdate stage " + stage + " from " + state.getStateDot(stage)); equation.calculateDerivative( state.yWorking, time.add(state.timeOffset), state.getStateDot(stage)); // Note, stateDot of the given stage will be updated by this operation + //System.out.println("\tto: " + state.getStateDot(stage)); } /** @@ -75,6 +79,7 @@ public void calculateDerivative(long stage) { * @param stage wanted stage of derivative * @return The derivative of the given stage */ + @Override public INDArray getStateDot(long stage) { return state.getStateDot(stage); } @@ -83,10 +88,25 @@ public INDArray getStateDot(long stage) { * Return the current state * @return the current state */ + @Override public INDArray getCurrentState() { return state.y; } + /** + * Return the current time state + * @return the current time state + */ + @Override + public INDArray time() { + return time; + } + + @Override + public double[] getInterpolationMidpoints() { + return midPointCoeffs; + } + /** * Estimate the error using the given {@link AdaptiveRungeKuttaSolver.MseComputation} * @param mseComputation Strategy for computing the error @@ -96,14 +116,6 @@ public INDArray estimateError(AdaptiveRungeKuttaSolver.MseComputation mseComputa return mseComputation.estimateMse(state.yDotK, state.y, state.yWorking, state.timeOffset); } - /** - * Return the current time state - * @return - */ - public INDArray time() { - return time; - } - /** * Update the working state by taking a step accumulated over all stages up the the given stage. The base step is * weighted with the given coefficients for each stage. @@ -115,12 +127,18 @@ public void step(INDArray stepCoeffPerStage, INDArray step) { } /** - * Update the current state to the working state. This also includes shifting the derivative so that previous - * stage 1 becomes new stage 0. + * Update the current state to the working state. */ public void update() { time.addi(state.timeOffset); state.y.assign(state.yWorking); + + } + + /** + * Shift the derivative so that previous stage 1 becomes new stage 0 in preparation for next step. + */ + public void shiftDerivative() { state.yDotK.putRow(0, state.yDotK.getRow(state.yDotK.rows()-1)); } } diff --git a/src/main/java/ode/solve/impl/util/InterpolatingStepListener.java b/src/main/java/ode/solve/impl/util/InterpolatingStepListener.java new file mode 100644 index 0000000..0395a01 --- /dev/null +++ b/src/main/java/ode/solve/impl/util/InterpolatingStepListener.java @@ -0,0 +1,151 @@ +package ode.solve.impl.util; + +import ode.solve.api.StepListener; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.indexing.conditions.And; +import org.nd4j.linalg.indexing.conditions.GreaterThan; +import org.nd4j.linalg.indexing.conditions.LessThan; + +import java.util.Arrays; + +/** + * Samples {@link SolverState} at given time indexes using an {@link Interpolation}. Useful to extract multiple time + * steps from a single time step solver. Advantage compared to calling the solver once per time step pair is fewer + * function evaluations. + * + * @author Christian Skarby + */ +public class InterpolatingStepListener implements StepListener { + + private final INDArray wantedTimeInds; + private final INDArray yInterpol; + private final INDArrayIndex[] yInterpolAccess; + private State state; + + private final class State { + private Interpolation interpolation = new Interpolation(); + private INDArray y0; + private INDArray t0; + } + + /** + * Create an {@link InterpolatingStepListener}. Output for the each wanted time will be assigned along dimension 0 + * of the provided yInterpol. In other words, output for wantedTimes[x] can be accessed through + * yInterpol.get(NDArrayIndex.point(x), NDArrayIndex.all(), NDArrayIndex.all(), ...) + * + * @param wantedTimes Time samples for which output is desired. + * @param yInterpol Will contain output from the provided {@link SolverState} at the desired times. + */ + public InterpolatingStepListener(INDArray wantedTimes, INDArray yInterpol) { + if (wantedTimes.length() != yInterpol.size(0)) { + throw new IllegalArgumentException("Must have one wanted time per element in dimension 0 of yInterpol! " + + "wantedTimes shape: " + Arrays.toString(wantedTimes.shape()) + ", yInterpol shape: " + + Arrays.toString(yInterpol.shape())); + } + + this.wantedTimeInds = wantedTimes; + this.yInterpol = yInterpol; + + this.yInterpolAccess = new INDArrayIndex[yInterpol.rank()]; + for (int dim = 0; dim < yInterpolAccess.length; dim++) { + yInterpolAccess[dim] = NDArrayIndex.all(); + } + } + + @Override + public void begin(INDArray t, INDArray y0) { + this.state = new State(); + this.state.t0 = t.getScalar(0); + this.state.y0 = y0; + + // Edge case: The first wanted time index is the start time -> user wants the starting state to be added to output + if (state.t0.equalsWithEps(wantedTimeInds.getScalar(0), 1e-10)) { + yInterpolAccess[0] = NDArrayIndex.point(0); + yInterpol.put(yInterpolAccess, state.y0); + } + } + + @Override + public void step(SolverState solverState, INDArray step, INDArray error) { + final INDArray greaterThanTime; + final INDArray lessThanTime; + if (step.getDouble(0) > 0) { + greaterThanTime = state.t0; + lessThanTime = solverState.time(); + } else { + greaterThanTime = solverState.time(); + lessThanTime = state.t0; + } + + final INDArray timeInds = wantedTimeInds.cond( + new And( + new GreaterThan(greaterThanTime.getDouble(0)), + new LessThan(lessThanTime.getDouble(0))) + ); + + if (timeInds.sumNumber().doubleValue() > 0) { + fitInterpolationCoeffs(solverState, step); + doInterpolation(timeInds, solverState.time()); + } + + state.t0 = solverState.time().dup(); + state.y0 = solverState.getCurrentState().dup(); + } + + private void fitInterpolationCoeffs(SolverState solverState, INDArray step) { + final INDArray[] yDotStages = new INDArray[solverState.getInterpolationMidpoints().length]; + for (int i = 0; i < yDotStages.length; i++) { + yDotStages[i] = solverState.getStateDot(i); + } + + final INDArray state0 = state.y0; + final INDArray state1 = solverState.getCurrentState(); + + final INDArray yMid = state.y0.add(scaledDotProduct( + Nd4j.createUninitialized(state0.shape()), + solverState.getInterpolationMidpoints(), + yDotStages, + step)); + + state.interpolation.fitCoeffs( + state0, + state1, + yMid, + yDotStages[0], + yDotStages[yDotStages.length - 1], + step); + } + + private INDArray scaledDotProduct(INDArray output, double[] factors, INDArray[] inputs, INDArray scale) { + output.assign((inputs[0].mul(factors[0])).mul(scale)); + for (int i = 1; i < inputs.length; i++) { + output.addi(inputs[i].mul(factors[i]).mul(scale)); + } + return output; + } + + private void doInterpolation(INDArray timeInds, INDArray tNew) { + final int startInd = timeInds.argMax().getInt(0); + final int stopInd = startInd + timeInds.sumNumber().intValue(); + + for (int i = startInd; i < stopInd; i++) { + yInterpolAccess[0] = NDArrayIndex.point(i); + + yInterpol.put(yInterpolAccess, + state.interpolation.interpolate(state.t0.getDouble(0), tNew.getDouble(0), wantedTimeInds.getDouble(i))); + } + } + + @Override + public void done() { + + // Edge case: User wants last time step to be added to interpolation + if (state.t0.equalsWithEps(wantedTimeInds.getScalar(wantedTimeInds.length() - 1), 1e-10)) { + yInterpolAccess[0] = NDArrayIndex.point(wantedTimeInds.length() - 1); + yInterpol.put(yInterpolAccess, state.y0); + } + } +} diff --git a/src/main/java/ode/solve/impl/util/Interpolation.java b/src/main/java/ode/solve/impl/util/Interpolation.java new file mode 100644 index 0000000..7c17009 --- /dev/null +++ b/src/main/java/ode/solve/impl/util/Interpolation.java @@ -0,0 +1,85 @@ +package ode.solve.impl.util; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Performs interpolation using a fourth order polynomial. Useful for reducing the number of function evaluations + * when multiple (closely spaced) time steps are used. + * Reimplementation of https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/interp.py + * + * @author Christian Skarby + */ +public class Interpolation { + + private final INDArray[] coeffs = new INDArray[5]; + + private void initCoeffs(long[] shape) { + for (int i = 0; i < coeffs.length; i++) { + coeffs[i] = Nd4j.createUninitialized(shape); + } + } + + /** + * Fit coefficients for a fourth order polynomial: p = coeffs[0] * x^4 + coeffs[1] * x^3 + coeffs[2] * x^2 + coeffs[3] * x + coeffs[4] + * + * @param y0 Function value at start of interval + * @param y1 Function value at end of interval + * @param yMid Function value at midpoint of interval + * @param f0 Derivative value at start of interval + * @param f1 Derivative value at end of interval + * @param dt Time between start and end of interval + */ + public void fitCoeffs(INDArray y0, INDArray y1, INDArray yMid, INDArray f0, INDArray f1, INDArray dt) { + if (coeffs[0] == null) { + initCoeffs(y0.shape()); + } + + final INDArray[] inputs = {f0, f1, y0, y1, yMid}; + + final double dtdub = dt.getDouble(0); + dotProduct(coeffs[0], new double[]{-2 * dtdub, 2 * dtdub, -8, -8, 16}, inputs); + + dotProduct(coeffs[1], new double[]{5 * dtdub, -3 * dtdub, 18, 14, -32}, inputs); + + dotProduct(coeffs[2], new double[]{-4 * dtdub, dtdub, -11, -5, 16}, inputs); + + coeffs[3] = f0.mul(dt); + coeffs[4] = y0.dup(); + } + + private INDArray dotProduct(INDArray output, double[] factors, INDArray[] inputs) { + output.assign((inputs[0].mul(factors[0]))); + for (int i = 1; i < inputs.length; i++) { + output.addi(inputs[i].mul(factors[i])); + } + return output; + } + + /** + * Evaluate interpolation of a fourth order polynomial: p = coeffs[0] * x^4 + coeffs[1] * x^3 + coeffs[2] * x^2 + coeffs[3] * x + coeffs[4] + * + * @param t0 Start of interval + * @param t1 End of interval + * @param t Wanted time + * @return Result of the interpolation + */ + public INDArray interpolate(double t0, double t1, double t) { + + if (Math.max(t0, t1) < t || Math.min(t0, t1) > t) { + throw new IllegalArgumentException("t0 < t < t1 or t1 < t < t0 not satisfied! t0: " + t0 + ", t: " + t + ", t1: " + t1); + } + + final double x = ((t - t0) / (t1 - t0)); + + // Create array {x^4, x^3, x^2, x, 1}; + final double[] xs = new double[coeffs.length]; + xs[coeffs.length - 1] = 1; + for (int i = coeffs.length - 2; i > -1; i--) { + xs[i] = xs[i + 1] * x; + } + + return dotProduct(Nd4j.createUninitialized(coeffs[0].shape()), xs, coeffs); + } + +} diff --git a/src/main/java/ode/solve/impl/util/SolverState.java b/src/main/java/ode/solve/impl/util/SolverState.java new file mode 100644 index 0000000..a9b2e98 --- /dev/null +++ b/src/main/java/ode/solve/impl/util/SolverState.java @@ -0,0 +1,36 @@ +package ode.solve.impl.util; + +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Interface for providing state of a solver + * + * @author Christian Skarby + */ +public interface SolverState { + + /** + * Return the given stage of the derivative of the state + * @param stage wanted stage of derivative + * @return The derivative of the given stage + */ + INDArray getStateDot(long stage); + + /** + * Return the current state + * @return the current state + */ + INDArray getCurrentState(); + + /** + * Return the current time state + * @return the current time state + */ + INDArray time(); + + /** + * Return coefficients for calculating midpoints for the given solver + * @return coefficients for calculating midpoints for the given solver + */ + double[] getInterpolationMidpoints(); +} diff --git a/src/main/java/ode/solve/impl/util/StateContainer.java b/src/main/java/ode/solve/impl/util/StateContainer.java new file mode 100644 index 0000000..1ce89a9 --- /dev/null +++ b/src/main/java/ode/solve/impl/util/StateContainer.java @@ -0,0 +1,46 @@ +package ode.solve.impl.util; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Simple container for solver state. + * + * @author Christian Skarby + */ +public class StateContainer implements SolverState { + + private final INDArray currentTime; + private final INDArray currentState; + private final INDArray currentStateDot; + + public StateContainer(double currentTime, double[] currentState, double[] currentStateDot) { + this(Nd4j.scalar(currentTime), Nd4j.create(currentState), Nd4j.create(currentStateDot)); + } + + public StateContainer(INDArray currentTime, INDArray currentState, INDArray currentStateDot) { + this.currentTime = currentTime; + this.currentState = currentState; + this.currentStateDot = currentStateDot; + } + + @Override + public INDArray getStateDot(long stage) { + return currentStateDot; + } + + @Override + public INDArray getCurrentState() { + return currentState; + } + + @Override + public INDArray time() { + return currentTime; + } + + @Override + public double[] getInterpolationMidpoints() { + return new double[0]; + } +} diff --git a/src/main/java/ode/vertex/conf/DefaultTrainingConfig.java b/src/main/java/ode/vertex/conf/DefaultTrainingConfig.java index a41690e..3510ad3 100644 --- a/src/main/java/ode/vertex/conf/DefaultTrainingConfig.java +++ b/src/main/java/ode/vertex/conf/DefaultTrainingConfig.java @@ -1,8 +1,12 @@ package ode.vertex.conf; +import ode.vertex.impl.gradview.parname.ParamNameMapping; import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.GraphVertex; import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.primitives.Pair; /** * Basic {@link TrainingConfig} for vertices which do not have a {@link org.deeplearning4j.nn.conf.layers.Layer} @@ -11,13 +15,14 @@ */ public class DefaultTrainingConfig implements TrainingConfig { - private final String name; - private final IUpdater updater; + private final ComputationGraph graph; + private final ParamNameMapping paramNameMapping; - public DefaultTrainingConfig(String name, IUpdater updater) { + public DefaultTrainingConfig(ComputationGraph graph, String name, ParamNameMapping paramNameMapping) { this.name = name; - this.updater = updater; + this.graph = graph; + this.paramNameMapping = paramNameMapping; } @Override @@ -32,22 +37,30 @@ public boolean isPretrain() { @Override public double getL1ByParam(String paramName) { - return 0; + final Pair vertexAndParam = paramNameMapping.reverseMap(paramName); + final GraphVertex vertex = graph.getVertex(vertexAndParam.getFirst()); + return vertex.getConfig().getL1ByParam(vertexAndParam.getSecond()); } @Override public double getL2ByParam(String paramName) { - return 0; + final Pair vertexAndParam = paramNameMapping.reverseMap(paramName); + final GraphVertex vertex = graph.getVertex(vertexAndParam.getFirst()); + return vertex.getConfig().getL2ByParam(vertexAndParam.getSecond()); } @Override public boolean isPretrainParam(String paramName) { - return false; + final Pair vertexAndParam = paramNameMapping.reverseMap(paramName); + final GraphVertex vertex = graph.getVertex(vertexAndParam.getFirst()); + return vertex.getConfig().isPretrainParam(vertexAndParam.getSecond()); } @Override public IUpdater getUpdaterByParam(String paramName) { - return updater; + final Pair vertexAndParam = paramNameMapping.reverseMap(paramName); + final GraphVertex vertex = graph.getVertex(vertexAndParam.getFirst()); + return vertex.getConfig().getUpdaterByParam(vertexAndParam.getSecond()); } @Override diff --git a/src/main/java/ode/vertex/conf/OdeVertex.java b/src/main/java/ode/vertex/conf/OdeVertex.java index f01a660..1150723 100644 --- a/src/main/java/ode/vertex/conf/OdeVertex.java +++ b/src/main/java/ode/vertex/conf/OdeVertex.java @@ -1,23 +1,40 @@ package ode.vertex.conf; import lombok.Data; -import ode.solve.api.FirstOrderSolver; -import ode.solve.api.FirstOrderSolverConf; import ode.solve.conf.DormandPrince54Solver; +import ode.vertex.conf.helper.OdeHelper; +import ode.vertex.conf.helper.backward.FixedStepAdjoint; +import ode.vertex.conf.helper.backward.OdeHelperBackward; +import ode.vertex.conf.helper.forward.FixedStep; +import ode.vertex.conf.helper.forward.OdeHelperForward; +import ode.vertex.impl.gradview.GradientViewFactory; +import ode.vertex.impl.gradview.GradientViewSelectionFromBlacklisted; +import ode.vertex.impl.helper.OdeGraphHelper; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; -import org.deeplearning4j.nn.conf.layers.CnnLossLayer; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.shade.jackson.annotation.JsonProperty; /** - * Configuration of an ODE block. + * Configuration of an ODE block. Contains a {@link ComputationGraphConfiguration} which defines the structure of the + * learnable function {@code f = z(t)/dt} for which the {@link ode.vertex.impl.OdeVertex} will output an estimate + * of z(t) for given t(s). + *

+ * A {@link Builder} is used to add {@link Layer}s and {@link GraphVertex GraphVertices} to the internal + * {@link ComputationGraphConfiguration}. + *

+ * Note that the internal {@code ComputationGraphConfiguration} is not the same as the "outer" + * {@code ComputationGraphConfiguration} which houses the OdeVertex itself. This understandably confusing composition + * comes from the fact that the {@code OdeVertex} needs to operate on an arbitrary graph and I didn't want to + * reimplement all the routing for doing this. If dl4j had something similar to pytorch's nn.Module I would rather have + * used that. * * @author Christian Skarby */ @@ -26,23 +43,26 @@ public class OdeVertex extends GraphVertex { protected ComputationGraphConfiguration conf; protected String firstVertex; - protected String lastVertex; - protected FirstOrderSolverConf odeSolver; + protected OdeHelperForward odeForwardConf; + protected OdeHelperBackward odeBackwardConf; + protected GradientViewFactory gradientViewFactory; public OdeVertex( @JsonProperty("conf") ComputationGraphConfiguration conf, @JsonProperty("firstVertex") String firstVertex, - @JsonProperty("lastVertex") String lastVertex, - @JsonProperty("odeSolver") FirstOrderSolverConf odeSolver) { + @JsonProperty("odeForwardConf") OdeHelperForward odeForwardConf, + @JsonProperty("odeBackwardConf") OdeHelperBackward odeBackwardConf, + @JsonProperty("gradientViewFactory") GradientViewFactory gradientViewFactory) { this.conf = conf; this.firstVertex = firstVertex; - this.lastVertex = lastVertex; - this.odeSolver = odeSolver; + this.odeForwardConf = odeForwardConf; + this.odeBackwardConf = odeBackwardConf; + this.gradientViewFactory = gradientViewFactory; } @Override public GraphVertex clone() { - return new OdeVertex(conf.clone(), firstVertex, lastVertex, odeSolver.clone()); + return new OdeVertex(conf.clone(), firstVertex, odeForwardConf.clone(), odeBackwardConf.clone(), gradientViewFactory.clone()); } @Override @@ -50,11 +70,12 @@ public boolean equals(Object o) { if (!(o instanceof OdeVertex)) { return false; } - final OdeVertex other = (OdeVertex)o; + final OdeVertex other = (OdeVertex) o; return conf.equals(other.conf) && firstVertex.equals(other.firstVertex) - && lastVertex.equals(other.lastVertex) - && odeSolver.equals(other.odeSolver); + && odeForwardConf.equals(other.odeForwardConf) + && odeBackwardConf.equals(other.odeBackwardConf) + && gradientViewFactory.equals(other.gradientViewFactory); } @Override @@ -71,12 +92,12 @@ public long numParams(boolean backprop) { @Override public int minVertexInputs() { - return conf.getVertices().get(firstVertex).minVertexInputs(); + return conf.getVertices().get(firstVertex).minVertexInputs() + odeForwardConf.nrofTimeInputs(); } @Override public int maxVertexInputs() { - return conf.getVertices().get(firstVertex).maxVertexInputs(); + return conf.getVertices().get(firstVertex).maxVertexInputs() + odeForwardConf.nrofTimeInputs(); } @Override @@ -110,18 +131,27 @@ public void setBackpropGradientsViewArray(INDArray gradient) { innerGraph.init(paramsView, false); // This does not update any parameters, just sets them - return new ode.vertex.impl.OdeVertex( - graph, - name, - idx, + final DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig( innerGraph, - odeSolver.instantiate(), - new DefaultTrainingConfig(name, graph.getVertices()[1].getConfig().getUpdaterByParam("W").clone())); + name, + gradientViewFactory.paramNameMapping()); + + return new ode.vertex.impl.OdeVertex( + new ode.vertex.impl.OdeVertex.BaseGraphVertexInputs(graph, name, idx), + new OdeGraphHelper( + odeForwardConf.instantiate(), + odeBackwardConf.instantiate(), + new OdeGraphHelper.CompGraphAsOdeFunction( + innerGraph, + // Hacky handling for legacy models. To be removed... + gradientViewFactory == null ? new GradientViewSelectionFromBlacklisted() : gradientViewFactory) + ), + trainingConfig); } @Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { - return conf.getLayerActivationTypes(vertexInputs).get(lastVertex); + return odeForwardConf.getOutputType(conf, vertexInputs); } @Override @@ -131,17 +161,25 @@ public MemoryReport getMemoryReport(InputType... inputTypes) { public static class Builder { - private final String inputName = this.toString() + "_input"; - private final String outputName = this.toString() + "_output"; - private final ComputationGraphConfiguration.GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder() - .graphBuilder(); - - private String first = null; - private String last; + private final ComputationGraphConfiguration.GraphBuilder graphBuilder; + private final String first; + private OdeHelperForward odeForwardConf = new FixedStep(new DormandPrince54Solver(), Nd4j.arange(2), true); + private OdeHelperBackward odeBackwardConf = new FixedStepAdjoint(new DormandPrince54Solver(), Nd4j.arange(2)); + private GradientViewFactory gradientViewFactory = new GradientViewSelectionFromBlacklisted(); - private FirstOrderSolverConf odeSolver = new DormandPrince54Solver(); public Builder(String name, Layer layer) { + graphBuilder = new NeuralNetConfiguration.Builder().graphBuilder(); + final String inputName = this.toString() + "_input"; + graphBuilder + .addInputs(inputName) + .addLayer(name, layer, inputName); + first = name; + } + + public Builder(NeuralNetConfiguration.Builder globalConf, String name, Layer layer) { + graphBuilder = globalConf.clone().graphBuilder(); + final String inputName = this.toString() + "_input"; graphBuilder .addInputs(inputName) .addLayer(name, layer, inputName); @@ -153,8 +191,6 @@ public Builder(String name, Layer layer) { */ public Builder addLayer(String name, Layer layer, String... inputs) { graphBuilder.addLayer(name, layer, inputs); - checkFirst(name); - last = name; return this; } @@ -163,25 +199,51 @@ public Builder addLayer(String name, Layer layer, String... inputs) { */ public Builder addVertex(String name, GraphVertex vertex, String... inputs) { graphBuilder.addVertex(name, vertex, inputs); - checkFirst(name); - last = name; return this; } /** - * Sets the {@link FirstOrderSolver} to use - * @param odeSolver solver instance + * Set the {@link OdeHelper} to use + * + * @param odeConf ODE configuration + * @return the Builder for fluent API + */ + public Builder odeConf(OdeHelper odeConf) { + odeForward(odeConf.forward()); + return odeBackward(odeConf.backward()); + } + + /** + * Sets the {@link OdeHelperForward} to use + * + * @param odeForwardConf Configuration of forward helper + * @return the Builder for fluent API + */ + public Builder odeForward(OdeHelperForward odeForwardConf) { + this.odeForwardConf = odeForwardConf; + return this; + } + + /** + * Sets the {@link OdeHelperBackward} to use + * + * @param odeBackwardConf Configuration of backward helper * @return the Builder for fluent API */ - public Builder odeSolver(FirstOrderSolverConf odeSolver) { - this.odeSolver = odeSolver; + public Builder odeBackward(OdeHelperBackward odeBackwardConf) { + this.odeBackwardConf = odeBackwardConf; return this; } - private void checkFirst(String name) { - if (first == null) { - first = name; - } + /** + * Sets the {@link GradientViewFactory} to use + * + * @param gradientViewFactory Factory for gradient views + * @return the Builder for fluent API + */ + public Builder gradientViewFactory(GradientViewFactory gradientViewFactory) { + this.gradientViewFactory = gradientViewFactory; + return this; } /** @@ -191,9 +253,12 @@ private void checkFirst(String name) { */ public OdeVertex build() { return new OdeVertex(graphBuilder - .setOutputs(outputName) - .addLayer(outputName, new CnnLossLayer(), last) - .build(), first, last, odeSolver); + .allowNoOutput(true) + .build(), + first, + odeForwardConf, + odeBackwardConf, + gradientViewFactory); } } diff --git a/src/main/java/ode/vertex/conf/helper/FixedStep.java b/src/main/java/ode/vertex/conf/helper/FixedStep.java new file mode 100644 index 0000000..9b6f0be --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/FixedStep.java @@ -0,0 +1,40 @@ +package ode.vertex.conf.helper; + +import ode.solve.api.FirstOrderSolverConf; +import ode.vertex.conf.helper.backward.FixedStepAdjoint; +import ode.vertex.conf.helper.backward.OdeHelperBackward; +import ode.vertex.conf.helper.forward.OdeHelperForward; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Configuration for using a {@link ode.solve.api.FirstOrderSolver} inside a ComputationGraph when a fixed predefined + * set of time steps shall be used when solving the ODE. + * + * @author Christian Skarby + */ +public class FixedStep implements OdeHelper { + + private final FirstOrderSolverConf solver; + private final INDArray time; + private final boolean interpolateForwardIfMultiStep; + + public FixedStep(FirstOrderSolverConf solver, INDArray time) { + this(solver, time, false); + } + + public FixedStep(FirstOrderSolverConf solver, INDArray time, boolean interpolateForwardIfMultiStep) { + this.solver = solver; + this.time = time.dup(); + this.interpolateForwardIfMultiStep = interpolateForwardIfMultiStep; + } + + @Override + public OdeHelperForward forward() { + return new ode.vertex.conf.helper.forward.FixedStep(solver, time, interpolateForwardIfMultiStep); + } + + @Override + public OdeHelperBackward backward() { + return new FixedStepAdjoint(solver, time); + } +} diff --git a/src/main/java/ode/vertex/conf/helper/InputStep.java b/src/main/java/ode/vertex/conf/helper/InputStep.java new file mode 100644 index 0000000..9e20f58 --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/InputStep.java @@ -0,0 +1,53 @@ +package ode.vertex.conf.helper; + +import ode.solve.api.FirstOrderSolverConf; +import ode.vertex.conf.helper.backward.InputStepAdjoint; +import ode.vertex.conf.helper.backward.OdeHelperBackward; +import ode.vertex.conf.helper.forward.OdeHelperForward; + +/** + * Configuration for using a {@link ode.solve.api.FirstOrderSolver} inside a {@code ComputationGraph} when time steps for solving + * the ODE comes as inputs to the {@code GraphVertex} housing the ODE. + *

+ * Example:
+ *
+ * graphBuilder.addVertex("odeVertex",
+ *     new OdeVertex.Builder("0", new DenseLayer.Builder().nOut(4).build())
+ *         .odeConf(new InputStep(solverConf, 1)) // Refers to input "time" on the line below
+ *         .build(), "someLayer", "time");
+ * 
+ * + * @author Christian Skarby + */ +public class InputStep implements OdeHelper { + + private final FirstOrderSolverConf solverConf; + private final int timeInputIndex; + private final boolean interpolateForwardIfMultiStep; + private final boolean needTimeGradient; + + public InputStep(FirstOrderSolverConf solverConf, int timeInputIndex) { + this(solverConf, timeInputIndex, false); + } + + public InputStep(FirstOrderSolverConf solverConf, int timeInputIndex, boolean interpolateForwardIfMultiStep) { + this(solverConf, timeInputIndex, interpolateForwardIfMultiStep, false); + } + + public InputStep(FirstOrderSolverConf solverConf, int timeInputIndex, boolean interpolateForwardIfMultiStep, boolean needTimeGradient) { + this.solverConf = solverConf; + this.timeInputIndex = timeInputIndex; + this.interpolateForwardIfMultiStep = interpolateForwardIfMultiStep; + this.needTimeGradient = needTimeGradient; + } + + @Override + public OdeHelperForward forward() { + return new ode.vertex.conf.helper.forward.InputStep(solverConf, timeInputIndex, interpolateForwardIfMultiStep); + } + + @Override + public OdeHelperBackward backward() { + return new InputStepAdjoint(solverConf, timeInputIndex, needTimeGradient); + } +} diff --git a/src/main/java/ode/vertex/conf/helper/OdeHelper.java b/src/main/java/ode/vertex/conf/helper/OdeHelper.java new file mode 100644 index 0000000..9a65e85 --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/OdeHelper.java @@ -0,0 +1,27 @@ +package ode.vertex.conf.helper; + +import ode.vertex.conf.helper.backward.OdeHelperBackward; +import ode.vertex.conf.helper.forward.OdeHelperForward; + +/** + * Convenience configuration of both {@link OdeHelperForward} and {@link OdeHelperBackward}. + * + * @author Christian Skarby + */ +public interface OdeHelper { + + /** + * Create the helper config in the forward direction + * + * @return helper in forward direction + */ + OdeHelperForward forward(); + + /** + * Create the helper config in the backward direction + * + * @return helper in backward direction + */ + OdeHelperBackward backward(); + +} diff --git a/src/main/java/ode/vertex/conf/helper/backward/FixedStepAdjoint.java b/src/main/java/ode/vertex/conf/helper/backward/FixedStepAdjoint.java new file mode 100644 index 0000000..618e46e --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/backward/FixedStepAdjoint.java @@ -0,0 +1,41 @@ +package ode.vertex.conf.helper.backward; + +import lombok.Data; +import ode.solve.api.FirstOrderSolverConf; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.serde.jackson.shaded.NDArrayDeSerializer; +import org.nd4j.serde.jackson.shaded.NDArraySerializer; +import org.nd4j.shade.jackson.annotation.JsonProperty; +import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; +import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; + +/** + * Serializable configuration for {@link ode.vertex.impl.helper.backward.BackpropagateAdjoint} + * + * @author Christian Skarby + */ +@Data +public class FixedStepAdjoint implements OdeHelperBackward { + + private final FirstOrderSolverConf solverConf; + @JsonSerialize(using = NDArraySerializer.class) + @JsonDeserialize(using = NDArrayDeSerializer.class) + private final INDArray time; + + public FixedStepAdjoint( + @JsonProperty("solverConf") FirstOrderSolverConf solverConf, + @JsonProperty("time") INDArray time) { + this.solverConf = solverConf; + this.time = time; + } + + @Override + public ode.vertex.impl.helper.backward.OdeHelperBackward instantiate() { + return new ode.vertex.impl.helper.backward.FixedStepAdjoint(solverConf.instantiate(), time); + } + + @Override + public FixedStepAdjoint clone() { + return new FixedStepAdjoint(solverConf.clone(), time.dup()); + } +} diff --git a/src/main/java/ode/vertex/conf/helper/backward/InputStepAdjoint.java b/src/main/java/ode/vertex/conf/helper/backward/InputStepAdjoint.java new file mode 100644 index 0000000..022c87b --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/backward/InputStepAdjoint.java @@ -0,0 +1,41 @@ +package ode.vertex.conf.helper.backward; + +import lombok.Data; +import ode.solve.api.FirstOrderSolverConf; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +/** + * Serializable configuration of an {@link ode.vertex.impl.helper.backward.InputStepAdjoint} + * + * @author Christian Skarby + */ +@Data +public class InputStepAdjoint implements OdeHelperBackward { + + private final FirstOrderSolverConf solverConf; + private final int timeInputIndex; + private final boolean needTimeGradient; + + public InputStepAdjoint(FirstOrderSolverConf solverConf, int timeInputIndex) { + this(solverConf, timeInputIndex, false); + } + + public InputStepAdjoint( + @JsonProperty("solverConf") FirstOrderSolverConf solverConf, + @JsonProperty("timeInputIndex") int timeInputIndex, + @JsonProperty("needTimeGradient") boolean needTimeGradient) { + this.solverConf = solverConf; + this.timeInputIndex = timeInputIndex; + this.needTimeGradient = needTimeGradient; + } + + @Override + public ode.vertex.impl.helper.backward.OdeHelperBackward instantiate() { + return new ode.vertex.impl.helper.backward.InputStepAdjoint(solverConf.instantiate(), timeInputIndex, needTimeGradient); + } + + @Override + public InputStepAdjoint clone() { + return new InputStepAdjoint(solverConf.clone(), timeInputIndex); + } +} diff --git a/src/main/java/ode/vertex/conf/helper/backward/OdeHelperBackward.java b/src/main/java/ode/vertex/conf/helper/backward/OdeHelperBackward.java new file mode 100644 index 0000000..8c4c1f2 --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/backward/OdeHelperBackward.java @@ -0,0 +1,24 @@ +package ode.vertex.conf.helper.backward; + +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; + +/** + * Serializable configuration of an {@link ode.vertex.impl.helper.backward.OdeHelperBackward} + * + * @author Christian Skarby + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") +public interface OdeHelperBackward { + + /** + * Instantiate the helper + * @return a New {@link ode.vertex.impl.helper.backward.OdeHelperBackward} + */ + ode.vertex.impl.helper.backward.OdeHelperBackward instantiate(); + + /** + * Clone the configuration + * @return a clone of the configuration + */ + OdeHelperBackward clone(); +} diff --git a/src/main/java/ode/vertex/conf/helper/forward/FixedStep.java b/src/main/java/ode/vertex/conf/helper/forward/FixedStep.java new file mode 100644 index 0000000..af5fbf4 --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/forward/FixedStep.java @@ -0,0 +1,66 @@ +package ode.vertex.conf.helper.forward; + +import lombok.Data; +import ode.solve.api.FirstOrderSolverConf; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.serde.jackson.shaded.NDArrayDeSerializer; +import org.nd4j.serde.jackson.shaded.NDArraySerializer; +import org.nd4j.shade.jackson.annotation.JsonProperty; +import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; +import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; + +/** + * Serializable configuration for {@link ode.vertex.impl.helper.forward.FixedStep} + * + * @author Christian Skarby + */ +@Data +public class FixedStep implements OdeHelperForward{ + + private final FirstOrderSolverConf solverConf; + @JsonSerialize(using = NDArraySerializer.class) + @JsonDeserialize(using = NDArrayDeSerializer.class) + private final INDArray time; + private final boolean interpolateIfMultiStep; + + public FixedStep( + @JsonProperty("solverConf") FirstOrderSolverConf solverConf, + @JsonProperty("time") INDArray time, + @JsonProperty("interpolateIfMultiStep") boolean interpolateIfMultiStep) { + this.solverConf = solverConf; + this.time = time; + this.interpolateIfMultiStep = interpolateIfMultiStep; + } + + @Override + public ode.vertex.impl.helper.forward.OdeHelperForward instantiate() { + return new ode.vertex.impl.helper.forward.FixedStep(solverConf.instantiate(), time, interpolateIfMultiStep); + } + + @Override + public int nrofTimeInputs() { + return 0; + } + + @Override + public FixedStep clone() { + return new FixedStep(solverConf.clone(), time.dup(), interpolateIfMultiStep); + } + + @Override + public InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException { + + final OutputTypeHelper confHelper = new OutputTypeFromConfig(); + if(time.length() == 2) { + return confHelper.getOutputType(conf, vertexInputs); + } + + final InputType[] withTime = new InputType[vertexInputs.length+1]; + System.arraycopy(vertexInputs, 0, withTime, 0, vertexInputs.length); + withTime[vertexInputs.length] = InputType.inferInputType(this.time); + return new OutputTypeAddTimeAsDimension(vertexInputs.length, confHelper).getOutputType(conf, withTime); + } +} diff --git a/src/main/java/ode/vertex/conf/helper/forward/InputStep.java b/src/main/java/ode/vertex/conf/helper/forward/InputStep.java new file mode 100644 index 0000000..948b1b4 --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/forward/InputStep.java @@ -0,0 +1,67 @@ +package ode.vertex.conf.helper.forward; + +import lombok.Data; +import ode.solve.api.FirstOrderSolverConf; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +import java.util.ArrayList; +import java.util.List; + +/** + * Serializable configuration of an {@link ode.vertex.impl.helper.forward.InputStep} + * + * @author Christian Skarby + */ +@Data +public class InputStep implements OdeHelperForward { + + private final FirstOrderSolverConf solverConf; + private final int timeInputIndex; + private final boolean interpolateIfMultiStep; + + public InputStep( + @JsonProperty("solverConf") FirstOrderSolverConf solverConf, + @JsonProperty("timeInputIndex") int timeInputIndex, + @JsonProperty("interpolateIfMultiStep") boolean interpolateIfMultiStep) { + this.solverConf = solverConf; + this.timeInputIndex = timeInputIndex; + this.interpolateIfMultiStep = interpolateIfMultiStep; + } + + @Override + public ode.vertex.impl.helper.forward.OdeHelperForward instantiate() { + return new ode.vertex.impl.helper.forward.InputStep(solverConf.instantiate(), timeInputIndex, interpolateIfMultiStep); + } + + @Override + public int nrofTimeInputs() { + return 1; + } + + @Override + public InputStep clone() { + return new InputStep(solverConf.clone(), timeInputIndex, interpolateIfMultiStep); + } + + @Override + public InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException { + if(vertexInputs.length <= timeInputIndex) { + throw new InvalidInputTypeException("Time input index was not part of input types!!"); + } + + final InputType timeInput = vertexInputs[timeInputIndex]; + if(timeInput.arrayElementsPerExample() > 2) { + return new OutputTypeAddTimeAsDimension(timeInputIndex, new OutputTypeFromConfig()).getOutputType(conf, vertexInputs); + } + List inputTypeList = new ArrayList<>(); + for (int i = 0; i < vertexInputs.length; i++) { + if (i != timeInputIndex) { + inputTypeList.add(vertexInputs[i]); + } + } + return new OutputTypeFromConfig().getOutputType(conf, inputTypeList.toArray(new InputType[0])); + } +} diff --git a/src/main/java/ode/vertex/conf/helper/forward/OdeHelperForward.java b/src/main/java/ode/vertex/conf/helper/forward/OdeHelperForward.java new file mode 100644 index 0000000..7c5eda8 --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/forward/OdeHelperForward.java @@ -0,0 +1,31 @@ +package ode.vertex.conf.helper.forward; + +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; + +/** + * Serializable configuration of an {@link ode.vertex.impl.helper.forward.OdeHelperForward} + * + * @author Christian Skarby + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") +public interface OdeHelperForward extends OutputTypeHelper { + + /** + * Instantiate the helper + * @return a New {@link ode.vertex.impl.helper.forward.OdeHelperForward} + */ + ode.vertex.impl.helper.forward.OdeHelperForward instantiate(); + + /** + * How many time inputs are needed + * @return the number of needed time inputs + */ + int nrofTimeInputs(); + + + /** + * Clone the configuration + * @return a clone of the configuration + */ + OdeHelperForward clone(); +} diff --git a/src/main/java/ode/vertex/conf/helper/forward/OutputTypeAddTimeAsDimension.java b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeAddTimeAsDimension.java new file mode 100644 index 0000000..4f0d844 --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeAddTimeAsDimension.java @@ -0,0 +1,58 @@ +package ode.vertex.conf.helper.forward; + +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.deeplearning4j.nn.conf.layers.Convolution3D; + +import java.util.ArrayList; +import java.util.List; + +/** + * Adds a time dimension to output types from a given {@link OutputTypeHelper} given an input time type + * + * @author Christian Skarby + */ +public class OutputTypeAddTimeAsDimension implements OutputTypeHelper { + + private final int timeInputIndex; + private final OutputTypeHelper sourceHelper; + + public OutputTypeAddTimeAsDimension(int timeInputIndex, OutputTypeHelper sourceHelper) { + this.timeInputIndex = timeInputIndex; + this.sourceHelper = sourceHelper; + } + + @Override + public InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException { + List inputTypeList = new ArrayList<>(); + InputType time = vertexInputs[timeInputIndex]; + for (int i = 0; i < vertexInputs.length; i++) { + if (i != timeInputIndex) { + inputTypeList.add(vertexInputs[i]); + } + } + + InputType outputs = sourceHelper.getOutputType(conf, inputTypeList.toArray(new InputType[0])); + + if(time.getType() != InputType.Type.FF) { + throw new IllegalArgumentException("Time must be 1D!"); + } + + return addTimeDim(outputs, time); + } + + private InputType addTimeDim(InputType type, InputType timeDim) { + switch (type.getType()) { + case FF: return InputType.recurrent(type.arrayElementsPerExample(), timeDim.arrayElementsPerExample()); + case CNN: + InputType.InputTypeConvolutional convType = (InputType.InputTypeConvolutional)type; + return InputType.convolutional3D(Convolution3D.DataFormat.NDHWC, + timeDim.arrayElementsPerExample(), + convType.getHeight(), + convType.getWidth(), + convType.getChannels()); + default: throw new InvalidInputTypeException("Input type not supported with time as input!"); + } + } +} diff --git a/src/main/java/ode/vertex/conf/helper/forward/OutputTypeFromConfig.java b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeFromConfig.java new file mode 100644 index 0000000..d92cfb4 --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeFromConfig.java @@ -0,0 +1,27 @@ +package ode.vertex.conf.helper.forward; + +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; + +import java.util.Map; + +/** + * Determines output type from the given configuration + * + * @author Christian Skarby + */ +public class OutputTypeFromConfig implements OutputTypeHelper { + + @Override + public InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException { + final Map inputTypeMap = conf.getLayerActivationTypes(vertexInputs); + inputTypeMap.keySet().removeAll(conf.getVertexInputs().keySet()); + + if(inputTypeMap.size() != 1) { + throw new IllegalStateException("Can only support one single output!"); + } + + return inputTypeMap.values().iterator().next(); + } +} diff --git a/src/main/java/ode/vertex/conf/helper/forward/OutputTypeHelper.java b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeHelper.java new file mode 100644 index 0000000..7e6c77b --- /dev/null +++ b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeHelper.java @@ -0,0 +1,22 @@ +package ode.vertex.conf.helper.forward; + +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; + +/** + * Interface for determining the shape of output given shape of inputs. + * + * @author Christian Skarby + */ +public interface OutputTypeHelper { + + /** + * Return the format of the input for the given InputTypes + * @param conf {@link ComputationGraphConfiguration} which will be used as the derivative + * @param vertexInputs Inputs to the vertex + * @return an {@link InputType} for the next vertex in the graph + * @throws InvalidInputTypeException + */ + InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException; +} diff --git a/src/main/java/ode/vertex/impl/OdeVertex.java b/src/main/java/ode/vertex/impl/OdeVertex.java index aeb757a..582395c 100644 --- a/src/main/java/ode/vertex/impl/OdeVertex.java +++ b/src/main/java/ode/vertex/impl/OdeVertex.java @@ -1,31 +1,27 @@ package ode.vertex.impl; -import ode.solve.api.FirstOrderEquation; -import ode.solve.api.FirstOrderSolver; +import lombok.AllArgsConstructor; +import lombok.Getter; +import ode.vertex.impl.helper.OdeGraphHelper; +import ode.vertex.impl.helper.backward.OdeHelperBackward; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.TrainingConfig; -import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; -import org.deeplearning4j.nn.graph.vertex.GraphVertex; -import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; import java.util.Map; /** - * Implementation of an ODE block. + * Implementation of an ODE block. Contains a {@link ComputationGraph} which defines the learnable function + * {@code f = z(t)/dt} for which the {@code OdeVertex} will output an estimate of z(t) for given t(s). * * @author Christian Skarby */ @@ -33,41 +29,28 @@ public class OdeVertex extends BaseGraphVertex { private static final Logger log = LoggerFactory.getLogger(OdeVertex.class); - private final static String parName = "params"; - - private final ComputationGraph graph; - private final FirstOrderSolver odeSolver; + private final OdeGraphHelper odeHelper; private final TrainingConfig trainingConfig; - private final Parameters parameters; - - private static class Parameters { - private final INDArray time; - private INDArray lastOutput; // z(t1) from paper - private final NonContiguous1DView realGradients; // Parts of graph.getFlattenedGradients() which are actually gradients - - public Parameters(INDArray time) { - this.time = time; - realGradients = new NonContiguous1DView(); - } + @AllArgsConstructor + @Getter + public static class BaseGraphVertexInputs { + private final ComputationGraph graph; + private final String name; + private final int vertexIndex; } - public OdeVertex(ComputationGraph actualGraph, - String name, - int vertexIndex, - ComputationGraph innerGraph, - FirstOrderSolver odeSolver, + public OdeVertex(BaseGraphVertexInputs baseGraphVertexInputs, + OdeGraphHelper odeHelper, TrainingConfig trainingConfig) { - super(actualGraph, name, vertexIndex, null, null); - this.graph = innerGraph; + super(baseGraphVertexInputs.getGraph(), baseGraphVertexInputs.getName(), baseGraphVertexInputs.getVertexIndex(), null, null); this.trainingConfig = trainingConfig; - this.odeSolver = odeSolver; - this.parameters = new Parameters(Nd4j.create(new double[]{0, 1})); + this.odeHelper = odeHelper; } @Override public String toString() { - return graph.toString(); + return odeHelper.getFunction().toString(); } @Override @@ -82,24 +65,23 @@ public Layer getLayer() { @Override public long numParams() { - return graph.numParams(); + return odeHelper.getFunction().numParams(); } @Override public INDArray params() { - return graph.params(); + return odeHelper.getFunction().params(); } @Override public void clear() { super.clear(); - graph.clearLayersStates(); - parameters.lastOutput = null; + odeHelper.clear(); } @Override public Map paramTable(boolean backpropOnly) { - return Collections.synchronizedMap(Collections.singletonMap(parName, params())); + return odeHelper.paramTable(backpropOnly); } @Override @@ -110,10 +92,6 @@ public TrainingConfig getConfig() { private void validateForward() { if (!canDoForward()) throw new IllegalStateException("Cannot do forward pass: inputs not set"); - - if (getInputs().length != 1) { - throw new IllegalStateException("Only one input supported!"); - } } private void validateBackprop() { @@ -137,26 +115,9 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { leverageInputs(workspaceMgr); - final LayerWorkspaceMgr innerWorkspaceMgr = createWorkspaceMgr(workspaceMgr); - - final ForwardPass equation = new ForwardPass( - graph, - innerWorkspaceMgr, - true, // Always use training as batch norm running mean and var become messed up otherwise. Same effect seen in original pytorch repo. - getInputs()); - - // nrof outputs must be same as number of inputs due to resblock - parameters.lastOutput = workspaceMgr.createUninitialized(ArrayType.INPUT, getInputs()[0].shape()).detach(); - odeSolver.integrate(equation, parameters.time, getInputs()[0], parameters.lastOutput); + final INDArray output = odeHelper.doForward(workspaceMgr, getInputs()); - for (GraphVertex vertex : graph.getVertices()) { - final INDArray[] inputs = vertex.getInputs(); - for (int i = 0; i < inputs.length; i++) { - vertex.setInput(i, workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, inputs[i]), workspaceMgr); - } - } - - return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, parameters.lastOutput); + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, output); } @Override @@ -164,51 +125,18 @@ public Pair doBackward(boolean tbptt, LayerWorkspaceMgr wo validateBackprop(); log.trace("Start backward"); - // Create augmented dynamics for adjoint method - // Initialization: S0: - // z(t1) = lastoutput - // a(t) = -dL/d(z(t1)) = -epsilon from next layer (i.e getEpsilon) - // parameters = zeros - // dL/dt1 = dL / dz(t1) dot z(t1) - final INDArray dL_dtN = getEpsilon().reshape(1, parameters.lastOutput.length()) - .mmul(parameters.lastOutput.reshape(parameters.lastOutput.length(), 1)).muli(-1); - - final INDArray zAug = Nd4j.create(1, parameters.lastOutput.length() + getEpsilon().length() + graph.numParams() + dL_dtN.length()); - - final NDArrayIndexAccumulator accumulator = new NDArrayIndexAccumulator(zAug); - accumulator.increment(parameters.lastOutput.reshape(new long[]{1, parameters.lastOutput.length()})) - .increment(getEpsilon().reshape(new long[]{1, getEpsilon().length()})) - .increment(Nd4j.zeros(parameters.realGradients.length()).reshape(new long[]{1, Nd4j.zeros(parameters.realGradients.length()).length()})) - .increment(dL_dtN.reshape(new long[]{1, dL_dtN.length()})); - - final AugmentedDynamics augmentedDynamics = new AugmentedDynamics( - zAug, - getEpsilon().shape(), - new long[]{parameters.realGradients.length()}, - dL_dtN.shape()); - - final LayerWorkspaceMgr innerWorkspaceMgr = createWorkspaceMgr(workspaceMgr); - - final FirstOrderEquation equation = new BackpropagateAdjoint( - augmentedDynamics, - new ForwardPass(graph, - innerWorkspaceMgr, - true, - getInputs()), - new BackpropagateAdjoint.GraphInfo(graph, parameters.realGradients, innerWorkspaceMgr, tbptt) - ); - - INDArray augAns = odeSolver.integrate(equation, Nd4j.reverse(parameters.time.dup()), zAug, zAug.dup()); - - augmentedDynamics.updateFrom(augAns); - - final INDArray epsilonOut = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, augmentedDynamics.zAdjoint()); + final Pair gradients = odeHelper.doBackward( + new OdeHelperBackward.MiscPar(tbptt, workspaceMgr), + getEpsilon(), + getInputs()); - parameters.realGradients.assignFrom(augmentedDynamics.paramAdjoint()); - final Gradient gradient = new DefaultGradient(graph.getFlattenedGradients()); - gradient.setGradientFor(parName, graph.getFlattenedGradients()); + final INDArray[] inputGrads = gradients.getSecond(); + final INDArray[] leveragedGrads = new INDArray[inputGrads.length]; + for (int i = 0; i < inputGrads.length; i++) { + leveragedGrads[i] = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, inputGrads[i]); + } - return new Pair<>(gradient, new INDArray[]{epsilonOut}); + return new Pair<>(gradients.getFirst(), leveragedGrads); } private void leverageInputs(LayerWorkspaceMgr workspaceMgr) { @@ -217,63 +145,9 @@ private void leverageInputs(LayerWorkspaceMgr workspaceMgr) { } } - private LayerWorkspaceMgr createWorkspaceMgr(final LayerWorkspaceMgr outerWsMgr) { - - return new ComputationGraph(graph.getConfiguration()) { - public LayerWorkspaceMgr spyWsConfigs() { - // A little bit too many methods to comfortably decorate. Try to copy config instead - final LayerWorkspaceMgr.Builder wsBuilder = LayerWorkspaceMgr.builder(); - for (ArrayType type : ArrayType.values()) { - if (outerWsMgr.hasConfiguration(type)) { - wsBuilder.with(type, outerWsMgr.getWorkspaceName(type), outerWsMgr.getConfiguration(type)); - } - } - - final LayerWorkspaceMgr wsMgr = wsBuilder - .with(ArrayType.FF_WORKING_MEM, "WS_ODE_VERTEX_LAYER_WORKING_MEM", WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, "WS_ODE_VERTEX_LAYER_WORKING_MEM", WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, "WS_ODE_VERTEX_RNN_LOOP_WORKING_MEM", WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, "WS_ODE_VERTEX_RNN_LOOP_WORKING_MEM", WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.ACTIVATIONS, "WS_ODE_VERTEX_ALL_LAYERS_ACT", WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.ACTIVATION_GRAD, "WS_ODE_VERTEX_ALL_LAYERS_GRAD", WS_ALL_LAYERS_ACT_CONFIG) - .build(); - wsMgr.setHelperWorkspacePointers(outerWsMgr.getHelperWorkspacePointers()); - return wsMgr; - } - }.spyWsConfigs(); - - } - @Override public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) { - graph.setBackpropGradientsViewArray(backpropGradientsViewArray); - - // What is this about? Some layers "abuse" the gradient to perform updates of parameters for which no gradient - // is calculated and this screws up the ODE solvers idea of what the solution is. The following layers are known - // to do this: - // - // * BatchNormalization: The global variance and mean are just the (sliding) average of the batch dittos. - // However, in order to support distributed training the updates are performed by adding - // the change to the state as a gradient even through it is not really. - - // Maybe get these from config so user can specify others e.g. for custom layers - final List nonGradientParamNames = Arrays.asList( - BatchNormalizationParamInitializer.GLOBAL_VAR, - BatchNormalizationParamInitializer.GLOBAL_MEAN); - - parameters.realGradients.clear(); - for (Layer layer : graph.getLayers()) { - Map gradParams = layer.conf().getLayer().initializer().getGradientsFromFlattened(layer.conf(), layer.getGradientsViewArray()); - for (Map.Entry parNameAndGradView : gradParams.entrySet()) { - - final String parName = parNameAndGradView.getKey(); - final INDArray grad = parNameAndGradView.getValue(); - - if (!nonGradientParamNames.contains(parName)) { - parameters.realGradients.addView(grad.reshape(grad.length())); - } - } - } + odeHelper.setBackpropGradientsViewArray(backpropGradientsViewArray); } @Override diff --git a/src/main/java/ode/vertex/impl/gradview/Contiguous1DView.java b/src/main/java/ode/vertex/impl/gradview/Contiguous1DView.java new file mode 100644 index 0000000..cf31163 --- /dev/null +++ b/src/main/java/ode/vertex/impl/gradview/Contiguous1DView.java @@ -0,0 +1,49 @@ +package ode.vertex.impl.gradview; + +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * A simple contiguous gradient view + * + * @author Christian Skarby + */ +public class Contiguous1DView implements INDArray1DView { + + private final INDArray view; + + public Contiguous1DView(INDArray view) { + this.view = view; + } + + @Override + public void assignFrom(INDArray toAssign) { + if(toAssign.length() != view.length()) { + throw new IllegalArgumentException("Array to assignFrom must have same length! " + + "This length: " + view.length() +" array length: " + toAssign.length()); + } + + if(toAssign.rank() != 1) { + throw new IllegalArgumentException("Array toAssign must have rank 1!"); + } + + view.assign(toAssign); + } + + @Override + public void assignTo(INDArray assignTo) { + if(assignTo.length() != view.length()) { + throw new IllegalArgumentException("Array assignTo must have same length! " + + "This length: " + view.length() +" array length: " + assignTo.length()); + } + if(assignTo.rank() != 1) { + throw new IllegalArgumentException("Array assignTo must have rank 1!"); + } + + assignTo.assign(view); + } + + @Override + public long length() { + return view.length(); + } +} diff --git a/src/main/java/ode/vertex/impl/gradview/GradientViewFactory.java b/src/main/java/ode/vertex/impl/gradview/GradientViewFactory.java new file mode 100644 index 0000000..933db86 --- /dev/null +++ b/src/main/java/ode/vertex/impl/gradview/GradientViewFactory.java @@ -0,0 +1,41 @@ +package ode.vertex.impl.gradview; + +import ode.vertex.impl.gradview.parname.ParamNameMapping; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; + +import java.io.Serializable; + +/** + * Creates an {@link INDArray1DView} of all gradients in a set of vertices. Reason why this exists is that there are + * instances of parameters which DL4J puts in the gradient view but which are not actually gradients. One example is the + * running mean and variance of batchnorm. Such parameters does not play well with adjoint backpropagation for Neural + * ODEs. + * + * @author Christian Skarby + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") +public interface GradientViewFactory extends Serializable { + + /** + * Create a {@link ParameterGradientView} of the gradients of the given graph. + * + * @param graph Graph to extract gradients from + * @return Views of the gradients + */ + ParameterGradientView create(ComputationGraph graph); + + /** + * Return the mapping used to create non-colliding names + * + * @return the mapping used to create non-colliding names + */ + ParamNameMapping paramNameMapping(); + + /** + * Clone the factory + * + * @return a clone + */ + GradientViewFactory clone(); +} diff --git a/src/main/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklisted.java b/src/main/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklisted.java new file mode 100644 index 0000000..d6e2d0f --- /dev/null +++ b/src/main/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklisted.java @@ -0,0 +1,151 @@ +package ode.vertex.impl.gradview; + +import lombok.Data; +import ode.vertex.impl.gradview.parname.Concat; +import ode.vertex.impl.gradview.parname.ParamNameMapping; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.GraphVertex; +import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * {@link GradientViewFactory} which selects either a {@link Contiguous1DView} or a {@link NonContiguous1DView} based on + * presence of blacklisted parameters in the graph. + * + * @author Christian Skarby + */ +@Data +public class GradientViewSelectionFromBlacklisted implements GradientViewFactory { + + private final List nonGradientParamNames; + private final ParamNameMapping paramNameMapping; + + public GradientViewSelectionFromBlacklisted() { + this(Arrays.asList( + BatchNormalizationParamInitializer.GLOBAL_VAR, + BatchNormalizationParamInitializer.GLOBAL_MEAN)); + } + + public GradientViewSelectionFromBlacklisted(List nonGradientParamNames) { + this(nonGradientParamNames, + new Concat()); + } + + public GradientViewSelectionFromBlacklisted(@JsonProperty("nonGradientParamNames") List nonGradientParamNames, + @JsonProperty("paramNameMapping") ParamNameMapping paramNameMapping) { + this.nonGradientParamNames = nonGradientParamNames; + this.paramNameMapping = paramNameMapping; + } + + public ParameterGradientView create(ComputationGraph graph) { + + final Gradient gradient = getAllGradients(graph); + + for (GraphVertex vertex : graph.getVertices()) { + if (hasNonGradient(vertex)) { + return new ParameterGradientView(gradient, createNonContiguous1DView(graph)); + } + } + + return new ParameterGradientView(gradient, new Contiguous1DView(graph.getGradientsViewArray())); + } + + private boolean hasNonGradient(GraphVertex vertex) { + boolean anyNonGrad = false; + for (String parName : vertex.paramTable(false).keySet()) { + anyNonGrad |= nonGradientParamNames.contains(parName); + } + return anyNonGrad; + } + + private NonContiguous1DView createNonContiguous1DView(ComputationGraph graph) { + final NonContiguous1DView gradView = new NonContiguous1DView(); + + for (GraphVertex vertex : graph.getVertices()) { + addGradientView(gradView, vertex); + } + return gradView; + } + + private void addGradientView(NonContiguous1DView gradView, GraphVertex vertex) { + if (vertex.numParams() > 0 && hasNonGradient(vertex)) { + Layer layer = vertex.getLayer(); + + if (layer == null) { + // Only way I have found to get mapping from gradient view to gradient view per parameter is though + // a ParameterInitializer as done below and only Layers seem be able to provide them + throw new UnsupportedOperationException("Can not (reliably) get correct gradient views from non-layer " + + "vertices with blacklisted parameters!"); + } + + Map gradParams = layer.conf().getLayer().initializer().getGradientsFromFlattened(layer.conf(), layer.getGradientsViewArray()); + for (Map.Entry parNameAndGradView : gradParams.entrySet()) { + final String parName = parNameAndGradView.getKey(); + final INDArray grad = parNameAndGradView.getValue(); + + if (!nonGradientParamNames.contains(parName)) { + gradView.addView(grad); + } + } + } else if (vertex.numParams() > 0) { + gradView.addView(vertex.getGradientsViewArray()); + } + } + + private Gradient getAllGradients(ComputationGraph graph) { + final Gradient allGradients = new DefaultGradient(graph.getGradientsViewArray()); + for (GraphVertex vertex : graph.getVertices()) { + if (vertex.numParams() > 0) { + Layer layer = vertex.getLayer(); + + if (layer == null) { + // Only way I have found to get mapping from gradient view to gradient view per parameter is though + // a ParameterInitializer as done below and only Layers seem be able to provide them + throw new UnsupportedOperationException("Can not (reliably) get correct gradient views from non-layer " + + "vertices with blacklisted parameters!"); + } + + Map gradParams = layer.conf().getLayer().initializer().getGradientsFromFlattened(layer.conf(), layer.getGradientsViewArray()); + for (Map.Entry parNameAndGradView : gradParams.entrySet()) { + final String parName = parNameAndGradView.getKey(); + final INDArray grad = parNameAndGradView.getValue(); + allGradients.setGradientFor(paramNameMapping.map(vertex.getVertexName(), parName), grad); + } + } + } + return allGradients; + } + + @Override + public ParamNameMapping paramNameMapping() { + return paramNameMapping; + } + + @Override + public GradientViewFactory clone() { + return new GradientViewSelectionFromBlacklisted(nonGradientParamNames, paramNameMapping); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof GradientViewSelectionFromBlacklisted)) return false; + GradientViewSelectionFromBlacklisted that = (GradientViewSelectionFromBlacklisted) o; + return nonGradientParamNames.equals(that.nonGradientParamNames) && paramNameMapping.equals(that.paramNameMapping); + } + + @Override + public int hashCode() { + return Objects.hash(nonGradientParamNames); + } + +} diff --git a/src/main/java/ode/vertex/impl/gradview/INDArray1DView.java b/src/main/java/ode/vertex/impl/gradview/INDArray1DView.java new file mode 100644 index 0000000..43367dd --- /dev/null +++ b/src/main/java/ode/vertex/impl/gradview/INDArray1DView.java @@ -0,0 +1,29 @@ +package ode.vertex.impl.gradview; + +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * A 1D view of one or more {@link INDArray}s. + * + * @author Christian Skarby + */ +public interface INDArray1DView { + + /** + * Sets the view to the given {@link INDArray} + * @param toAssign view will be set to this. Must be same size as view + */ + void assignFrom(INDArray toAssign); + + /** + * Sets the values of the given {@link INDArray} to the values of the view + * @param assignTo will be set to state of the view. Must be same size as view + */ + void assignTo(INDArray assignTo); + + /** + * Return the current length (total number of elements) of the view + * @return the current length of the view + */ + long length(); +} diff --git a/src/main/java/ode/vertex/impl/NonContiguous1DView.java b/src/main/java/ode/vertex/impl/gradview/NonContiguous1DView.java similarity index 61% rename from src/main/java/ode/vertex/impl/NonContiguous1DView.java rename to src/main/java/ode/vertex/impl/gradview/NonContiguous1DView.java index 2f981a5..4d72c79 100644 --- a/src/main/java/ode/vertex/impl/NonContiguous1DView.java +++ b/src/main/java/ode/vertex/impl/gradview/NonContiguous1DView.java @@ -1,11 +1,10 @@ -package ode.vertex.impl; +package ode.vertex.impl.gradview; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; /** @@ -13,28 +12,17 @@ * * @author Christian Skarby */ -public class NonContiguous1DView { +public class NonContiguous1DView implements INDArray1DView { private final List view = new ArrayList<>(); private long length = 0; - public void addView(INDArray array, long begin, long end) { - INDArray viewSlice = array.get(NDArrayIndex.interval(begin, end)); - addView(viewSlice); - } - public void addView(INDArray viewSlice) { - if(viewSlice.isColumnVectorOrScalar()) { - throw new IllegalArgumentException("Must be vector or scalar! Had shape: " + Arrays.toString(viewSlice.shape())); - } length += viewSlice.length(); view.add(viewSlice); } - /** - * Sets the view to the given {@link INDArray} - * @param toAssign view will be set to this. Must be same size as view - */ + @Override public void assignFrom(INDArray toAssign) { if(toAssign.length() != length) { throw new IllegalArgumentException("Array to assignFrom must have same length! " + @@ -47,15 +35,12 @@ public void assignFrom(INDArray toAssign) { long ptr = 0; for(INDArray viewSlice: view) { - viewSlice.assign(toAssign.get(NDArrayIndex.interval(ptr, ptr + viewSlice.length()))); + viewSlice.assign(toAssign.get(NDArrayIndex.interval(ptr, ptr + viewSlice.length())).reshape(viewSlice.shape())); ptr += viewSlice.length(); } } - /** - * Sets the values of the given {@link INDArray} to the values of the view - * @param assignTo will be set to state of the view. Must be same size as view - */ + @Override public void assignTo(INDArray assignTo) { if(assignTo.length() != length) { throw new IllegalArgumentException("Array assignTo must have same length! " + @@ -67,24 +52,19 @@ public void assignTo(INDArray assignTo) { long ptr = 0; for(INDArray viewSlice: view) { - assignTo.put(new INDArrayIndex[] {NDArrayIndex.interval(ptr, ptr + viewSlice.length())}, viewSlice); + assignTo.put(new INDArrayIndex[] {NDArrayIndex.interval(ptr, ptr + viewSlice.length())}, viewSlice.reshape(viewSlice.length())); ptr += viewSlice.length(); } } - /** - * Return the current length (total number of elements) of the view - * @return the current length of the view - */ + @Override public long length() { return length; } - /** - * Clears the view - */ - public void clear() { - view.clear(); - length = 0; + @Override + public String toString() { + return view.toString(); } + } diff --git a/src/main/java/ode/vertex/impl/gradview/ParameterGradientView.java b/src/main/java/ode/vertex/impl/gradview/ParameterGradientView.java new file mode 100644 index 0000000..917129a --- /dev/null +++ b/src/main/java/ode/vertex/impl/gradview/ParameterGradientView.java @@ -0,0 +1,37 @@ +package ode.vertex.impl.gradview; + +import org.deeplearning4j.nn.gradient.Gradient; + +/** + * Different views of the parameter gradients of a graph. + * + * @author Christian Skarby + */ +public class ParameterGradientView { + + private final Gradient allGradients; + private final INDArray1DView realGradientView; + + public ParameterGradientView(Gradient allGradients, INDArray1DView realGradientView) { + this.allGradients = allGradients; + this.realGradientView = realGradientView; + } + + + /** + * Returns all gradients (even those which are not actually gradients) per parameter + * @return a Gradient for all parameters + */ + public Gradient allGradientsPerParam() { + return allGradients; + } + + /** + * Returns an {@link INDArray1DView} of only the parts of the gradient view which are actually gradients. + * @return and {@link INDArray1DView} + */ + public INDArray1DView realGradientView() { + return realGradientView; + } + +} diff --git a/src/main/java/ode/vertex/impl/gradview/parname/Concat.java b/src/main/java/ode/vertex/impl/gradview/parname/Concat.java new file mode 100644 index 0000000..c4d3681 --- /dev/null +++ b/src/main/java/ode/vertex/impl/gradview/parname/Concat.java @@ -0,0 +1,39 @@ +package ode.vertex.impl.gradview.parname; + +import lombok.Data; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +/** + * {@link ParamNameMapping} which concatenates the names + * + * @author Christian Skarby + */ +@Data +public class Concat implements ParamNameMapping { + + private final String concatStr; + + public Concat() { + this("-"); + } + + public Concat(@JsonProperty("concatStr") String concatStr) { + this.concatStr = concatStr; + } + + + @Override + public String map(String vertexName, String paramName) { + return vertexName + concatStr + paramName; + } + + @Override + public Pair reverseMap(String combinedName) { + final String[] split = combinedName.split(concatStr); + if(split.length != 2) { + throw new IllegalArgumentException("Can not reverse mapping for " + combinedName); + } + return new Pair<>(split[0], split[1]); + } +} diff --git a/src/main/java/ode/vertex/impl/gradview/parname/ParamNameMapping.java b/src/main/java/ode/vertex/impl/gradview/parname/ParamNameMapping.java new file mode 100644 index 0000000..467e448 --- /dev/null +++ b/src/main/java/ode/vertex/impl/gradview/parname/ParamNameMapping.java @@ -0,0 +1,30 @@ +package ode.vertex.impl.gradview.parname; + +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; + +/** + * A mapping between parameter and vertex names to a combined name. Also capable of reverse mapping. + * + * @author Christian Skarby + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") +public interface ParamNameMapping { + + /** + * Map layerName and paramName to a new parameter name which is unique for the given input + * @param vertexName Name of layer + * @param paramName Name of parameter + * @return A combined namn + */ + String map(String vertexName, String paramName); + + /** + * Reverse mapping. In other words, {@code mapping.reverseMap(mapping.map(vertexName, paramName)); } returns + * {@code [vertexName, paramName]} + * @param combinedName Combined name to reverese map + * @return a Pair where vertexName is the first member and paramName is the second. + */ + Pair reverseMap(String combinedName); + +} diff --git a/src/main/java/ode/vertex/impl/gradview/parname/Prefix.java b/src/main/java/ode/vertex/impl/gradview/parname/Prefix.java new file mode 100644 index 0000000..813c6bf --- /dev/null +++ b/src/main/java/ode/vertex/impl/gradview/parname/Prefix.java @@ -0,0 +1,33 @@ +package ode.vertex.impl.gradview.parname; + +import lombok.Data; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +/** + * {@link ParamNameMapping} which adds a prefix to another mapping + * + * @author Christian Skarby + */ +@Data +public class Prefix implements ParamNameMapping { + + private final ParamNameMapping mapping; + private final String prefix; + + public Prefix(@JsonProperty("prefix") String prefix, + @JsonProperty("mapping") ParamNameMapping mapping) { + this.mapping = mapping; + this.prefix = prefix; + } + + @Override + public String map(String vertexName, String paramName) { + return prefix + mapping.map(vertexName, paramName); + } + + @Override + public Pair reverseMap(String combinedName) { + return mapping.reverseMap(combinedName.substring(prefix.length())); + } +} diff --git a/src/main/java/ode/vertex/impl/NDArrayIndexAccumulator.java b/src/main/java/ode/vertex/impl/helper/NDArrayIndexAccumulator.java similarity index 80% rename from src/main/java/ode/vertex/impl/NDArrayIndexAccumulator.java rename to src/main/java/ode/vertex/impl/helper/NDArrayIndexAccumulator.java index 790473a..9873878 100644 --- a/src/main/java/ode/vertex/impl/NDArrayIndexAccumulator.java +++ b/src/main/java/ode/vertex/impl/helper/NDArrayIndexAccumulator.java @@ -1,4 +1,4 @@ -package ode.vertex.impl; +package ode.vertex.impl.helper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; @@ -10,12 +10,12 @@ * * @author Christian Skarby */ -class NDArrayIndexAccumulator { +public class NDArrayIndexAccumulator { private final INDArrayIndex[] state; private final INDArray array; - NDArrayIndexAccumulator(INDArray array) { + public NDArrayIndexAccumulator(INDArray array) { this.array = array; state = new INDArrayIndex[array.shape().length]; for(int i = 0; i < array.shape().length; i++) { @@ -23,7 +23,9 @@ class NDArrayIndexAccumulator { } } - NDArrayIndexAccumulator increment(INDArray toAdd) { + public NDArrayIndexAccumulator increment(INDArray toAdd) { + if(toAdd.isEmpty()) return this; + for(int dim = 0; dim < toAdd.shape().length; dim++) { if(toAdd.size(dim) != array.size(dim)) { final long curr = state[dim] instanceof NDArrayIndexAll ? 0 : state[dim].end(); diff --git a/src/main/java/ode/vertex/impl/helper/OdeGraphHelper.java b/src/main/java/ode/vertex/impl/helper/OdeGraphHelper.java new file mode 100644 index 0000000..7bca036 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/OdeGraphHelper.java @@ -0,0 +1,189 @@ +package ode.vertex.impl.helper; + +import ode.vertex.impl.gradview.GradientViewFactory; +import ode.vertex.impl.gradview.INDArray1DView; +import ode.vertex.impl.gradview.ParameterGradientView; +import ode.vertex.impl.helper.backward.OdeHelperBackward; +import ode.vertex.impl.helper.forward.OdeHelperForward; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.GraphVertex; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; + +/** + * Helper which jumps through the hoops so that a {@link ComputationGraph} can be seen as the function which provides + * the derivatives for an ODE solver. + * + * @author Christian Skarby + */ +public class OdeGraphHelper { + + private static final Logger log = LoggerFactory.getLogger(OdeGraphHelper.class); + + private final OdeHelperForward odeHelperForward; + private final OdeHelperBackward odeHelperBackward; + private final CompGraphAsOdeFunction odeFunction; + + public OdeGraphHelper(OdeHelperForward odeHelperForward, OdeHelperBackward odeHelperBackward, CompGraphAsOdeFunction odeFunction) { + this.odeHelperForward = odeHelperForward; + this.odeHelperBackward = odeHelperBackward; + this.odeFunction = odeFunction; + } + + public static class CompGraphAsOdeFunction { + + private INDArray lastOutput; // z(t1) from paper + private ParameterGradientView parameterGradientView; + private final ComputationGraph function; + private final GradientViewFactory gradientViewFactory; + + public CompGraphAsOdeFunction(ComputationGraph odeFunction, GradientViewFactory gradientViewFactory) { + this.function = odeFunction; + this.gradientViewFactory = gradientViewFactory; + } + + private INDArray lastOutput() { + return lastOutput; + } + + private INDArray1DView realGradients() { + return parameterGradientView.realGradientView(); + } + + private void setLastOutput(INDArray lastOutput) { + this.lastOutput = lastOutput; + } + + void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) { + function.setBackpropGradientsViewArray(backpropGradientsViewArray); + parameterGradientView = gradientViewFactory.create(function); + } + + Map paramTable(boolean backpropOnly) { + final Map output = new HashMap<>(); + for(GraphVertex vertex: function.getVertices()) { + final Map partable = vertex.paramTable(backpropOnly); + if(partable != null) { + for(Map.Entry parEntry: partable.entrySet()) { + output.put( + gradientViewFactory.paramNameMapping().map(vertex.getVertexName(), parEntry.getKey()), + parEntry.getValue()); + } + } + } + return output; + } + } + + /** + * Clears the current state wrt training. Gradient views are not touched. + */ + public void clear() { + getFunction().clearLayersStates(); + odeFunction.setLastOutput(null); + } + + public Map paramTable(boolean backpropOnly) { + return odeFunction.paramTable(backpropOnly); + } + + public ComputationGraph getFunction() { + return odeFunction.function; + } + + /** + * What is this about? Some layers "abuse" the gradient to perform updates of parameters for which no gradient + * is calculated and this screws up the ODE solvers idea of what the solution is. The following layers are known + * to do this: + *

+ * * BatchNormalization: The global variance and mean are just the (sliding) average of the batch dittos. + * However, in order to support distributed training the updates are performed by adding + * the change to the state as a gradient even through it is not really. + * @param backpropGradientsViewArray View of parameter gradients + */ + public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) { + odeFunction.setBackpropGradientsViewArray(backpropGradientsViewArray); + } + + + public INDArray doForward(LayerWorkspaceMgr workspaceMgr, INDArray[] inputs) { + + final LayerWorkspaceMgr innerWorkspaceMgr = createWorkspaceMgr(workspaceMgr, getFunction()); + + getFunction().getConfiguration().setIterationCount(0); + final INDArray output = odeHelperForward.solve(getFunction(), innerWorkspaceMgr, inputs); + log.debug("Nrof func eval forward " + getFunction().getIterationCount()); + + odeFunction.setLastOutput(output.detach()); + + return output; + } + + public Pair doBackward( + OdeHelperBackward.MiscPar miscPars, + INDArray lossGradient, + INDArray[] lastInputs) { + + final OdeHelperBackward.InputArrays inputArrays = new OdeHelperBackward.InputArrays( + lastInputs, + odeFunction.lastOutput(), + lossGradient, + odeFunction.realGradients() + ); + + final OdeHelperBackward.MiscPar miscParNewWsMgr = new OdeHelperBackward.MiscPar( + miscPars.isUseTruncatedBackPropTroughTime(), + createWorkspaceMgr(miscPars.getWsMgr(), getFunction()) + ); + + getFunction().getConfiguration().setIterationCount(0); + final INDArray[] gradients = odeHelperBackward.solve(getFunction(), inputArrays, miscParNewWsMgr); + log.debug("Nrof func eval backward " + getFunction().getIterationCount()); + + return new Pair<>(odeFunction.parameterGradientView.allGradientsPerParam(), gradients); + } + + /** + * Changes names of workspaces associated with certain {@link ArrayType}s in order to avoid workspace conflicts + * due to "graph in graph". + * @param outerWsMgr workspace manager + * @return LayerWorkspaceMgr with new workspace names but using the same workspace configs as in {@link ComputationGraph} + */ + private LayerWorkspaceMgr createWorkspaceMgr(final LayerWorkspaceMgr outerWsMgr, ComputationGraph graph) { + if(outerWsMgr == LayerWorkspaceMgr.noWorkspacesImmutable()) { + // This can be handled better, but I just CBA to check presence for every array type right now... + return outerWsMgr; + } + + return new ComputationGraph(graph.getConfiguration()) { + LayerWorkspaceMgr spyWsConfigs() { + // A little bit too many methods to comfortably decorate. Try to copy config instead + final LayerWorkspaceMgr.Builder wsBuilder = LayerWorkspaceMgr.builder(); + for (ArrayType type : ArrayType.values()) { + if (outerWsMgr.hasConfiguration(type)) { + wsBuilder.with(type, outerWsMgr.getWorkspaceName(type), outerWsMgr.getConfiguration(type)); + } + } + + final LayerWorkspaceMgr wsMgr = wsBuilder + .with(ArrayType.FF_WORKING_MEM, "WS_ODE_VERTEX_LAYER_WORKING_MEM", WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, "WS_ODE_VERTEX_LAYER_WORKING_MEM", WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, "WS_ODE_VERTEX_RNN_LOOP_WORKING_MEM", WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, "WS_ODE_VERTEX_RNN_LOOP_WORKING_MEM", WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with(ArrayType.ACTIVATIONS, "WS_ODE_VERTEX_ALL_LAYERS_ACT", WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.ACTIVATION_GRAD, "WS_ODE_VERTEX_ALL_LAYERS_GRAD", WS_ALL_LAYERS_ACT_CONFIG) + .build(); + wsMgr.setHelperWorkspacePointers(outerWsMgr.getHelperWorkspacePointers()); + return wsMgr; + } + }.spyWsConfigs(); + } +} diff --git a/src/main/java/ode/vertex/impl/AugmentedDynamics.java b/src/main/java/ode/vertex/impl/helper/backward/AugmentedDynamics.java similarity index 91% rename from src/main/java/ode/vertex/impl/AugmentedDynamics.java rename to src/main/java/ode/vertex/impl/helper/backward/AugmentedDynamics.java index f859712..e071582 100644 --- a/src/main/java/ode/vertex/impl/AugmentedDynamics.java +++ b/src/main/java/ode/vertex/impl/helper/backward/AugmentedDynamics.java @@ -1,4 +1,4 @@ -package ode.vertex.impl; +package ode.vertex.impl.helper.backward; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; @@ -11,7 +11,7 @@ * * @author Christian Skarby */ -class AugmentedDynamics { +public class AugmentedDynamics { private final INDArray augStateFlat; private final INDArray z; @@ -20,7 +20,7 @@ class AugmentedDynamics { private final INDArray tAdjoint; - AugmentedDynamics(INDArray zAug, long[] zShape, long[] paramShape, long[] tShape) { + public AugmentedDynamics(INDArray zAug, long[] zShape, long[] paramShape, long[] tShape) { this( zAug, zAug.get(NDArrayIndex.interval(0, length(zShape))).reshape(zShape), @@ -45,7 +45,7 @@ private static long length(long[] shape) { return length; } - void updateFrom(INDArray zAug) { + public void updateFrom(INDArray zAug) { augStateFlat.assign(zAug); } diff --git a/src/main/java/ode/vertex/impl/BackpropagateAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/BackpropagateAdjoint.java similarity index 90% rename from src/main/java/ode/vertex/impl/BackpropagateAdjoint.java rename to src/main/java/ode/vertex/impl/helper/backward/BackpropagateAdjoint.java index 0fa0014..4173669 100644 --- a/src/main/java/ode/vertex/impl/BackpropagateAdjoint.java +++ b/src/main/java/ode/vertex/impl/helper/backward/BackpropagateAdjoint.java @@ -1,7 +1,8 @@ -package ode.vertex.impl; +package ode.vertex.impl.helper.backward; import lombok.AllArgsConstructor; import ode.solve.api.FirstOrderEquation; +import ode.vertex.impl.gradview.INDArray1DView; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.GraphVertex; @@ -25,7 +26,7 @@ *
  * f(z(t), theta) = output from forward pass through the layers of the ODE vertex (i.e. the layers of graph)
  * -a(t)*df/dz(t) = dL / dz(t) = epsilon from a backward pass through the layers of the ODE vertex (i.e. the layers of graph) wrt previous output.
- * -a(t) * df / dt = not used, set to 0
+ * -a(t) * df / dt = not used (as of now), set to 0
  * -a(t) df/dtheta = -dL / dtheta = Gradient from a backward pass through the layers of the ODE vertex (i.e. the layers of graph) wrt -epsilon.
  * 
* @@ -40,7 +41,7 @@ public class BackpropagateAdjoint implements FirstOrderEquation { @AllArgsConstructor public static class GraphInfo { private final ComputationGraph graph; - private final NonContiguous1DView realGradients; + private final INDArray1DView realGradients; private final LayerWorkspaceMgr workspaceMgr; private final boolean truncatedBPTT; } @@ -59,7 +60,7 @@ public INDArray calculateDerivative(INDArray zAug, INDArray t, INDArray fzAug) { augmentedDynamics.updateFrom(zAug); // Note: Will also update z - forwardPass.calculateDerivative(augmentedDynamics.z(), t, augmentedDynamics.z()); + forwardPass.calculateDerivative(augmentedDynamics.z().dup(), t, augmentedDynamics.z()); try (WorkspacesCloseable ws = graphInfo.workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS, ArrayType.ACTIVATION_GRAD)) { @@ -86,20 +87,14 @@ private List backPropagate(INDArray epsilon) { final int[] topologicalOrder = graphInfo.graph.topologicalSortOrder(); final GraphVertex[] vertices = graphInfo.graph.getVertices(); + vertices[topologicalOrder[topologicalOrder.length-1]].setEpsilon(epsilon); + List outputEpsilons = new ArrayList<>(); boolean[] setVertexEpsilon = new boolean[topologicalOrder.length]; //If true: already set epsilon for this vertex; later epsilons should be *added* to the existing one, not set for (int i = topologicalOrder.length - 1; i >= 0; i--) { GraphVertex current = vertices[topologicalOrder[i]]; - if (current.isOutputVertex()) { - for (VertexIndices vertexIndices : current.getInputVertices()) { - final String inputName = vertices[vertexIndices.getVertexIndex()].getVertexName(); - graphInfo.graph.getVertex(inputName).setEpsilon(epsilon); - } - continue; - } - if (current.isInputVertex()) { continue; } diff --git a/src/main/java/ode/vertex/impl/helper/backward/FixedStepAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/FixedStepAdjoint.java new file mode 100644 index 0000000..ebe2fcc --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/FixedStepAdjoint.java @@ -0,0 +1,30 @@ +package ode.vertex.impl.helper.backward; + +import ode.solve.api.FirstOrderSolver; +import ode.vertex.impl.helper.backward.timegrad.NoMultiStepTimeGrad; +import ode.vertex.impl.helper.backward.timegrad.NoTimeGrad; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * {@link OdeHelperBackward} with a fixed given sequence of time steps to evaluate the ODE for. + * + * @author Christian Skarby + */ +public class FixedStepAdjoint implements OdeHelperBackward { + + private final OdeHelperBackward helper; + + public FixedStepAdjoint(FirstOrderSolver solver, INDArray time) { + if(time.length() > 2) { + helper = new MultiStepAdjoint(solver, time, NoMultiStepTimeGrad.factory); + } else { + helper = new SingleStepAdjoint(solver, time, NoTimeGrad.factory); + } + } + + @Override + public INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars) { + return helper.solve(graph, input, miscPars); + } +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/InputStepAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/InputStepAdjoint.java new file mode 100644 index 0000000..9e5ded5 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/InputStepAdjoint.java @@ -0,0 +1,59 @@ +package ode.vertex.impl.helper.backward; + +import ode.solve.api.FirstOrderSolver; +import ode.vertex.impl.helper.backward.timegrad.*; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.ArrayList; +import java.util.List; + +/** + * {@link OdeHelperBackward} which uses one of the input {@link INDArray}s as the time steps to evaluate the ODE for + * + * @author Christian Skarby + */ +public class InputStepAdjoint implements OdeHelperBackward { + + private final FirstOrderSolver solver; + private final int timeIndex; + private final boolean needTimeGradient; + + public InputStepAdjoint(FirstOrderSolver solver, int timeIndex, boolean needTimeGradient) { + this.solver = solver; + this.timeIndex = timeIndex; + this.needTimeGradient = needTimeGradient; + } + + + @Override + public INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars) { + final INDArray time = input.getLastInputs()[timeIndex]; + final List notTimeInputs = new ArrayList<>(); + for (int i = 0; i < input.getLastInputs().length; i++) { + if (i != timeIndex) { + notTimeInputs.add(input.getLastInputs()[i]); + } + } + final InputArrays newInput = new InputArrays( + notTimeInputs.toArray(new INDArray[0]), + input.getLastOutput(), + input.getLossGradient(), + input.getRealGradientView() + ); + + if (time.length() > 2) { + final MultiStepTimeGrad.Factory factory = needTimeGradient ? + new CalcMultiStepTimeGrad.Factory(time, timeIndex) : + new ZeroMultiStepTimeGrad.Factory(time, timeIndex); + + return new MultiStepAdjoint(solver, time, factory).solve(graph, newInput, miscPars); + } + + final TimeGrad.Factory factory = needTimeGradient ? + new CalcTimeGrad.Factory(input.getLossGradient(), timeIndex) : + new ZeroTimeGrad.Factory(timeIndex); + + return new SingleStepAdjoint(solver, time, factory).solve(graph, newInput, miscPars); + } +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/MultiStepAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/MultiStepAdjoint.java new file mode 100644 index 0000000..4d93c8c --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/MultiStepAdjoint.java @@ -0,0 +1,124 @@ +package ode.vertex.impl.helper.backward; + +import ode.solve.api.FirstOrderSolver; +import ode.vertex.impl.helper.backward.timegrad.MultiStepTimeGrad; +import ode.vertex.impl.helper.backward.timegrad.TimeGrad; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.ops.transforms.Transforms; + +import java.util.Arrays; + +/** + * {@link OdeHelperBackward} using the adjoint method capable of handling multiple time steps. Gradients will be provided + * for the last time step only. + * + * @author Christian Skarby + */ +public class MultiStepAdjoint implements OdeHelperBackward { + + private final FirstOrderSolver solver; + private final INDArray time; + private final MultiStepTimeGrad.Factory timeGradFactory; + + public MultiStepAdjoint(FirstOrderSolver solver, INDArray time, MultiStepTimeGrad.Factory timeGradFactory) { + this.solver = solver; + this.time = time; + this.timeGradFactory = timeGradFactory; + + if(time.length() <= 2 || !time.isVector()) { + throw new IllegalArgumentException("time must be a vector of size > 2! Was of shape: " + Arrays.toString(time.shape())+ "!"); + } + assertSorted(time); + } + + private void assertSorted(INDArray time) { + int signDiffSum = 0; + for(int i = 0; i < time.length()-1; i++) { + signDiffSum += Transforms.sign(time.getScalar(i).sub(time.getScalar(i+1))).getDouble(0); + } + + if(Math.abs(signDiffSum)+1 != time.length()) { + throw new IllegalArgumentException("Time must be in ascending or descending order! Got: " + time); + } + } + + @Override + public INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars) { + final INDArray zt = alignInShapeToTimeFirst(input.getLastOutput()); + final INDArray dL_dzt = alignInShapeToTimeFirst(input.getLossGradient()); + + assertSizeVsTime(zt); + assertSizeVsTime(dL_dzt); + + final INDArrayIndex[] timeIndexer = createIndexer(time); + timeIndexer[1] = NDArrayIndex.interval(time.length()-2, time.length()); + final INDArrayIndex[] ztIndexer = createIndexer(input.getLastOutput()); + final INDArrayIndex[] dL_dztIndexer= createIndexer(input.getLossGradient()); + + INDArray[] gradients = null; + final MultiStepTimeGrad timeGrad = timeGradFactory.create(); + + // Go backwards in time + for (int step = (int)time.length()-1; step > 0; step--) { + final INDArray ztStep = getStep(ztIndexer,zt, step); + + final INDArray dL_dztStep = getStep(dL_dztIndexer, dL_dzt, step); + final TimeGrad.Factory stepTimeGradFactory = timeGrad.createSingleStepFactory(dL_dztStep.dup()); + + timeGrad.prepareStep(gradients, dL_dztStep); + + final InputArrays stepInput = new InputArrays( + input.getLastInputs(), + ztStep, + dL_dztStep, + input.getRealGradientView() + ); + timeIndexer[1] = NDArrayIndex.interval(step - 1, step+1); + + final OdeHelperBackward stepSolve = new SingleStepAdjoint(solver, time.get(timeIndexer), stepTimeGradFactory); + gradients = stepSolve.solve(graph, stepInput, miscPars); + + timeGrad.updateStep(timeIndexer, gradients); + } + + gradients = timeGrad.updateLastStep(timeIndexer, gradients, getStep(dL_dztIndexer, dL_dzt, 0)); + + return gradients; + } + + private void assertSizeVsTime(INDArray array) { + if(array.size(0) != time.length()) { + throw new IllegalArgumentException("Must have same number of in first dimension as there are time steps! Input: " + + array.size(0) + ", time: " + time.length()); + } + } + + private INDArrayIndex[] createIndexer(INDArray array) { + final INDArrayIndex[] indexer = new INDArrayIndex[array.rank()]; + for(int dim = 0; dim < indexer.length; dim++) { + indexer[dim] = NDArrayIndex.all(); + } + return indexer; + } + + private INDArray getStep(INDArrayIndex[] indexer, INDArray array, int step) { + indexer[0] = NDArrayIndex.point(step); + return array.get(indexer); + } + + private INDArray alignInShapeToTimeFirst(INDArray array) { + + final long[] shape = array.shape(); + switch (shape.length) { + case 3: // Assume recurrent output + return array.permute(2,0,1); + case 5: // Assume conv 3D output + return array.permute(1,0,2,3,4); + // Should not happen as conf throws exception for other types + default: throw new UnsupportedOperationException("Rank not supported: " + array.rank()); + } + } +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/OdeHelperBackward.java b/src/main/java/ode/vertex/impl/helper/backward/OdeHelperBackward.java new file mode 100644 index 0000000..c35af70 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/OdeHelperBackward.java @@ -0,0 +1,55 @@ +package ode.vertex.impl.helper.backward; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import ode.vertex.impl.gradview.INDArray1DView; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Helps with input/output handling when solving ODEs inside a neural network + * + * @author Christian Skarby + */ +public interface OdeHelperBackward { + + /** + * Input arrays needed to do backward pass. Has the following definitions:
+ * {@code lossGradient}: Gradient w.r.t loss from subsequent layers (typically called epsilon in dl4j)
+ * {@code lastOutput}: Last computed output from a forward pass used to calculate the loss gradient
+ * {@code realGradientView}: View of all array elements which are actually gradients in the given + * {@link ComputationGraph}s gradient view array. Notable exceptions (i.e. things labeled as gradients which are not + * are running mean and variance of Batch Normalization layers. + */ + @Getter @AllArgsConstructor + class InputArrays { + + private final INDArray[] lastInputs; + private final INDArray lastOutput; + private final INDArray lossGradient; + private final INDArray1DView realGradientView; + } + + /** + * Misc parameters needed to jump through the hoops of doing back propagation + */ + @Getter @AllArgsConstructor + class MiscPar { + private final boolean useTruncatedBackPropTroughTime; + private final LayerWorkspaceMgr wsMgr; + } + + /** + * Return the solution to the ODE when assuming that a backwards pass through the layers of the given graph is + * the derivative of the sought function. Note that parameter gradient is set in given graph so it is not returned. + * + * @param graph Graph of layers to do backwards pass through + * @param input Input arrays + * @param miscPars Misc parameters needed to jump through the hoops of doing back propagation + * + * @return Loss gradients (a.k.a epsilon in dl4j) w.r.t last input from previous layers. Note that parameter gradients + * are set in graph and can be accessed through graph.getGradientsViewArray() + */ + INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars); +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/SingleStepAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/SingleStepAdjoint.java new file mode 100644 index 0000000..84f6097 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/SingleStepAdjoint.java @@ -0,0 +1,88 @@ +package ode.vertex.impl.helper.backward; + +import ode.solve.api.FirstOrderEquation; +import ode.solve.api.FirstOrderSolver; +import ode.vertex.impl.gradview.INDArray1DView; +import ode.vertex.impl.helper.NDArrayIndexAccumulator; +import ode.vertex.impl.helper.backward.timegrad.TimeGrad; +import ode.vertex.impl.helper.forward.ForwardPass; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Arrays; + +/** + * {@link OdeHelperBackward} using the adjoint method capable of handling a single time step. Gradients for time steps + * will only be provided if required. + * + * @author Christian Skarby + */ +public class SingleStepAdjoint implements OdeHelperBackward { + + private final FirstOrderSolver solver; + private final INDArray time; + private final TimeGrad.Factory timeGradFactory; + + public SingleStepAdjoint(FirstOrderSolver solver, INDArray time, TimeGrad.Factory timeGradFactory) { + this.solver = solver; + this.time = time; + this.timeGradFactory = timeGradFactory; + if (time.length() != 2 && time.rank() != 1) { + throw new IllegalArgumentException("time must be a vector with two elements! Was of shape: " + Arrays.toString(time.shape()) + "!"); + } + } + + @Override + public INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars) { + + // Create augmented dynamics for adjoint method + // Initialization: S0: + // z(t1) = lastoutput + // a(t) = -dL/d(z(t1)) = -epsilon from next layer (i.e getEpsilon). Use last row if more than one timestep + // parameters = zeros + // dL/dt1 = -dL / dz(t1) dot dz(t1) / dt1 + + final INDArray dL_dzt1 = input.getLossGradient(); + final INDArray zt1 = input.getLastOutput(); + final INDArray1DView realParamGrads = input.getRealGradientView(); + + final FirstOrderEquation forward = new ForwardPass(graph, + miscPars.getWsMgr(), + true, // Always use training as batch norm running mean and var become messed up otherwise. Same effect seen in original pytorch repo. + input.getLastInputs()); + + final TimeGrad timeGrad = timeGradFactory.create(); + final INDArray dL_dt1 = timeGrad.calcTimeGradT1(forward, zt1, time); + + final INDArray zAug = Nd4j.create(1, zt1.length() + dL_dzt1.length() + graph.numParams() + dL_dt1.length()); + final INDArray paramAdj = Nd4j.zeros(realParamGrads.length()); + realParamGrads.assignTo(paramAdj); + + final NDArrayIndexAccumulator accumulator = new NDArrayIndexAccumulator(zAug); + accumulator.increment(zt1.reshape(new long[]{1, zt1.length()})) + .increment(dL_dzt1.reshape(new long[]{1, dL_dzt1.length()})) + .increment(paramAdj.reshape(new long[]{1, paramAdj.length()})) + .increment(dL_dt1); + + final AugmentedDynamics augmentedDynamics = new AugmentedDynamics( + zAug, + dL_dzt1.shape(), + new long[]{realParamGrads.length()}, + dL_dt1.shape()); + + final FirstOrderEquation equation = new BackpropagateAdjoint( + augmentedDynamics, + forward, + new BackpropagateAdjoint.GraphInfo(graph, realParamGrads, miscPars.getWsMgr(), miscPars.isUseTruncatedBackPropTroughTime()) + ); + + INDArray augAns = solver.integrate(equation, Nd4j.reverse(time.dup()), zAug, zAug.dup()); + + augmentedDynamics.updateFrom(augAns); + + realParamGrads.assignFrom(augmentedDynamics.paramAdjoint()); + + return timeGrad.createLossGradient(augmentedDynamics.zAdjoint(), augmentedDynamics.tAdjoint()); + } +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcMultiStepTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcMultiStepTimeGrad.java new file mode 100644 index 0000000..1728fd0 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcMultiStepTimeGrad.java @@ -0,0 +1,70 @@ +package ode.vertex.impl.helper.backward.timegrad; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +/** + * Calculates a time gradient for multiple time steps. + * + * @author Christian Skarby + */ +public class CalcMultiStepTimeGrad implements MultiStepTimeGrad { + + private final INDArray timeGradient; + private final int timeIndex; + + private final INDArray lastTime = Nd4j.scalar(0); + + /** + * Factory for this class. + */ + public static class Factory implements MultiStepTimeGrad.Factory { + + private final INDArray time; + private final int timeIndex; + + public Factory(INDArray time, int timeIndex) { + this.time = time; + this.timeIndex = timeIndex; + } + + @Override + public MultiStepTimeGrad create() { + return new CalcMultiStepTimeGrad(time, timeIndex); + } + } + + public CalcMultiStepTimeGrad(INDArray time, int timeIndex) { + timeGradient = Nd4j.zeros(time.shape()); + this.timeIndex = timeIndex; + } + + @Override + public void prepareStep(INDArray[] lastGradients, INDArray dL_dzt) { + if(lastGradients == null) return; + + dL_dzt.addi(lastGradients[(timeIndex+1) % lastGradients.length]); + } + + @Override + public void updateStep(INDArrayIndex[] timeIndexer, INDArray[] gradients) { + timeGradient.put(timeIndexer, gradients[timeIndex]); + lastTime.subi(timeGradient.get(timeIndexer).getScalar(0)); + } + + @Override + public INDArray[] updateLastStep(INDArrayIndex[] timeIndexer, INDArray[] gradients, INDArray dL_dzt0) { + timeIndexer[1] = NDArrayIndex.point(0); + timeGradient.put(timeIndexer, lastTime); + gradients[timeIndex] = timeGradient; + gradients[(timeIndex + 1) % gradients.length].addi(dL_dzt0); + return gradients; + } + + @Override + public TimeGrad.Factory createSingleStepFactory(INDArray dL_dzt1_time) { + return new CalcTimeGrad.Factory(dL_dzt1_time, timeIndex); + } +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcTimeGrad.java new file mode 100644 index 0000000..67439e3 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcTimeGrad.java @@ -0,0 +1,67 @@ +package ode.vertex.impl.helper.backward.timegrad; + +import ode.solve.api.FirstOrderEquation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Calculates a time gradient for a single time step. + * + * @author Christian Skarby + */ +public class CalcTimeGrad implements TimeGrad { + + private final INDArray dL_dzt1_time; + private final int timeIndex; + + private INDArray dL_dt1; + + /** + * Factory for this class + */ + public static class Factory implements TimeGrad.Factory { + + private final INDArray dL_dzt1_time; + private final int timeIndex; + + public Factory(INDArray dL_dzt1_time, int timeIndex) { + this.dL_dzt1_time = dL_dzt1_time; + this.timeIndex = timeIndex; + } + + @Override + public TimeGrad create() { + return new CalcTimeGrad(dL_dzt1_time, timeIndex); + } + } + + public CalcTimeGrad(INDArray dL_dzt1_time, int timeIndex) { + this.dL_dzt1_time = dL_dzt1_time; + this.timeIndex = timeIndex; + } + + @Override + public INDArray calcTimeGradT1(FirstOrderEquation equation, INDArray zt1, INDArray time) { + final INDArray dzt1_dt1 = equation.calculateDerivative(zt1, time.getColumn(1), zt1.dup()); + + this.dL_dt1 = dL_dzt1_time.reshape(1, dzt1_dt1.length()) + .mmul(dzt1_dt1.reshape(dzt1_dt1.length(), 1)); + return dL_dt1; + } + + @Override + public INDArray[] createLossGradient(INDArray zAdjoint, INDArray tAdjoint) { + if(dL_dt1 == null) { + throw new IllegalStateException("Must compute dL / dt1 before creating loss gradient!"); + } + + final INDArray[] epsilons = new INDArray[2]; + for (int i = 0; i < epsilons.length; i++) { + if (i != timeIndex) { + epsilons[i] = zAdjoint; + } + } + epsilons[timeIndex] = Nd4j.hstack(dL_dt1, tAdjoint); + return epsilons; + } +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/MultiStepTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/MultiStepTimeGrad.java new file mode 100644 index 0000000..4252f7e --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/MultiStepTimeGrad.java @@ -0,0 +1,55 @@ +package ode.vertex.impl.helper.backward.timegrad; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.indexing.INDArrayIndex; + +/** + * Interface for handling time gradient w.r.t loss when multiple time steps are used. Main use case is to be able to + * not calculate time gradients when not needed. + * + * @author Christian Skarby + */ +public interface MultiStepTimeGrad { + + /** + * Factory for MultiStepTimeGrads + */ + interface Factory { + /** + * Return a {@link MultiStepTimeGrad} instance + * @return a {@link MultiStepTimeGrad} instance + */ + MultiStepTimeGrad create(); + } + + /** + * Update dL_dzt based on last steps gradients + * @param lastGradients gradients from last time step + * @param dL_dzt Loss gradient for z(t) for current time step + */ + void prepareStep(INDArray[] lastGradients, INDArray dL_dzt); + + /** + * Update after a single step backwards + * @param timeIndexer Points to the current time step + * @param gradients Current gradients + */ + void updateStep(INDArrayIndex[] timeIndexer, INDArray[] gradients); + + /** + * Update after the last step backwards has been performed. Note that this will update the input gradient + * @param timeIndexer Points to the current time step + * @param gradients Current gradients. Might be updated as a result + * @param dL_dzt0 Loss gradient for z(t0). + * @return Updated gradients + */ + INDArray[] updateLastStep(INDArrayIndex[] timeIndexer, INDArray[] gradients, INDArray dL_dzt0); + + /** + * Create an appropriate {@link TimeGrad.Factory} + * @param dL_dzt1_time Loss gradient for calculation of time gradient + * @return a {@link TimeGrad.Factory} + */ + TimeGrad.Factory createSingleStepFactory(INDArray dL_dzt1_time); + +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoMultiStepTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoMultiStepTimeGrad.java new file mode 100644 index 0000000..250586c --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoMultiStepTimeGrad.java @@ -0,0 +1,49 @@ +package ode.vertex.impl.helper.backward.timegrad; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.indexing.INDArrayIndex; + +/** + * Does not calculcate any time gradients + * + * @author Christian Skarby + */ +public class NoMultiStepTimeGrad implements MultiStepTimeGrad { + + // Singleton because stateless + private final static MultiStepTimeGrad noTimeGrad = new NoMultiStepTimeGrad(); + + /** + * Factory for this class + */ + public static MultiStepTimeGrad.Factory factory = new MultiStepTimeGrad.Factory() { + + @Override + public MultiStepTimeGrad create() { + return noTimeGrad; + } + }; + + @Override + public void prepareStep(INDArray[] lastGradients, INDArray dL_dzt) { + if(lastGradients == null) return; + + dL_dzt.addi(lastGradients[0]); + } + + @Override + public void updateStep(INDArrayIndex[] timeIndexer, INDArray[] gradients) { + // Do nothing + } + + @Override + public INDArray[] updateLastStep(INDArrayIndex[] timeIndexer, INDArray[] gradients, INDArray dL_dzt) { + gradients[0].addi(dL_dzt); + return gradients; + } + + @Override + public TimeGrad.Factory createSingleStepFactory(INDArray dL_dzt1_time) { + return NoTimeGrad.factory; + } +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoTimeGrad.java new file mode 100644 index 0000000..b5e5f64 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoTimeGrad.java @@ -0,0 +1,29 @@ +package ode.vertex.impl.helper.backward.timegrad; + +import ode.solve.api.FirstOrderEquation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +public class NoTimeGrad implements TimeGrad { + + private final static TimeGrad noTimeGrad = new NoTimeGrad(); + public static TimeGrad.Factory factory = new TimeGrad.Factory() { + + @Override + public TimeGrad create() { + // Singleton because stateless + return noTimeGrad; + } + }; + + + @Override + public INDArray calcTimeGradT1(FirstOrderEquation equation, INDArray zt1, INDArray time) { + return Nd4j.empty(); + } + + @Override + public INDArray[] createLossGradient(INDArray zAdjoint, INDArray tAdjoint) { + return new INDArray[] {zAdjoint}; + } +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/TimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/TimeGrad.java new file mode 100644 index 0000000..190b1e9 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/TimeGrad.java @@ -0,0 +1,42 @@ +package ode.vertex.impl.helper.backward.timegrad; + +import ode.solve.api.FirstOrderEquation; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Interface for handling time gradient w.r.t loss for a single time step. Main use case is to be able to + * not calculate time gradients when not needed. + * + * @author Christian Skarby + */ +public interface TimeGrad { + + /** + * Factory for TimeGrads + */ + interface Factory { + /** + * Return a {@link TimeGrad} instance + * @return a {@link TimeGrad} instance + */ + TimeGrad create(); + } + + /** + * Calculate time gradient for the last time point (t1) + * @param equation Calculates dz(t1) / d(t1) + * @param zt1 Value of z(t1) + * @param time + * @return dLoss / dt1 or empty if no time gradient needed + */ + INDArray calcTimeGradT1(FirstOrderEquation equation, INDArray zt1, INDArray time); + + /** + * Create loss gradient (a.k.a epsilons in dl4j) from adjoint state. + * @param zAdjoint dL / dz(t0) + * @param tAdjoint dL / dt0 + * @return Array of required loss gradients + */ + INDArray[] createLossGradient(INDArray zAdjoint, INDArray tAdjoint); + +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroMultiStepTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroMultiStepTimeGrad.java new file mode 100644 index 0000000..b5a9a3f --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroMultiStepTimeGrad.java @@ -0,0 +1,70 @@ +package ode.vertex.impl.helper.backward.timegrad; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; + +/** + * {@link MultiStepTimeGrad} which does not compute and time gradients, but returns a dummy gradient so that array sizes + * are correct + * + * @author Christian Skarby + */ +public class ZeroMultiStepTimeGrad implements MultiStepTimeGrad { + + private final int timeIndex; + private final INDArray time; + + /** + * Factory for this class. + */ + public static class Factory implements MultiStepTimeGrad.Factory { + + private final INDArray time; + private final int timeIndex; + + public Factory(INDArray time, int timeIndex) { + this.time = time; + this.timeIndex = timeIndex; + } + + @Override + public MultiStepTimeGrad create() { + return new ZeroMultiStepTimeGrad(time, timeIndex); + } + } + + public ZeroMultiStepTimeGrad(INDArray time, int timeIndex) { + this.time = time; + this.timeIndex = timeIndex; + } + + @Override + public void prepareStep(INDArray[] lastGradients, INDArray dL_dzt) { + if(lastGradients == null) return; + + dL_dzt.addi(lastGradients[(timeIndex+1) % lastGradients.length]); + } + + @Override + public void updateStep(INDArrayIndex[] timeIndexer, INDArray[] gradients) { + // Do nothing + } + + @Override + public INDArray[] updateLastStep(INDArrayIndex[] timeIndexer, INDArray[] gradients, INDArray dL_dzt) { + + final INDArray[] toRet = new INDArray[2]; + toRet[timeIndex] = Nd4j.zerosLike(time); + final int notTimeIndex = (timeIndex + 1) % toRet.length; + toRet[notTimeIndex] = gradients[notTimeIndex % gradients.length]; + gradients[notTimeIndex % gradients.length].addi(dL_dzt); + + return toRet; + } + + @Override + public TimeGrad.Factory createSingleStepFactory(INDArray dL_dzt1_time) { + return NoTimeGrad.factory; + } +} diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroTimeGrad.java new file mode 100644 index 0000000..19ad527 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroTimeGrad.java @@ -0,0 +1,55 @@ +package ode.vertex.impl.helper.backward.timegrad; + +import ode.solve.api.FirstOrderEquation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * {@link TimeGrad} which does not compute and time gradients, but returns a dummy gradient so that array sizes + * are correct + * + * @author Christian Skarby + */ +public class ZeroTimeGrad implements TimeGrad { + + private final int timeIndex; + + /** + * Factory for this class + */ + public static class Factory implements TimeGrad.Factory { + + private final int timeIndex; + + public Factory(int timeIndex) { + this.timeIndex = timeIndex; + } + + @Override + public TimeGrad create() { + return new ZeroTimeGrad(timeIndex); + } + } + + + public ZeroTimeGrad(int timeIndex) { + this.timeIndex = timeIndex; + } + + @Override + public INDArray calcTimeGradT1(FirstOrderEquation equation, INDArray zt1, INDArray time) { + return Nd4j.empty(); + } + + @Override + public INDArray[] createLossGradient(INDArray zAdjoint, INDArray tAdjoint) { + final INDArray[] epsilons = new INDArray[2]; + for (int i = 0; i < epsilons.length; i++) { + if (i != timeIndex) { + epsilons[i] = zAdjoint; + } + } + epsilons[timeIndex] = Nd4j.zeros(1,2); + return epsilons; + } +} diff --git a/src/main/java/ode/vertex/impl/helper/forward/FixedStep.java b/src/main/java/ode/vertex/impl/helper/forward/FixedStep.java new file mode 100644 index 0000000..35831ba --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/forward/FixedStep.java @@ -0,0 +1,35 @@ +package ode.vertex.impl.helper.forward; + +import ode.solve.api.FirstOrderMultiStepSolver; +import ode.solve.api.FirstOrderSolver; +import ode.solve.impl.InterpolatingMultiStepSolver; +import ode.solve.impl.SingleSteppingMultiStepSolver; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * {@link OdeHelperForward} with a fixed given sequence of time steps to evaluate the ODE for. + * + * @author Christian Skarby + */ +public class FixedStep implements OdeHelperForward { + + private final OdeHelperForward helper; + + public FixedStep(FirstOrderSolver solver, INDArray time, boolean interpolateIfMultiStep) { + if(time.length() > 2) { + final FirstOrderMultiStepSolver multiStepSolver = interpolateIfMultiStep ? + new InterpolatingMultiStepSolver(solver) : + new SingleSteppingMultiStepSolver(solver); + helper = new MultiStep(multiStepSolver, time); + } else { + helper = new SingleStep(solver, time); + } + } + + @Override + public INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs) { + return helper.solve(graph, wsMgr, inputs); + } +} diff --git a/src/main/java/ode/vertex/impl/ForwardPass.java b/src/main/java/ode/vertex/impl/helper/forward/ForwardPass.java similarity index 90% rename from src/main/java/ode/vertex/impl/ForwardPass.java rename to src/main/java/ode/vertex/impl/helper/forward/ForwardPass.java index f36d9db..0d35252 100644 --- a/src/main/java/ode/vertex/impl/ForwardPass.java +++ b/src/main/java/ode/vertex/impl/helper/forward/ForwardPass.java @@ -1,6 +1,7 @@ -package ode.vertex.impl; +package ode.vertex.impl.helper.forward; import ode.solve.api.FirstOrderEquation; +import ode.vertex.impl.helper.NDArrayIndexAccumulator; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.GraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; @@ -34,12 +35,12 @@ public ForwardPass(ComputationGraph graph, this.graph = graph; this.workspaceMgr = workspaceMgr; this.training = training; - this.inputs = startInputs; + this.inputs = startInputs.clone(); } - @Override public INDArray calculateDerivative(INDArray y, INDArray t, INDArray fy) { + graph.getConfiguration().setIterationCount(graph.getIterationCount() + 1); try (WorkspacesCloseable ws = enterIfNotOpen(ArrayType.ACTIVATIONS)) { setInputsFromFlat(y); evaluate(inputs, fy); @@ -64,7 +65,7 @@ private void setInputsFromFlat(INDArray flatArray) { INDArray input = inputs[i]; final INDArray z = flatArray.get(NDArrayIndex.interval(lastInd, lastInd + input.length())); lastInd += input.length(); - inputs[i].assign(z.reshape(input.shape())); + inputs[i] = z.reshape(input.shape()); } } @@ -81,19 +82,17 @@ private void evaluate(INDArray[] inputs, INDArray output) { VertexIndices[] inputsTo = current.getOutputVertices(); - INDArray out = null; + final INDArray out; if (current.isInputVertex()) { out = inputs[vIdx]; - } else if (current.isOutputVertex()) { - for (INDArray outArr : current.getInputs()) { - outputAccum.increment(outArr); - } } else { //Standard feed-forward case out = current.doForward(training, workspaceMgr); } - if (inputsTo != null) { //Output vertices may not input to any other vertices + if (inputsTo == null) { //Output vertices may not input to any other vertices + outputAccum.increment(out); + } else { for (VertexIndices v : inputsTo) { //Note that we don't have to do anything special here: the activations are always detached in // this method diff --git a/src/main/java/ode/vertex/impl/helper/forward/InputStep.java b/src/main/java/ode/vertex/impl/helper/forward/InputStep.java new file mode 100644 index 0000000..09cfee9 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/forward/InputStep.java @@ -0,0 +1,42 @@ +package ode.vertex.impl.helper.forward; + +import ode.solve.api.FirstOrderSolver; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.ArrayList; +import java.util.List; + +/** + * {@link OdeHelperForward} which uses one of the input {@link INDArray}s as the time steps to evaluate the ODE for + * + * @author Christian Skarby + */ +public class InputStep implements OdeHelperForward { + + private final FirstOrderSolver solver; + private final int timeInputIndex; + private final boolean interpolateIfMultiStep; + + public InputStep(FirstOrderSolver solver, int timeInputIndex, boolean interpolateIfMultiStep) { + this.solver = solver; + this.timeInputIndex = timeInputIndex; + this.interpolateIfMultiStep = interpolateIfMultiStep; + } + + @Override + public INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs) { + final List notTimeInputs = new ArrayList<>(); + for (int i = 0; i < inputs.length; i++) { + if (i != timeInputIndex) { + notTimeInputs.add(inputs[i]); + } + } + return new FixedStep( + solver, + inputs[timeInputIndex], + interpolateIfMultiStep) + .solve(graph, wsMgr, notTimeInputs.toArray(new INDArray[0])); + } +} diff --git a/src/main/java/ode/vertex/impl/helper/forward/MultiStep.java b/src/main/java/ode/vertex/impl/helper/forward/MultiStep.java new file mode 100644 index 0000000..2bf8c69 --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/forward/MultiStep.java @@ -0,0 +1,64 @@ +package ode.vertex.impl.helper.forward; + +import com.google.common.primitives.Longs; +import ode.solve.api.FirstOrderEquation; +import ode.solve.api.FirstOrderMultiStepSolver; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Arrays; + +/** + * {@link OdeHelperForward} capable of handling multiple time steps + * + * @author Christian Skarby + */ +public class MultiStep implements OdeHelperForward { + + private final FirstOrderMultiStepSolver solver; + private final INDArray time; + + public MultiStep(FirstOrderMultiStepSolver solver, INDArray time) { + this.solver = solver; + this.time = time; + if(time.length() <= 2 || !time.isVector()) { + throw new IllegalArgumentException("time must be a vector of size > 2! Was of shape: " + Arrays.toString(time.shape())+ "!"); + } + } + + @Override + public INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs) { + if (inputs.length != 1) { + throw new IllegalArgumentException("Only single input supported!"); + } + + final FirstOrderEquation equation = new ForwardPass( + graph, + wsMgr, + true, + inputs + ); + + final INDArray z0 = inputs[0].dup(); + final INDArray zt = Nd4j.createUninitialized(Longs.concat(new long[]{time.length() - 1}, z0.shape())); + solver.integrate(equation, time, inputs[0].dup(), zt); + + return alignOutShape(zt, z0); + } + + + private INDArray alignOutShape(INDArray zt, INDArray z0) { + final long[] shape = zt.shape(); + switch (shape.length) { + case 3: // Assume recurrent output + return Nd4j.concat(0, z0.reshape(1, shape[1], shape[2]), zt).permute(1, 2, 0); + case 5: // Assume conv 3D output + return Nd4j.concat(0, z0.reshape(1, shape[1], shape[2], shape[3], shape[4]), zt).permute(1, 0, 2, 3, 4); + // Should not happen as conf throws exception for other types + default: + throw new UnsupportedOperationException("Rank not supported: " + zt.rank()); + } + } +} diff --git a/src/main/java/ode/vertex/impl/helper/forward/OdeHelperForward.java b/src/main/java/ode/vertex/impl/helper/forward/OdeHelperForward.java new file mode 100644 index 0000000..9b8309e --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/forward/OdeHelperForward.java @@ -0,0 +1,23 @@ +package ode.vertex.impl.helper.forward; + +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Helps with input/output handling when solving ODEs inside a neural network + * + * @author Christian Skarby + */ +public interface OdeHelperForward { + + /** + * Return the solution to the ODE when assuming that a forward pass through the layers of the given graph is + * the derivative of the sought function. + * @param graph Graph of layers to do forward pass through + * @param wsMgr To handle workspaces for newly created arrays + * @param inputs Inputs to vertex, typically activations from previous layers + * @return an {@link INDArray} with the solution to the ODE + */ + INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs); +} diff --git a/src/main/java/ode/vertex/impl/helper/forward/SingleStep.java b/src/main/java/ode/vertex/impl/helper/forward/SingleStep.java new file mode 100644 index 0000000..837911b --- /dev/null +++ b/src/main/java/ode/vertex/impl/helper/forward/SingleStep.java @@ -0,0 +1,50 @@ +package ode.vertex.impl.helper.forward; + + +import ode.solve.api.FirstOrderEquation; +import ode.solve.api.FirstOrderSolver; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Arrays; + +/** + * Simple {@link OdeHelperForward} capable of only a single time step. Main difference compared to using + * {@link MultiStep} is that the latter will return output (zt) which includes the first step (z0) + * as well. + */ +public class SingleStep implements OdeHelperForward { + + private final FirstOrderSolver solver; + private final INDArray time; + + public SingleStep(FirstOrderSolver solver, INDArray time) { + this.solver = solver; + this.time = time; + if(time.length() != 2 && time.rank() != 1) { + throw new IllegalArgumentException("time must be a vector with two elements! Was of shape: " + Arrays.toString(time.shape())+ "!"); + } + } + + @Override + public INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs) { + if (inputs.length != 1) { + throw new IllegalArgumentException("Only single input supported!"); + } + + final FirstOrderEquation equation = new ForwardPass( + graph, + wsMgr, + true, // Always use training as batch norm running mean and var become messed up otherwise. Same effect seen in original pytorch repo. + inputs + ); + + final INDArray z0 = inputs[0]; + final INDArray zt = Nd4j.createUninitialized(z0.shape()); + solver.integrate(equation, time, z0, zt); + + return zt; + } +} diff --git a/src/main/java/util/listen/step/Mask.java b/src/main/java/util/listen/step/Mask.java index e438fb9..cd3d9e0 100644 --- a/src/main/java/util/listen/step/Mask.java +++ b/src/main/java/util/listen/step/Mask.java @@ -1,6 +1,7 @@ package util.listen.step; import ode.solve.api.StepListener; +import ode.solve.impl.util.SolverState; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -61,9 +62,9 @@ public void begin(INDArray t, INDArray y0) { } @Override - public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) { + public void step(SolverState solverState, INDArray step, INDArray error) { if(mask) { - listener.step(currTime, step, error, y); + listener.step(solverState, step, error); } } diff --git a/src/main/java/util/listen/step/StepCounter.java b/src/main/java/util/listen/step/StepCounter.java index e9c2df7..f13e99f 100644 --- a/src/main/java/util/listen/step/StepCounter.java +++ b/src/main/java/util/listen/step/StepCounter.java @@ -1,6 +1,7 @@ package util.listen.step; import ode.solve.api.StepListener; +import ode.solve.impl.util.SolverState; import org.nd4j.linalg.api.ndarray.INDArray; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,7 +48,7 @@ public void begin(INDArray t, INDArray y0) { } @Override - public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) { + public void step(SolverState solverState, INDArray step, INDArray error) { nrofSteps++; } diff --git a/src/main/java/util/plot/Plot.java b/src/main/java/util/plot/Plot.java new file mode 100644 index 0000000..2835e1f --- /dev/null +++ b/src/main/java/util/plot/Plot.java @@ -0,0 +1,89 @@ +package util.plot; + +import java.io.IOException; +import java.util.List; + +/** + * Interface for XY plotting. Supports serialization of plot data. + * @param + * @param + */ +public interface Plot { + + /** + * Factory interface + */ + interface Factory { + + /** + * Create a new plot with the given title + * @param title title of the plot + * @return a new Plot instance + */ + Plot create(String title); + } + + /** + * Creates a time series for the given label. If data with the given label exists in serialized format in the + * plotDir the time series of that data will be recreated. + * @param label series label. + */ + void createSeries(String label); + + /** + * Plot some data belonging to a certain label. Will be appended to an existing series of such exists, either in + * an existing window or in serialized format in the plotDir. If no series with the given label exists it will + * be created in the window of this plot instance. + * @param label series label + * @param x point on x axis + * @param y point on y axis + */ + void plotData(String label, X x, Y y); + + /** + * Plot some data belonging to a certain label. Will be appended to an existing series of such exists, either in + * an existing window or in serialized format in the plotDir. If no series with the given label exists it will + * be created in the window of this plot instance. + * @param label series label + * @param x points on x axis + * @param y points on y axis + */ + void plotData(String label, List x, List y); + + /** + * Clears the data for the given label + * @param label series label + */ + void clearData(String label); + + /** + * Serialize the data for all labels. + * @throws IOException + */ + void storePlotData() throws IOException; + + /** + * Serialize the data for the given label. + * @param label series label + * @throws IOException + */ + void storePlotData(String label) throws IOException; + + /** + * Save plot as a picture + */ + void savePicture(String suffix) throws IOException; + + /** + * Convenience method for debugging purposes. Plots the given data vs list indexes + * @param data + * @param + */ + static void plot(List data, String plotName) { + final RealTimePlot plotter = new RealTimePlot<>(plotName, ""); + plotter.createSeries("data"); + for(int x = 0; x < data.size(); x ++) { + plotter.plotData("data", x, data.get(x)); + } + } +} diff --git a/src/main/java/util/plot/RealTimePlot.java b/src/main/java/util/plot/RealTimePlot.java new file mode 100644 index 0000000..7920423 --- /dev/null +++ b/src/main/java/util/plot/RealTimePlot.java @@ -0,0 +1,257 @@ + +package util.plot; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import com.fasterxml.jackson.databind.module.SimpleModule; +import lombok.Data; +import org.jetbrains.annotations.NotNull; +import org.knowm.xchart.*; +import org.knowm.xchart.style.Styler; +import org.knowm.xchart.style.Styler.ChartTheme; + +import javax.swing.*; +import java.io.File; +import java.io.IOException; +import java.io.Serializable; +import java.io.UncheckedIOException; +import java.util.*; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +/** + * Real time updatable plot with support for an arbitrary number of series. Can also serialize the plotted data and + * recreate a plot from such data. Typically used for plot training/eval metrics for each iteration. Note: The amount + * of data points per timeseries is limited to 1000 as a significant slowdown was observed for higher numbers. When 1000 + * points is reached, all even points will be removed. New points after this will be added as normal until the total hits + * 1000 again. + * + * @author Christian Skärby + */ +public class RealTimePlot implements Plot { + + private final String title; + private final XYChart xyChart; + private final SwingWrapper swingWrapper; + private final String plotDir; + + private final Map> plotSeries = new HashMap<>(); + + @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, property = "@class") + @Data + private static class DataXY implements Serializable { + + private static final long serialVersionUID = 7526471155622776891L; + private final String series; + + private final LinkedList xData; + private final LinkedList yData; + + DataXY(String series) { + this(series, new LinkedList<>(), new LinkedList<>()); + } + + DataXY( + @JsonProperty("series") String series, + @JsonProperty("xData") List xData, + @JsonProperty("yData") List yData) { + this.series = series; + this.xData = new LinkedList<>(xData); + this.yData = new LinkedList<>(yData); + } + + private void addPoint(X x, Y y, XYChart xyChart, SwingWrapper swingWrapper) { + xData.addLast(x); + yData.addLast(y); + if (xData.size() > 1000) { + for (int i = 0; i < xData.size(); i += 2) { + xData.remove(i); + yData.remove(i); + } + } + plotData(xyChart, swingWrapper); + } + + private void addData(List x, List y, XYChart xyChart, SwingWrapper swingWrapper) { + xData.addAll(x); + yData.addAll(y); + plotData(xyChart, swingWrapper); + } + + private void plotData(XYChart xyChart, SwingWrapper swingWrapper) { + javax.swing.SwingUtilities.invokeLater(() -> { + if (!xyChart.getSeriesMap().containsKey(series)) { + xyChart.addSeries(series, xData, yData, null); + } else { + xyChart.updateXYSeries(series, xData, yData, null); + } + swingWrapper.repaintChart(); + }); + } + + private void createSeries(final XYChart xyChart, SwingWrapper swingWrapper) { + javax.swing.SwingUtilities.invokeLater(() -> { + if (xData.size() == 0) { + xyChart.addSeries(series, Arrays.asList(0), Arrays.asList(1)); + } else { + xyChart.addSeries(series, xData, yData); + } + swingWrapper.repaintChart(); + }); + } + + + private void clear() { + xData.clear(); + yData.clear(); + } + } + + private class DataXYDeserializer extends StdDeserializer> { + + public DataXYDeserializer() { + this(null); + } + + public DataXYDeserializer(Class vc) { + super(vc); + } + + @Override + public DataXY deserialize(JsonParser jp, DeserializationContext ctxt) + throws IOException { + JsonNode node = jp.getCodec().readTree(jp); + + final List xData = new ArrayList<>(); + for (JsonNode xNode : node.get("xdata")) { + xData.add((X) xNode.numberValue()); + } + + final List yData = new ArrayList<>(); + for (JsonNode xNode : node.get("ydata")) { + yData.add((Y) xNode.numberValue()); + } + return new DataXY<>(node.get("series").toString(), xData, yData); + } + + } + + /** + * Constructor + * + * @param title Title of the plot + * @param plotDir Directory to store plots in. + */ + public RealTimePlot(String title, String plotDir) { + // Create Chart + this.title = title; + xyChart = new XYChartBuilder().width(800).height(500).theme(ChartTheme.Matlab).title(title).build(); + xyChart.getStyler().setLegendPosition(Styler.LegendPosition.OutsideE); + xyChart.getStyler().setDefaultSeriesRenderStyle(XYSeries.XYSeriesRenderStyle.Line); + + this.swingWrapper = new SwingWrapper<>(xyChart); + swingWrapper.displayChart(); + this.plotDir = plotDir; + } + + @Override + public void plotData(String label, X x, Y y) { + final DataXY data = getOrCreateSeries(label); + data.addPoint(x, y, xyChart, swingWrapper); + } + + @Override + public void plotData(String label, List x, List y) { + final DataXY data = getOrCreateSeries(label); + data.addData(x, y, xyChart, swingWrapper); + } + + @Override + public void clearData(String label) { + final DataXY data = getOrCreateSeries(label); + data.clear(); + } + + @Override + public void createSeries(String label) { + getOrCreateSeries(label); + } + + @NotNull + private DataXY getOrCreateSeries(String label) { + DataXY data = plotSeries.get(label); + if (data == null) { + data = restoreOrCreatePlotData(label); + plotSeries.put(label, data); + data.createSeries(xyChart, swingWrapper); + } + return data; + } + + @Override + public void storePlotData() throws IOException { + for (String label : plotSeries.keySet()) { + storePlotData(label); + } + } + + @Override + public void storePlotData(String label) throws IOException { + DataXY data = plotSeries.get(label); + if (data != null) { + new ObjectMapper().writeValue(new File(createFileName(label)), data); + } + } + + @Override + public void savePicture(String suffix) { + SwingUtilities.invokeLater(() -> { + try { + BitmapEncoder.saveBitmap(xyChart, plotDir + File.separator + title + suffix, BitmapEncoder.BitmapFormat.PNG); + } catch (IOException e) { + throw new UncheckedIOException("Save picture in " + this.getClass() + " failed!", e); + } + }); + } + + private DataXY restoreOrCreatePlotData(String label) { + File dataFile = new File(createFileName(label)); + if (dataFile.exists()) { + try { + final SimpleModule mod = new SimpleModule(); + mod.addDeserializer(DataXY.class, new DataXYDeserializer()); + final ObjectMapper mapper = new ObjectMapper().registerModule(mod); + return mapper.readValue(dataFile, new TypeReference>() { + }); + } catch (IOException e) { + e.printStackTrace(); + } + } + return new DataXY<>(label); + } + + private String createFileName(String label) { + return plotDir + File.separator + title + "_" + label + ".plt"; + } + + public static void main(String[] args) { + + final RealTimePlot plotter = new RealTimePlot<>("Test plot", ""); + IntStream.range(1000, 2000).forEach(x -> Stream.of("s1", "s2", "s3").forEach(str -> { + plotter.createSeries(str); + plotter.plotData(str, x, 1d / ((double) x + 10)); + })); +// try { +// plotter.storePlotData("s1"); +// } catch (IOException e) { +// e.printStackTrace(); +// } + + } +} \ No newline at end of file diff --git a/src/main/java/util/random/SeededRandomFactory.java b/src/main/java/util/random/SeededRandomFactory.java new file mode 100644 index 0000000..d4df45b --- /dev/null +++ b/src/main/java/util/random/SeededRandomFactory.java @@ -0,0 +1,50 @@ +package util.random; + +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.RandomFactory; + +/** + * {@link RandomFactory} with a configurable random seed + * + * @author Christian Skarby + */ +public class SeededRandomFactory extends RandomFactory { + + final java.util.Random base; + private ThreadLocal threadRandom = new ThreadLocal<>(); + + public SeededRandomFactory(Class randomClass, long baseSeed) { + super(randomClass); + base = new java.util.Random(baseSeed); + } + + @Override + public Random getRandom() { + // Copy pase from RandomFactory + try { + if (threadRandom.get() == null) { + Random t = super.getNewRandomInstance(base.nextLong()); + threadRandom.set(t); + return t; + } + + return threadRandom.get(); + } catch (Exception e) { + throw new IllegalStateException(e); + } + } + + @Override + public Random getNewRandomInstance() { + return super.getNewRandomInstance(base.nextLong()); + } + + /** + * Set a base seed for all Nd4j random generators + * @param seed base seed + */ + public static void setNd4jSeed(long seed) { + Nd4j.randomFactory = new SeededRandomFactory(Nd4j.randomFactory.getRandom().getClass(), seed); + } +} diff --git a/src/test/java/examples/spiral/AddKLDLabelTest.java b/src/test/java/examples/spiral/AddKLDLabelTest.java new file mode 100644 index 0000000..5319c52 --- /dev/null +++ b/src/test/java/examples/spiral/AddKLDLabelTest.java @@ -0,0 +1,44 @@ +package examples.spiral; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** + * Test cases for {@link AddKLDLabel} + */ +public class AddKLDLabelTest { + + /** + * Test that labels are added + */ + @Test + public void preProcess() { + final long batchSize = 3; + final long nrofLatentDims = 5; + + final MultiDataSet mds = new MultiDataSet(Nd4j.ones(batchSize, 13), Nd4j.ones(batchSize, 11)); + + final MultiDataSetPreProcessor addKld = new AddKLDLabel(1.23, 2.34, nrofLatentDims); + addKld.preProcess(mds); + + final INDArray kldLabel = mds.getLabels(1); + + assertArrayEquals("Incorrect shape!", new long[] {batchSize, 2*nrofLatentDims}, kldLabel.shape()); + + assertEquals("Expected first half to be mean!", 1.23, + kldLabel.get(NDArrayIndex.all(), + NDArrayIndex.interval(0, nrofLatentDims)).meanNumber().doubleValue(), 1e-5); + + assertEquals("Expected second half to be log(var)!", Math.log(2.34), + kldLabel.get(NDArrayIndex.all(), + NDArrayIndex.interval(nrofLatentDims, 2*nrofLatentDims)).meanNumber().doubleValue(), 1e-5); + + } +} \ No newline at end of file diff --git a/src/test/java/examples/spiral/LatentOdeTest.java b/src/test/java/examples/spiral/LatentOdeTest.java new file mode 100644 index 0000000..fb87f8c --- /dev/null +++ b/src/test/java/examples/spiral/LatentOdeTest.java @@ -0,0 +1,106 @@ +package examples.spiral; + +import ch.qos.logback.classic.Level; +import examples.spiral.listener.PlotDecodedOutput; +import examples.spiral.listener.SpiralPlot; +import examples.spiral.loss.NormLogLikelihoodLoss; +import ode.solve.conf.DormandPrince54Solver; +import ode.solve.conf.SolverConfig; +import ode.vertex.conf.helper.InputStep; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import util.listen.training.ZeroGrad; +import util.plot.RealTimePlot; + +import java.awt.*; + +import static junit.framework.TestCase.assertTrue; + +/** + * Test cases for {@link LatentOdeBlock}, {@link DenseDecoderBlock} and {@link ReconstructionLossBlock} together + * + * @author Christian Skarby + */ +public class LatentOdeTest { + + /** + * Test that the latent ODE can (over)fit to a simple line + */ + @Test + public void fitLine() { + final long nrofTimeSteps = 10; + final long nrofLatentDims = 4; + + //SeededRandomFactory.setNd4jSeed(0); + + final ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() + .seed(666) + .weightInit(WeightInit.UNIFORM) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(new Adam(0.01)) + .graphBuilder() + .setInputTypes(InputType.feedForward(nrofLatentDims), InputType.feedForward(nrofTimeSteps)); + + String next = "z0"; + builder.addInputs(next, "time"); + + + next = new LatentOdeBlock( + 20, + nrofLatentDims, + new InputStep( + new DormandPrince54Solver(new SolverConfig(1e-12, 1e-6, 1e-20, 1e2)), + 1, true)) + .add(builder, next, "time"); + + next = new DenseDecoderBlock(20, 2).add(builder, next); + final String decoded = next; + next = new ReconstructionLossBlock(new NormLogLikelihoodLoss(0.3)).add(builder, next); + builder.setOutputs(next); + + final ComputationGraph graph = new ComputationGraph(builder.build()); + graph.init(); + + final INDArray z0 = Nd4j.ones(1, nrofLatentDims); + + final INDArray time = Nd4j.linspace(0, 3, nrofTimeSteps); + final INDArray label = Nd4j.hstack(time, Nd4j.linspace(0, 9, nrofTimeSteps)).reshape(1, 2, nrofTimeSteps); + + if (!GraphicsEnvironment.isHeadless() + && !GraphicsEnvironment.getLocalGraphicsEnvironment().isHeadlessInstance() + && GraphicsEnvironment.getLocalGraphicsEnvironment().getScreenDevices() != null + && GraphicsEnvironment.getLocalGraphicsEnvironment().getScreenDevices().length > 0) { + ch.qos.logback.classic.Logger root = (ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME); + root.setLevel(Level.INFO); + + final SpiralPlot linePlot = new SpiralPlot(new RealTimePlot<>("Decoded output", "")); + linePlot.plot("Ground truth", label.tensorAlongDimension(0, 1, 2)); + graph.addListeners(new PlotDecodedOutput(linePlot, decoded, 0)); + } + + graph.addListeners( + new ZeroGrad(), + new ScoreIterationListener(1)); + + final MultiDataSet mds = new MultiDataSet(new INDArray[]{z0, time}, new INDArray[]{label}); + boolean success = false; + for (int i = 0; i < 300; i++) { + graph.fit(mds); + success = graph.score() < 100; + if (success) break; + } + assertTrue("Model failed to train properly! Score after 300 iters: " + graph.score() + "!", success); + } +} diff --git a/src/test/java/examples/spiral/OdeNetModelTest.java b/src/test/java/examples/spiral/OdeNetModelTest.java new file mode 100644 index 0000000..d7d16e0 --- /dev/null +++ b/src/test/java/examples/spiral/OdeNetModelTest.java @@ -0,0 +1,113 @@ +package examples.spiral; + +import com.beust.jcommander.JCommander; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.util.ModelSerializer; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.fail; + +/** + * Test cases for {@link OdeNetModel} + * + * @author Christian Skarby + */ +public class OdeNetModelTest { + + /** + * Test that the model can be created and that it is possible to make train for two batches + */ + @Test + public void fit() { + final OdeNetModel factory = new OdeNetModel(); + + final long nrofTimeSteps = 10; + final long batchSize = 3; + final long nrofLatentDims = 5; + final ComputationGraph model = factory.createNew(nrofTimeSteps, 0.3, nrofLatentDims).trainingModel(); + model.fit(new MultiDataSet( + new INDArray[]{Nd4j.randn(new long[]{batchSize, 2, nrofTimeSteps}), Nd4j.linspace(0, 3, nrofTimeSteps)}, + new INDArray[]{Nd4j.randn(new long[] {batchSize, 2, nrofTimeSteps}), Nd4j.zeros(batchSize, nrofLatentDims)})); + model.fit(new MultiDataSet( + new INDArray[]{Nd4j.randn(new long[]{batchSize, 2, nrofTimeSteps}), Nd4j.linspace(0, 3, nrofTimeSteps)}, + new INDArray[]{Nd4j.randn(new long[] {batchSize, 2, nrofTimeSteps}), Nd4j.zeros(batchSize, nrofLatentDims)})); + } + + /** + * Test that the model can be serialized and deserialized into the same thing + */ + @Test + public void testSerializeDeserialize() throws IOException { + final OdeNetModel factory = new OdeNetModel(); + final int nrofTimeSteps = 5; + JCommander.newBuilder() + .addObject(factory) + .build() + .parse(); + final ComputationGraph graph = factory.createNew(nrofTimeSteps, 0.1, 4).trainingModel(); + + final Path baseDir = Paths.get("src", "test", "resources", "OdeNetModelTest"); + final String fileName = Paths.get(baseDir.toString(), "testSerializeDeserialize.zip").toString(); + + try { + + baseDir.toFile().mkdirs(); + graph.save(new File(fileName), true); + final ComputationGraph newGraph = ModelSerializer.restoreComputationGraph(new File(fileName), true); + + assertEquals("Config was not restored properly!", graph.getConfiguration(), newGraph.getConfiguration()); + + final long batchSize = 3; + final INDArray[] input = {Nd4j.randn(new long[]{batchSize, 2, nrofTimeSteps}), Nd4j.linspace(0, 3, nrofTimeSteps)}; + + assertEquals("Output not the same!", graph.output(input)[0], newGraph.output(input)[0]); + + } catch (IOException e) { + e.printStackTrace(); + fail("Failed to serialize or deserialize graph!"); + } finally { + new File(fileName).delete(); + Files.delete(baseDir); + } + } + + /** + * Smoke test to assert that an {@link OdeNetModel} can be represented as a {@link TimeVae} without there being + * any exceptions. + */ + @Test + public void asTimeVae() { + final long batchSize = 5; + final long nrofTimeSteps = 17; + final long nrofLatentDims = 6; + + final TimeVae timeVae = new OdeNetModel().createNew(nrofTimeSteps, 0.3, nrofLatentDims); + + final INDArray inputTraj = Nd4j.randn(new long[]{batchSize, 2, nrofTimeSteps}); + final INDArray time = Nd4j.linspace(0, 3, nrofTimeSteps); + + final INDArray z0 = timeVae.encode(inputTraj); + + assertArrayEquals("Incorrect shape of z0!", new long[] {batchSize, nrofLatentDims}, z0.shape()); + + final INDArray zt = timeVae.timeDependency(z0, time); + + assertArrayEquals("Incorrect shape of zt!", new long[] {batchSize, nrofLatentDims, nrofTimeSteps}, zt.shape()); + + final INDArray decoded = timeVae.decode(zt); + + assertArrayEquals("Incorrect shape of decoded output!", new long[] {batchSize, 2, nrofTimeSteps}, decoded.shape()); + + } +} diff --git a/src/test/java/examples/spiral/SpiralFactoryTest.java b/src/test/java/examples/spiral/SpiralFactoryTest.java new file mode 100644 index 0000000..acd60a5 --- /dev/null +++ b/src/test/java/examples/spiral/SpiralFactoryTest.java @@ -0,0 +1,30 @@ +package examples.spiral; + +import org.junit.Test; + +import java.util.List; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertArrayEquals; + +/** + * Test cases for {@link SpiralFactory} + * + * @author Christian Skarby + */ +public class SpiralFactoryTest { + + /** + * Test that spirals can be sampled from factory + */ + @Test + public void sample() { + final SpiralFactory factory = new SpiralFactory(0, 0.3, 0, 10, 1000); + final List sample = factory.sample(4, 100, () -> 0.3, () -> true); + assertEquals("Incorrect number of samples!", 4, sample.size()); + for(SpiralFactory.Spiral spiral: sample) { + assertArrayEquals("Incorrect trajectory!", new long[] {2, 100}, spiral.trajectory().shape()); + assertArrayEquals("Incorrect theta!", new long[] {1, 100}, spiral.theta().shape()); + } + } +} \ No newline at end of file diff --git a/src/test/java/examples/spiral/SpiralIteratorTest.java b/src/test/java/examples/spiral/SpiralIteratorTest.java new file mode 100644 index 0000000..51e0024 --- /dev/null +++ b/src/test/java/examples/spiral/SpiralIteratorTest.java @@ -0,0 +1,40 @@ +package examples.spiral; + +import org.junit.Test; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; + +import java.util.Random; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Test cases for {@link SpiralIterator} + * + * @author Christian Skarby + */ +public class SpiralIteratorTest { + + /** + * Test that a spiral can be generated into a {@link MultiDataSet} and that the dimensions are as expected + */ + @Test + public void next() { + final long nrofSamplesOrig = 1000; + final long nrofSamplesTrain = 100; + final int batchSize = 200; + final SpiralFactory spiralFactory = new SpiralFactory(0, 0.3, 0, 6*Math.PI, nrofSamplesOrig); + final MultiDataSetIterator iterator = new SpiralIterator( + new SpiralIterator.Generator(spiralFactory, 0.3, nrofSamplesTrain, new Random(666)), + batchSize); + + final MultiDataSet mds = iterator.next(); + + final long[] expectedShapeSpiral = {batchSize, 2, nrofSamplesTrain}; + assertArrayEquals("Incorrect shape of spiral!", expectedShapeSpiral, mds.getFeatures(0).shape()); + assertArrayEquals("Incorrect shape of label!", expectedShapeSpiral, mds.getLabels(0).shape()); + + final long[] expectedShapeTime = {1, nrofSamplesTrain}; + assertArrayEquals("Incorrect shape of time!", expectedShapeTime, mds.getFeatures(1).shape()); + } +} \ No newline at end of file diff --git a/src/test/java/examples/spiral/loss/NormKLDLossTest.java b/src/test/java/examples/spiral/loss/NormKLDLossTest.java new file mode 100644 index 0000000..24fc7b1 --- /dev/null +++ b/src/test/java/examples/spiral/loss/NormKLDLossTest.java @@ -0,0 +1,142 @@ +package examples.spiral.loss; + +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.LossLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.junit.Test; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.activations.impl.ActivationTanH; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.primitives.Pair; + +import static junit.framework.TestCase.assertEquals; +import static junit.framework.TestCase.assertTrue; + +/** + * Test cases for {@link NormKLDLoss} + * + * @author Christian Skarby + */ +public class NormKLDLossTest { + + /** + * Test that score and gradient is 0 for a standard gaussian distribution + */ + @Test + public void computeGradientAndScoreStandard() { + final long batchSize = 11; + final long nrofLatentDimsTimesTwo = 14; + final INDArray meanAndLogvar = Nd4j.zeros(batchSize, nrofLatentDimsTimesTwo); + + final ILossFunction toTest = new NormKLDLoss(); + + final Pair scoreAndGrad = toTest.computeGradientAndScore(meanAndLogvar, meanAndLogvar, new ActivationIdentity(), null, false); + + assertEquals("Score shall be 0! ", 0.0, scoreAndGrad.getFirst(), 1e-10); + assertEquals("Gradient shall be 0!", meanAndLogvar, scoreAndGrad.getSecond()); + } + + /** + * Test that gradient is pointing in direction of error + */ + @Test + public void gradientMean() { + final long batchSize = 11; + final long nrofLatentDimsTimesTwo = 6; + final INDArray meanAndLogvar = Nd4j.zeros(batchSize, nrofLatentDimsTimesTwo); + final INDArray expectedMeanAndLogVar = meanAndLogvar.dup(); + + meanAndLogvar.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 2)).assign(Nd4j.create(new double[]{-1.23, 2.34})); + meanAndLogvar.get(NDArrayIndex.point(2), NDArrayIndex.interval(0, 2)).assign(Nd4j.create(new double[]{3.45, -4.56})); + + final ILossFunction toTest = new NormKLDLoss(); + + final INDArray gradient = toTest.computeGradient(expectedMeanAndLogVar, meanAndLogvar, new ActivationIdentity(), null); + + assertEquals("Gradient shall be same as input!", meanAndLogvar, gradient); + } + + /** + * Test that gradient is pointing in direction of error + */ + @Test + public void gradientLogVar() { + final long batchSize = 3; + final long nrofLatentDimsTimesTwo = 4; + final INDArray meanAndLogvar = Nd4j.zeros(batchSize, nrofLatentDimsTimesTwo); + final INDArray expectedMeanAndLogVar = meanAndLogvar.dup(); + + meanAndLogvar.get(NDArrayIndex.point(0), NDArrayIndex.interval(2, 4)).assign(Nd4j.create(new double[]{-1.23, 2.34})); + meanAndLogvar.get(NDArrayIndex.point(2), NDArrayIndex.interval(2, 4)).assign(Nd4j.create(new double[]{3.45, -4.56})); + + final ILossFunction toTest = new NormKLDLoss(); + + final INDArray gradient = toTest.computeGradient(expectedMeanAndLogVar, meanAndLogvar, new ActivationIdentity(), null); + + assertTrue("Gradient shall be < 0!", gradient.getDouble(0, 2) < 0); + assertTrue("Gradient shall be > 0!", gradient.getDouble(0, 3) > 0); + assertEquals("Gradient shall be 0!", 0d, gradient.getDouble(1, 2), 1e-10); + assertEquals("Gradient shall be 0!", 0d, gradient.getDouble(1, 3), 1e-10); + assertTrue("Gradient shall be < 0!", gradient.getDouble(2, 2) > 0); + assertTrue("Gradient shall be > 0!", gradient.getDouble(2, 3) < 0); + } + + /** + * Test to teach a small neural network to output zero mean and unit variance + */ + @Test + public void learnStandardGaussian() { + final long batchSize = 3; + final long nrofInputs = 4; + final long nrofTimeSteps = 5; + + final DataSet ds = new DataSet( + Nd4j.linspace(-3, 3, batchSize*nrofInputs*nrofTimeSteps).reshape(batchSize, nrofInputs, nrofTimeSteps), + Nd4j.zeros(batchSize, nrofInputs)); + + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .seed(666) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(new Adam(0.01)) + .graphBuilder() + .addInputs("input") + .setInputTypes(InputType.recurrent(nrofInputs, nrofTimeSteps)) + .addLayer("rnn", new SimpleRnn.Builder() + .activation(new ActivationTanH()) + .nOut(20) + .build(), "input") + .addVertex("lastStep", new LastTimeStepVertex("input"), "rnn") + .addLayer("dnn", new DenseLayer.Builder() + .nOut(nrofInputs) + .activation(new ActivationIdentity()) + .build(), "lastStep") + + .addLayer("out", new LossLayer.Builder() + .lossFunction(new NormKLDLoss()) + .build(), "dnn") + .setOutputs("out") + .build()); + graph.init(); + + boolean success = false; + for(int i = 0; i < 300; i++) { + graph.fit(ds); + success |= graph.score() < 0.001; + if(success) break; + } + + assertTrue("Training did not succeed!", success); + + } + +} \ No newline at end of file diff --git a/src/test/java/examples/spiral/loss/NormLogLikelihoodLossTest.java b/src/test/java/examples/spiral/loss/NormLogLikelihoodLossTest.java new file mode 100644 index 0000000..94f4d73 --- /dev/null +++ b/src/test/java/examples/spiral/loss/NormLogLikelihoodLossTest.java @@ -0,0 +1,67 @@ +package examples.spiral.loss; + +import org.junit.Test; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.primitives.Pair; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertArrayEquals; + +/** + * Test cases for {@link NormLogLikelihoodLoss} + * + * @author Christian Skarby + */ +public class NormLogLikelihoodLossTest { + + private final double[][][] expectedGrad = {{{-29.2222, -27.2069, -25.1916, -23.1762, -21.1609}, + {-19.1456, -17.1303, -15.1149, -13.0996, -11.0843}}, + + {{-9.0690, -7.0536, -5.0383, -3.0230, -1.0077}, + {1.0077, 3.0230, 5.0383, 7.0536, 9.0690}}, + + {{11.0843, 13.0996, 15.1149, 17.1303, 19.1456}, + {21.1609, 23.1762, 25.1916, 27.2069, 29.2222}}}; + + /** + * Test that score and gradient is zero when label and prediction are the same + */ + @Test + public void computeGradientAndScorePerfectMatch() { + final INDArray traj = Nd4j.linspace(-3.45, 2.34, 2 * 3 * 5).reshape(3, 2, 5); + + final Pair out = new NormLogLikelihoodLoss(0.3) + .computeGradientAndScore(traj, traj, new ActivationIdentity(), null, false); + + // Loss has a constant term + assertEquals("Expected minimum loss!", -8.55102825164795, out.getFirst(), 1e-5); + assertEquals("Expected zero grad!", 0, out.getSecond().amaxNumber().doubleValue(), 1e-10); + } + + /** + * Test that score and gradient is correct when label and prediction are not the same. Numbers taken from original repo. + */ + @Test + public void computeGradientAndScoreWithEps() { + final INDArray traj = Nd4j.linspace(-3.45, 2.34, 2 * 3 * 5).reshape(3, 2, 5); + final INDArray eps = Nd4j.linspace(-7.89, 7.89, traj.length()).reshape(traj.shape()); + + final Pair out = new NormLogLikelihoodLoss(0.3) + .computeGradientAndScore(traj, traj.add(eps), new ActivationIdentity(), null, true); + + + out.getSecond().divi(3); // Dl4j scales gradients with mini batch size centrally in BaseMultiLayerUpdater + assertEquals("Incorrect loss!", 1229.4709, out.getFirst(), 1e-4); + + for (int i = 0; i < expectedGrad.length; i++) { + for (int j = 0; j < expectedGrad[i].length; j++) { + assertArrayEquals("Incorrect gradient along " + i + ", " + j + "!", + expectedGrad[i][j], + out.getSecond().get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all()).toDoubleVector(), 1e-4); + } + } + } +} \ No newline at end of file diff --git a/src/test/java/examples/spiral/vertex/impl/SampleGaussianVertexTest.java b/src/test/java/examples/spiral/vertex/impl/SampleGaussianVertexTest.java new file mode 100644 index 0000000..d170906 --- /dev/null +++ b/src/test/java/examples/spiral/vertex/impl/SampleGaussianVertexTest.java @@ -0,0 +1,141 @@ +package examples.spiral.vertex.impl; + +import examples.spiral.vertex.conf.SampleGaussianVertex; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.deeplearning4j.nn.conf.layers.LossLayer; +import org.deeplearning4j.nn.conf.memory.MemoryReport; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.GraphVertex; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.jetbrains.annotations.NotNull; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.linalg.primitives.Pair; + +import java.util.stream.LongStream; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** + * Test cases for {@link examples.spiral.vertex.conf.SampleGaussianVertex} + * + * @author Christian Skarby + */ +public class SampleGaussianVertexTest { + + /** + * Test that output has the same mean and variance which is input + */ + @Test + public void doForward() { + final long nrofLatentDims = 4; + final ComputationGraph graph = getGraph(nrofLatentDims, new SampleGaussianVertex(666)); + + final int batchSize = 100000; + final INDArray means = Nd4j.arange(nrofLatentDims); + final INDArray logVars = Nd4j.arange(nrofLatentDims, 2*nrofLatentDims); + INDArray output = graph.outputSingle(Nd4j.repeat(means, batchSize), Nd4j.repeat(logVars, batchSize)); + + assertArrayEquals("Incorrect mean!", means.toDoubleVector(), output.mean(0).toDoubleVector(), 1e-1); + assertArrayEquals("Incorrect logvar!", logVars.toDoubleVector(), Transforms.log(output.var(0)).toDoubleVector(), 1e-1); + } + + /** + * Test doBackward. LogVar numbers verified in pytorch + */ + @Test + public void doBackward() { + final long nrofLatentDims = 2; + final ComputationGraph graph = getGraph(nrofLatentDims, new TestSampleGaussianVertex()); + + final GraphVertex vertex = graph.getVertex("z"); + + // Need to do a forward pass to set inputs and calculate epsilon + vertex.setInput(0, Nd4j.arange(2*nrofLatentDims), LayerWorkspaceMgr.noWorkspaces()); + vertex.doForward(true, LayerWorkspaceMgr.noWorkspaces()); + + vertex.setEpsilon(Nd4j.create(new double[] {1.3, 2.4})); + final Pair result = graph.getVertex("z").doBackward(false, LayerWorkspaceMgr.noWorkspaces()); + + assertEquals("Incorrect gradient for mean!", + vertex.getEpsilon(), + result.getSecond()[0].get(NDArrayIndex.all(), NDArrayIndex.interval(0, nrofLatentDims))); + assertEquals("Incorrect gradient for log var!", + Nd4j.create(new double[] {2.6503, 23.1255}), + result.getSecond()[0].get(NDArrayIndex.all(), NDArrayIndex.interval(nrofLatentDims, 2*nrofLatentDims))); + } + + @NotNull + private static ComputationGraph getGraph(long nrofLatentDims, org.deeplearning4j.nn.conf.graph.GraphVertex vertex) { + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .graphBuilder() + .addInputs("mean", "logVar") + .setInputTypes(InputType.feedForward(nrofLatentDims), InputType.feedForward(nrofLatentDims)) + .addVertex("z", vertex, "mean", "logVar") // Note: MergeVertex will be added + .setOutputs("output") + .addLayer("output", new LossLayer.Builder().build(), "z") + .build()); + graph.init(); + return graph; + } + + private static class TestSampleGaussianVertex extends org.deeplearning4j.nn.conf.graph.GraphVertex { + + private final SampleGaussianVertex helper = new SampleGaussianVertex(666); + + @Override + public org.deeplearning4j.nn.conf.graph.GraphVertex clone() { + return null; + } + + @Override + public boolean equals(Object o) { + return false; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public long numParams(boolean backprop) { + return helper.numParams(backprop); + } + + @Override + public int minVertexInputs() { + return helper.minVertexInputs(); + } + + @Override + public int maxVertexInputs() { + return helper.maxVertexInputs(); + } + + @Override + public GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) { + return new examples.spiral.vertex.impl.SampleGaussianVertex(graph, name, idx, shape -> { + final long sum = LongStream.of(shape).reduce(1, (l1,l2) -> l1*l2); + return Nd4j.linspace(1.5, 4.3, sum).reshape(shape); + }); + } + + @Override + public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { + return helper.getOutputType(layerIndex, vertexInputs); + } + + @Override + public MemoryReport getMemoryReport(InputType... inputTypes) { + return helper.getMemoryReport(inputTypes); + } + } +} \ No newline at end of file diff --git a/src/test/java/ode/solve/impl/DormandPrince54SolverTest.java b/src/test/java/ode/solve/impl/DormandPrince54SolverTest.java index 9dd5c1a..af9e1ae 100644 --- a/src/test/java/ode/solve/impl/DormandPrince54SolverTest.java +++ b/src/test/java/ode/solve/impl/DormandPrince54SolverTest.java @@ -5,6 +5,7 @@ import ode.solve.api.StepListener; import ode.solve.commons.FirstOrderSolverAdapter; import ode.solve.conf.SolverConfig; +import ode.solve.impl.util.SolverState; import org.apache.commons.math3.ode.nonstiff.DormandPrince54Integrator; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -103,8 +104,8 @@ public void begin(INDArray t, INDArray y0) { } @Override - public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) { - times.add(currTime.detach()); + public void step(SolverState solverState, INDArray step, INDArray error) { + times.add(solverState.time().detach()); } @Override diff --git a/src/test/java/ode/solve/impl/InterpolatingMultiStepSolverTest.java b/src/test/java/ode/solve/impl/InterpolatingMultiStepSolverTest.java new file mode 100644 index 0000000..5df8d7d --- /dev/null +++ b/src/test/java/ode/solve/impl/InterpolatingMultiStepSolverTest.java @@ -0,0 +1,57 @@ +package ode.solve.impl; + +import ode.solve.CircleODE; +import ode.solve.api.FirstOrderEquation; +import ode.solve.api.FirstOrderSolver; +import ode.solve.conf.SolverConfig; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Test cases for {@link InterpolatingMultiStepSolver} + * + * @author Christian Skarby + */ +public class InterpolatingMultiStepSolverTest { + + /** + * Test that output from a {@link SingleSteppingMultiStepSolver} is the same as the output from an + * {@link InterpolatingMultiStepSolver} when solving the {@link CircleODE} when using the same input. + */ + @Test + public void integrateCircle() { + final Pair multiAndInterp = solveCircleMultiInterpol(); + + final INDArray multi = multiAndInterp.getFirst(); + final INDArray interp = multiAndInterp.getSecond(); + + for(int i = 0; i < multi.columns(); i++) + assertArrayEquals("Solutions are different in column " + i +"!!", + multi.getColumn(i).toDoubleVector(), + interp.getColumn(i).toDoubleVector(), + 1e-4); + } + + private static Pair solveCircleMultiInterpol() { + final FirstOrderSolver singleStepSolver = new DormandPrince54Solver( + new SolverConfig(1e-7, 1e-7, 1e-10, 1e2)); + + final double omega = 5.67; + final FirstOrderEquation equation = new CircleODE(new double[] {1.23, 4.56}, omega); + + final int nrofSteps = 25; + final INDArray y0 = Nd4j.create(new double[] {-5.6, 7.3}); + final INDArray ySingle = y0.dup(); + final INDArray yMulti = Nd4j.repeat(ySingle,nrofSteps-1).reshape(nrofSteps-1, y0.length()).assign(0); + final INDArray t = Nd4j.linspace(-Math.PI/omega,Math.PI/omega , nrofSteps); + + final INDArray expected = new SingleSteppingMultiStepSolver(singleStepSolver).integrate(equation, t, y0, yMulti.dup()).transposei(); + final INDArray actual = new InterpolatingMultiStepSolver(singleStepSolver).integrate(equation, t, y0, yMulti.dup()).transposei(); + + return new Pair<>(expected, actual); + } +} \ No newline at end of file diff --git a/src/test/java/ode/solve/impl/SingleSteppingMultiStepSolverTest.java b/src/test/java/ode/solve/impl/SingleSteppingMultiStepSolverTest.java new file mode 100644 index 0000000..b899f0c --- /dev/null +++ b/src/test/java/ode/solve/impl/SingleSteppingMultiStepSolverTest.java @@ -0,0 +1,40 @@ +package ode.solve.impl; + +import ode.solve.CircleODE; +import ode.solve.api.FirstOrderEquation; +import ode.solve.api.FirstOrderSolver; +import ode.solve.conf.SolverConfig; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static junit.framework.TestCase.assertEquals; + +/** + * Test cases for {@link SingleSteppingMultiStepSolver} + * + * @author Christian Skarby + */ +public class SingleSteppingMultiStepSolverTest { + + /** + * Test that solving {@link ode.solve.CircleODE} in multiple steps gives the same result as if all steps are done + * in one solve. + */ + @Test + public void integrate() { + final int nrofSteps = 20; + final FirstOrderEquation circle = new CircleODE(new double[] {1.23, 4.56}, 1); + final FirstOrderSolver actualSolver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 10)); + + final INDArray t = Nd4j.linspace(0, Math.PI, nrofSteps); + final INDArray y0 = Nd4j.create(new double[] {0, 0}); + final INDArray ySingle = y0.dup(); + final INDArray yMulti = Nd4j.repeat(ySingle,nrofSteps-1).reshape(nrofSteps-1, y0.length()); + + actualSolver.integrate(circle, t.getColumns(0, nrofSteps-1), y0, ySingle); + new SingleSteppingMultiStepSolver(actualSolver).integrate(circle, t, y0, yMulti); + + assertEquals("Incorrect solution!", ySingle, yMulti.getRow(nrofSteps-2)); + } +} \ No newline at end of file diff --git a/src/test/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicyTest.java b/src/test/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicyTest.java index eb8bb34..c1b182a 100644 --- a/src/test/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicyTest.java +++ b/src/test/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicyTest.java @@ -52,7 +52,7 @@ DormandPrince54Integrator setEquation(FirstOrderDifferentialEquations equations) true, 5, scale, t.getDouble(0), y0.toDoubleVector(), yDot, y1, yDot1); - final FirstOrderEquationWithState eqState = new FirstOrderEquationWithState(equation, t.getColumn(0), y0, 5); + final FirstOrderEquationWithState eqState = new FirstOrderEquationWithState(equation, t.getColumn(0), y0, new double [5]); final INDArray stepAct = new AdaptiveRungeKuttaStepPolicy( new SolverConfigINDArray(absTol, relTol, 1e-20, 1e20), 5) .initializeStep(eqState, t); diff --git a/src/test/java/ode/solve/impl/util/AggStepListenerTest.java b/src/test/java/ode/solve/impl/util/AggStepListenerTest.java index af79155..0b223b8 100644 --- a/src/test/java/ode/solve/impl/util/AggStepListenerTest.java +++ b/src/test/java/ode/solve/impl/util/AggStepListenerTest.java @@ -25,7 +25,7 @@ public void addListenersAndRemoveOne() { aggStepListener.addListeners(second,third); aggStepListener.begin(Nd4j.linspace(0,1 ,2), Nd4j.zeros(1)); - aggStepListener.step(Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1)); + aggStepListener.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(1), Nd4j.zeros(1)); aggStepListener.done(); first.assertNrofCalls(1,1,1); second.assertNrofCalls(1,1,1); @@ -34,7 +34,7 @@ public void addListenersAndRemoveOne() { aggStepListener.clearListeners(second); aggStepListener.begin(Nd4j.linspace(0,1 ,2), Nd4j.zeros(1)); - aggStepListener.step(Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1)); + aggStepListener.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(1), Nd4j.zeros(1)); aggStepListener.done(); first.assertNrofCalls(2,2,2); second.assertNrofCalls(1,1,1); @@ -43,7 +43,7 @@ public void addListenersAndRemoveOne() { aggStepListener.clearListeners(third, first); aggStepListener.begin(Nd4j.linspace(0,1 ,2), Nd4j.zeros(1)); - aggStepListener.step(Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1)); + aggStepListener.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(1), Nd4j.zeros(1)); aggStepListener.done(); first.assertNrofCalls(2,2,2); second.assertNrofCalls(1,1,1); @@ -53,7 +53,7 @@ public void addListenersAndRemoveOne() { aggStepListener.clearListeners(); aggStepListener.begin(Nd4j.linspace(0,1 ,2), Nd4j.zeros(1)); - aggStepListener.step(Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1)); + aggStepListener.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(1), Nd4j.zeros(1)); aggStepListener.done(); first.assertNrofCalls(2,2,2); second.assertNrofCalls(1,1,1); diff --git a/src/test/java/ode/solve/impl/util/InterpolatingStepListenerTest.java b/src/test/java/ode/solve/impl/util/InterpolatingStepListenerTest.java new file mode 100644 index 0000000..d1fefe2 --- /dev/null +++ b/src/test/java/ode/solve/impl/util/InterpolatingStepListenerTest.java @@ -0,0 +1,116 @@ +package ode.solve.impl.util; + +import examples.spiral.listener.SpiralPlot; +import ode.solve.CircleODE; +import ode.solve.api.FirstOrderEquation; +import ode.solve.api.FirstOrderSolver; +import ode.solve.conf.SolverConfig; +import ode.solve.impl.DormandPrince54Solver; +import ode.solve.impl.SingleSteppingMultiStepSolver; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.primitives.Pair; +import util.plot.RealTimePlot; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Test cases for {@link InterpolatingStepListener} + * + * @author Christian Skarby + */ +public class InterpolatingStepListenerTest { + + /** + * Verify that the result from using an {@link InterpolatingStepListener} is equivalent to using a + * {@link SingleSteppingMultiStepSolver} when solving the {@link CircleODE} when using the same input. + */ + @Test + public void iterpolateCircleForward() { + final Pair multiAndInterp = solveCircleMultiInterpol(false); + + final INDArray multi = multiAndInterp.getFirst(); + final INDArray interp = multiAndInterp.getSecond(); + + for (int i = 0; i < multi.columns(); i++) + assertArrayEquals("Solutions are different in column " + i + "!!", + multi.getColumn(i).toDoubleVector(), + interp.getColumn(i).toDoubleVector(), + 1e-4); + } + + /** + * Verify that the result from using an {@link InterpolatingStepListener} is equivalent to using a + * {@link SingleSteppingMultiStepSolver} when solving the {@link CircleODE} when using the same input. + */ + @Test + public void iterpolateCircleBackward() { + final Pair multiAndInterp = solveCircleMultiInterpol(true); + + final INDArray multi = multiAndInterp.getFirst(); + final INDArray interp = multiAndInterp.getSecond(); + + for (int i = 0; i < multi.columns(); i++) + assertArrayEquals("Solutions are different in column " + i + "!!", + multi.getColumn(i).toDoubleVector(), + interp.getColumn(i).toDoubleVector(), + 1e-4); + } + + + private static Pair solveCircleMultiInterpol(boolean backwards) { + final FirstOrderSolver singleStepSolver = new DormandPrince54Solver( + new SolverConfig(1e-7, 1e-7, 1e-10, 1e2)); + + final double omega = 5.67; + final FirstOrderEquation equation = new CircleODE(new double[]{1.23, 4.56}, omega); + + final int nrofSteps = 25; + final INDArray y0 = Nd4j.create(new double[]{-5.6, 7.3}); + final INDArray ySingle = y0.dup(); + final INDArray yMulti = Nd4j.repeat(ySingle, nrofSteps - 1).reshape(nrofSteps - 1, y0.length()).assign(0); + + final INDArray tStartEnd = Nd4j.create(new double[]{-Math.PI / omega, Math.PI / omega}); + if (backwards) tStartEnd.negi(); + final INDArray t = Nd4j.linspace(tStartEnd.getDouble(0), tStartEnd.getDouble(1), nrofSteps); + + final INDArray expected = new SingleSteppingMultiStepSolver(singleStepSolver).integrate(equation, t, y0, yMulti.dup()).transposei(); + + singleStepSolver.addListener(new InterpolatingStepListener(t.get(NDArrayIndex.interval(1, nrofSteps)), yMulti)); + final INDArray yLast = y0.dup(); + singleStepSolver.integrate(equation, tStartEnd, y0, yLast); + + return new Pair<>(expected, yMulti.transposei()); + } + + /** + * Main method for plotting results from both approaches + * + * @param args not used + */ + public static void main(String[] args) { + + final Pair multiAndInterpBackwards = solveCircleMultiInterpol(true); + plotSolutions(new SpiralPlot(new RealTimePlot<>("Multi step vs interpol bwd", "")), multiAndInterpBackwards); + + final Pair multiAndInterpForwards = solveCircleMultiInterpol(false); + plotSolutions(new SpiralPlot(new RealTimePlot<>("Multi step vs interpol fwd" , "")), multiAndInterpForwards); + } + + private static void plotSolutions(SpiralPlot plot, Pair multiAndInterp) { + final INDArray multi = multiAndInterp.getFirst(); + final INDArray interp = multiAndInterp.getSecond(); + + plot.plot("Multi", multi); + plot.plot( "Interp", interp); + + plot.plot("Multi start", multi.getColumn(0)); + plot.plot("Interp start", interp.getColumn(0)); + + final long lastStep = multi.size(1) - 1; + plot.plot("Multi stop", multi.getColumn(lastStep)); + plot.plot("Interp stop", interp.getColumn(lastStep)); + } +} \ No newline at end of file diff --git a/src/test/java/ode/solve/impl/util/InterpolationTest.java b/src/test/java/ode/solve/impl/util/InterpolationTest.java new file mode 100644 index 0000000..3d85f52 --- /dev/null +++ b/src/test/java/ode/solve/impl/util/InterpolationTest.java @@ -0,0 +1,48 @@ +package ode.solve.impl.util; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static junit.framework.TestCase.assertEquals; + +/** + * Test cases for {@link Interpolation} + * + * @author Christian Skarby + */ +public class InterpolationTest { + + /** + * Test interpolation by comparing to the result for the same input from original repo + */ + @Test + public void interpolate() { + final long[] shape = {2, 3, 5}; + final long nrofElems = 2 * 3 * 5; + final INDArray y0 = Nd4j.linspace(-10, 10, nrofElems).reshape(shape); + final INDArray yMid = Nd4j.linspace(-7, 13, nrofElems).reshape(shape); + final INDArray y1 = Nd4j.linspace(-12, 7, nrofElems).reshape(shape); + final INDArray f0 = Nd4j.linspace(2, 5, nrofElems).reshape(shape); + final INDArray f1 = Nd4j.linspace(-3, -1, nrofElems).reshape(shape); + final INDArray dt = Nd4j.scalar(1.23); + + final Interpolation interpolation = new Interpolation(); + + interpolation.fitCoeffs(y0, y1, yMid, f0, f1, dt); + + final INDArray output = interpolation.interpolate(-2, 3, 2.34); + + // Output from pytorch repo + final INDArray expected = Nd4j.create(new double[][][]{ + {{-10.8218, -10.1690, -9.5162, -8.8633, -8.2105}, + {-7.5577, -6.9049, -6.2521, -5.5993, -4.9465}, + {-4.2936, -3.6408, -2.9880, -2.3352, -1.6824}}, + {{-1.0296, -0.3768, 0.2760, 0.9288, 1.5817}, + {2.2345, 2.8873, 3.5401, 4.1929, 4.8457}, + {5.4985, 6.1513, 6.8042, 7.4570, 8.1098}} + }); + + assertEquals("Incorrect output!", output.toString(), expected.toString()); + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/conf/helper/backward/AbstractHelperConfTest.java b/src/test/java/ode/vertex/conf/helper/backward/AbstractHelperConfTest.java new file mode 100644 index 0000000..2b2ad6c --- /dev/null +++ b/src/test/java/ode/vertex/conf/helper/backward/AbstractHelperConfTest.java @@ -0,0 +1,188 @@ +package ode.vertex.conf.helper.backward; + +import ode.vertex.impl.gradview.NonContiguous1DView; +import ode.vertex.impl.helper.backward.OdeHelperBackward.InputArrays; +import ode.vertex.impl.helper.backward.OdeHelperBackward.MiscPar; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.distribution.ConstantDistribution; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.Convolution2D; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.jetbrains.annotations.NotNull; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.IOException; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertNotEquals; + +abstract class AbstractHelperConfTest { + + private final static InputType convInput = InputType.convolutional(9,9,2); + + /** + * Create a {@link OdeHelperBackward} which shall be tested + * + * @param nrofTimeSteps Number of time steps to create helper for + * @return a new {@link OdeHelperBackward} + */ + abstract OdeHelperBackward create(int nrofTimeSteps, boolean needTimeGradient); + + /** + * Create inputs from a given input array + * + * @param input input array + * @param nrofTimeSteps number of time steps to create input for + * @return all needed inputs + */ + abstract INDArray[] createInputs(INDArray input, int nrofTimeSteps); + + /** + * Test serialization and deserialization + */ + @Test + public void serializeDeserializeSingleStep() throws IOException { + final OdeHelperBackward conf = create(2, true); + final String json = NeuralNetConfiguration.mapper().writeValueAsString(conf); + final OdeHelperBackward newConf = NeuralNetConfiguration.mapper().readValue(json, OdeHelperBackward.class); + assertEquals("Did not deserialize into the same thing!", conf, newConf); + } + + /** + * Test serialization and deserialization + */ + @Test + public void serializeDeserializeMultiStep() throws IOException { + final OdeHelperBackward conf = create(5, true); + final String json = NeuralNetConfiguration.mapper().writeValueAsString(conf); + final OdeHelperBackward newConf = NeuralNetConfiguration.mapper().readValue(json, OdeHelperBackward.class); + assertEquals("Did not deserialize into the same thing!", conf, newConf); + } + + /** + * Test that helper can be instantiated and that it does something + */ + @Test + public void instantiateAndSolveConvSingleStep() { + final int nrofTimeSteps = 2; + final ode.vertex.impl.helper.backward.OdeHelperBackward helper = create(nrofTimeSteps, true).instantiate(); + final ComputationGraph graph = createGraph(); + final InputArrays input = getTestInputArrays(nrofTimeSteps, graph); + final INDArray[] output = helper.solve(graph, input, new MiscPar( + false, + LayerWorkspaceMgr.noWorkspaces())); + + + assertNotEquals("Expected non-zero param gradient!", 0, graph.getGradientsViewArray().sumNumber().doubleValue() ,1e-10); + for(INDArray inputGrad: output) { + assertNotEquals("Expected non-zero param gradient!", 0, inputGrad.sumNumber().doubleValue() ,1e-10); + } + } + + /** + * Test that helper can be instantiated and that it does something + */ + @Test + public void instantiateAndSolveConvSingleStepNoTimeGrad() { + final int nrofTimeSteps = 2; + final ode.vertex.impl.helper.backward.OdeHelperBackward helper = create(nrofTimeSteps, false).instantiate(); + final ComputationGraph graph = createGraph(); + final InputArrays input = getTestInputArrays(nrofTimeSteps, graph); + final INDArray[] output = helper.solve(graph, input, new MiscPar( + false, + LayerWorkspaceMgr.noWorkspaces())); + + + assertNotEquals("Expected non-zero param gradient!", 0, graph.getGradientsViewArray().sumNumber().doubleValue() ,1e-10); + assertNotEquals("Expected non-zero param gradient!", 0, output[0].amaxNumber().doubleValue() ,1e-10); + if(output.length > 1) { + assertEquals("Expected zero time gradient!", 0, output[1].amaxNumber().doubleValue(), 1e-10); + } + } + + /** + * Test that helper can be instantiated and that it does something + */ + @Test + public void instantiateAndSolveConvMultiStep() { + final int nrofTimeSteps = 7; + final ode.vertex.impl.helper.backward.OdeHelperBackward helper = create(nrofTimeSteps, true).instantiate(); + final ComputationGraph graph = createGraph(); + final InputArrays input = getTestInputArrays(nrofTimeSteps, graph); + final INDArray[] output = helper.solve(graph, input, new MiscPar( + false, + LayerWorkspaceMgr.noWorkspaces())); + + assertNotEquals("Expected non-zero param gradient!", 0, graph.getGradientsViewArray().amaxNumber().doubleValue() ,1e-10); + for(INDArray inputGrad: output) { + assertNotEquals("Expected non-zero param gradient!", 0, inputGrad.amaxNumber().doubleValue() ,1e-10); + } + } + + /** + * Test that helper can be instantiated and that it does something + */ + @Test + public void instantiateAndSolveConvMultiStepNoTimeGrad() { + final int nrofTimeSteps = 7; + final ode.vertex.impl.helper.backward.OdeHelperBackward helper = create(nrofTimeSteps, false).instantiate(); + final ComputationGraph graph = createGraph(); + final InputArrays input = getTestInputArrays(nrofTimeSteps, graph); + final INDArray[] output = helper.solve(graph, input, new MiscPar( + false, + LayerWorkspaceMgr.noWorkspaces())); + + assertNotEquals("Expected non-zero param gradient!", 0, graph.getGradientsViewArray().amaxNumber().doubleValue() ,1e-10); + assertNotEquals("Expected non-zero param gradient!", 0, output[0].amaxNumber().doubleValue() ,1e-10); + if(output.length > 1) { + assertEquals("Expected zero time gradient!", 0, output[1].amaxNumber().doubleValue(), 1e-10); + } + } + + private ComputationGraph createGraph() { + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .trainingWorkspaceMode(WorkspaceMode.NONE) + .inferenceWorkspaceMode(WorkspaceMode.NONE) + .weightInit(new ConstantDistribution(0.01)) + .graphBuilder() + .allowNoOutput(true) + .addInputs("input") + .setInputTypes(convInput) + .addLayer("1", new Convolution2D.Builder(3, 3).nOut(2).convolutionMode(ConvolutionMode.Same).build(), "input") + .build()); + graph.init(); + graph.initGradientsView(); + return graph; + } + + @NotNull + private InputArrays getTestInputArrays(int nrofTimeSteps, ComputationGraph graph) { + + final int batchSize = 3; + final long[] shape = convInput.getShape(true); + shape[0] = batchSize; + + final int nrofTimeStepsToUse = nrofTimeSteps == 2 ? 1 : nrofTimeSteps; + final long[] outputShape = nrofTimeSteps == 2 ? shape : new long[]{ batchSize, nrofTimeStepsToUse, shape[1], shape[2], shape[3]}; + + + final INDArray input = Nd4j.arange(batchSize * convInput.arrayElementsPerExample()).reshape(shape); + final INDArray output = Nd4j.arange(batchSize * nrofTimeStepsToUse * convInput.arrayElementsPerExample()) + .reshape(outputShape); + final INDArray epsilon = Nd4j.ones(outputShape).assign(0.01); + final NonContiguous1DView realGrads = new NonContiguous1DView(); + realGrads.addView(graph.getGradientsViewArray()); + + return new InputArrays( + createInputs(input, nrofTimeSteps), + output, + epsilon, + realGrads + ); + } +} diff --git a/src/test/java/ode/vertex/conf/helper/backward/FixedStepAdjointTest.java b/src/test/java/ode/vertex/conf/helper/backward/FixedStepAdjointTest.java new file mode 100644 index 0000000..8cf86e4 --- /dev/null +++ b/src/test/java/ode/vertex/conf/helper/backward/FixedStepAdjointTest.java @@ -0,0 +1,23 @@ +package ode.vertex.conf.helper.backward; + +import ode.solve.conf.DormandPrince54Solver; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Test cases for {@link FixedStepAdjoint} + * + * @author Christian Skarby + */ +public class FixedStepAdjointTest extends AbstractHelperConfTest { + + @Override + OdeHelperBackward create(int nrofTimeSteps, boolean needTimeGradient) { + return new FixedStepAdjoint(new DormandPrince54Solver(), Nd4j.linspace(0,3,nrofTimeSteps)); + } + + @Override + INDArray[] createInputs(INDArray input, int nrofTimeSteps) { + return new INDArray[] {input}; + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/conf/helper/backward/InputStepAdjointTest.java b/src/test/java/ode/vertex/conf/helper/backward/InputStepAdjointTest.java new file mode 100644 index 0000000..c5d6ed8 --- /dev/null +++ b/src/test/java/ode/vertex/conf/helper/backward/InputStepAdjointTest.java @@ -0,0 +1,23 @@ +package ode.vertex.conf.helper.backward; + +import ode.solve.conf.DormandPrince54Solver; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Test cases for {@link InputStepAdjoint} + * + * @author Christian Skarby + */ +public class InputStepAdjointTest extends AbstractHelperConfTest { + + @Override + OdeHelperBackward create(int nrofTimeSteps, boolean needTimeGrad) { + return new InputStepAdjoint(new DormandPrince54Solver(), 1, needTimeGrad); + } + + @Override + INDArray[] createInputs(INDArray input, int nrofTimeSteps) { + return new INDArray[]{input, Nd4j.linspace(0, 2, nrofTimeSteps)}; + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/conf/helper/forward/AbstractHelperConfTest.java b/src/test/java/ode/vertex/conf/helper/forward/AbstractHelperConfTest.java new file mode 100644 index 0000000..5bdaf38 --- /dev/null +++ b/src/test/java/ode/vertex/conf/helper/forward/AbstractHelperConfTest.java @@ -0,0 +1,64 @@ +package ode.vertex.conf.helper.forward; + +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.Convolution2D; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.IOException; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertNotEquals; + +abstract class AbstractHelperConfTest { + + /** + * Create a {@link OdeHelperForward} which shall be tested + * + * @return a new {@link OdeHelperForward} + */ + abstract OdeHelperForward create(); + + /** + * Create inputs from a given input array + * + * @param input input array + * @return all needed inputs + */ + abstract INDArray[] createInputs(INDArray input); + + /** + * Test serialization and deserialization + */ + @Test + public void serializeDeserialize() throws IOException { + final OdeHelperForward conf = create(); + final String json = NeuralNetConfiguration.mapper().writeValueAsString(conf); + final OdeHelperForward newConf = NeuralNetConfiguration.mapper().readValue(json, OdeHelperForward.class); + assertEquals("Did not deserialize into the same thing!", conf, newConf); + } + + @Test + public void instantiateAndSolveConv() { + final ode.vertex.impl.helper.forward.OdeHelperForward helper = create().instantiate(); + final INDArray output = helper.solve(createGraph(), LayerWorkspaceMgr.noWorkspaces(), createInputs(Nd4j.randn(new long[] {5, 2, 3, 3}))); + assertNotEquals("Expected non-zero output!", 0, output.sumNumber().doubleValue() ,1e-10); + } + + private ComputationGraph createGraph() { + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .graphBuilder() + .allowNoOutput(true) + .addInputs("input") + .setInputTypes(InputType.convolutional(9, 9, 2)) + .addLayer("1", new Convolution2D.Builder(3, 3).nOut(2).convolutionMode(ConvolutionMode.Same).build(), "input") + .build()); + graph.init(); + return graph; + } +} diff --git a/src/test/java/ode/vertex/conf/helper/forward/FixedStepTest.java b/src/test/java/ode/vertex/conf/helper/forward/FixedStepTest.java new file mode 100644 index 0000000..765aa62 --- /dev/null +++ b/src/test/java/ode/vertex/conf/helper/forward/FixedStepTest.java @@ -0,0 +1,24 @@ +package ode.vertex.conf.helper.forward; + +import ode.solve.conf.DormandPrince54Solver; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Test cases for {@link FixedStep} + * + * @author Christian Skarby + */ +public class FixedStepTest extends AbstractHelperConfTest{ + + + @Override + OdeHelperForward create() { + return new FixedStep(new DormandPrince54Solver(), Nd4j.linspace(0,3,4), false); + } + + @Override + INDArray[] createInputs(INDArray input) { + return new INDArray[] {input}; + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/conf/helper/forward/InputStepTest.java b/src/test/java/ode/vertex/conf/helper/forward/InputStepTest.java new file mode 100644 index 0000000..d839272 --- /dev/null +++ b/src/test/java/ode/vertex/conf/helper/forward/InputStepTest.java @@ -0,0 +1,24 @@ +package ode.vertex.conf.helper.forward; + +import ode.solve.conf.DormandPrince54Solver; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + + +/** + * Test cases for {@link InputStep} + * + * @author Christian Skarby + */ +public class InputStepTest extends AbstractHelperConfTest { + + @Override + OdeHelperForward create() { + return new InputStep(new DormandPrince54Solver(), 1, false); + } + + @Override + INDArray[] createInputs(INDArray input) { + return new INDArray[] {input, Nd4j.linspace(0, 10, 5)}; + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/NonContiguous1DViewTest.java b/src/test/java/ode/vertex/impl/NonContiguous1DViewTest.java deleted file mode 100644 index 28c6ef3..0000000 --- a/src/test/java/ode/vertex/impl/NonContiguous1DViewTest.java +++ /dev/null @@ -1,50 +0,0 @@ -package ode.vertex.impl; - -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; - -/** - * Test cases for {@link NonContiguous1DView} - * - * @author Christian Skarby - */ -public class NonContiguous1DViewTest { - - /** - * Test assignment of {@link NonContiguous1DView} from another {@link INDArray}. Set first and last three elements - * and leave three in the middle untouched - */ - @Test - public void assignFrom() { - final INDArray toView = Nd4j.zeros(9); - - final NonContiguous1DView view = new NonContiguous1DView(); - view.addView(toView, 0, 3); - view.addView(toView, 6,9); - view.assignFrom(Nd4j.ones(new long[] {6})); - - final INDArray expected = Nd4j.create(new double[] {1,1,1,0,0,0,1,1,1}); - assertEquals("Viewed array was not changed!", expected, toView); - } - - /** - * Test assignment to another {@link INDArray} from a {@link NonContiguous1DView}. - */ - @Test - public void assignTo() { - final INDArray toView = Nd4j.linspace(0,8,9); - - final NonContiguous1DView view = new NonContiguous1DView(); - view.addView(toView, 0, 3); - view.addView(toView, 6,9); - - final INDArray actual = Nd4j.create(view.length()); - view.assignTo(actual); - - final INDArray expected = Nd4j.create(new double[] {0,1,2,6,7,8}); - assertEquals("Viewed array was not changed!", expected, actual); - } -} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/OdeVertexTest.java b/src/test/java/ode/vertex/impl/OdeVertexTest.java index d5e39c8..9ec5558 100644 --- a/src/test/java/ode/vertex/impl/OdeVertexTest.java +++ b/src/test/java/ode/vertex/impl/OdeVertexTest.java @@ -1,5 +1,7 @@ package ode.vertex.impl; +import ode.solve.conf.DormandPrince54Solver; +import ode.vertex.conf.helper.InputStep; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -10,6 +12,7 @@ import org.nd4j.linalg.activations.impl.ActivationReLU; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; @@ -93,4 +96,34 @@ public void fit() { graph.fit(new DataSet(Nd4j.randn(new long[]{1, 1, 9, 9}), Nd4j.create(new double[] {0,1,0}))); assertNotEquals("Expected parameters to be updated!", before, graph.getVertex("1").params().dup()); } + + /** + * Smoke test to see that it is possible to do a forward pass when time is one of the inputs + */ + @Test + public void fitWithTimeAsInput() { + final long nOut = 8; + final long nrofTimeSteps = 10; + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .graphBuilder() + .addInputs("input", "time") + .setInputTypes(InputType.feedForward(5), InputType.feedForward(nrofTimeSteps)) + .addLayer("0", new DenseLayer.Builder().nOut(nOut).build(), "input") + .addVertex("odeVertex", new ode.vertex.conf.OdeVertex.Builder("ode0", + new DenseLayer.Builder().nOut(nOut).build()) + .odeConf(new InputStep(new DormandPrince54Solver(), 1)) + .build(), "0", "time") + .setOutputs("output") + .addLayer("output", new RnnOutputLayer.Builder().nOut(3).build(), "odeVertex") + .build()); + + graph.init(); + + final INDArray before = graph.getVertex("odeVertex").params().dup(); + final int batchSize = 3; + graph.fit(new MultiDataSet( + new INDArray[] {Nd4j.randn(new long[]{batchSize, 5}), Nd4j.linspace(0, 2, nrofTimeSteps)}, + new INDArray[] {Nd4j.repeat(Nd4j.create(new double[] {0,1,0}).transposei(), batchSize*(int)nrofTimeSteps)})); + assertNotEquals("Expected parameters to be updated!", before, graph.getVertex("odeVertex").params().dup()); + } } \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/gradview/Contiguous1DViewTest.java b/src/test/java/ode/vertex/impl/gradview/Contiguous1DViewTest.java new file mode 100644 index 0000000..0612839 --- /dev/null +++ b/src/test/java/ode/vertex/impl/gradview/Contiguous1DViewTest.java @@ -0,0 +1,54 @@ +package ode.vertex.impl.gradview; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; + +/** + * Test cases for {@link Contiguous1DView} + * + * @author Christian Skarby + */ +public class Contiguous1DViewTest { + + /** + * Test assignment of {@link Contiguous1DView} from another {@link INDArray}. Set first and last three elements + * and leave three in the middle untouched + */ + @Test + public void assignFrom() { + final INDArray toView = Nd4j.ones(13); + final INDArray other = Nd4j.zeros(toView.shape()).reshape(toView.length()); + + final INDArray1DView view = new Contiguous1DView(toView); + + view.assignFrom(other); + assertEquals("View not set!", other, toView); + } + + /** + * Test assignment to another {@link INDArray} from a {@link Contiguous1DView}. + */ + @Test + public void assignTo() { + final INDArray toView = Nd4j.ones(13); + final INDArray other = Nd4j.zeros(toView.shape()).reshape(toView.length()); + + final INDArray1DView view = new Contiguous1DView(toView); + + view.assignTo(other); + assertEquals("View not set!", other, toView); + } + + /** + * Test that length is correct + */ + @Test + public void length() { + final long length = 27; + + assertEquals("Incorrect length!", length, new Contiguous1DView(Nd4j.ones(length)).length()); + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklistedTest.java b/src/test/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklistedTest.java new file mode 100644 index 0000000..3a11f55 --- /dev/null +++ b/src/test/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklistedTest.java @@ -0,0 +1,125 @@ +package ode.vertex.impl.gradview; + +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.BatchNormalization; +import org.deeplearning4j.nn.conf.layers.Convolution2D; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.jetbrains.annotations.NotNull; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; + +import static junit.framework.TestCase.assertEquals; +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertFalse; + +/** + * Test cases for {@link GradientViewSelectionFromBlacklisted} + * + * @author Christian Skarby + */ +public class GradientViewSelectionFromBlacklistedTest { + + /** + * Test that gradients in the black list are not selected + */ + @Test + public void createWithBlacklisted() { + final ComputationGraph graph = createGraph(); + final GradientViewFactory factory = new GradientViewSelectionFromBlacklisted(); + final ParameterGradientView gradView = factory.create(graph); + + assertEquals("Incorrect number of parameter gradients in view!", + graph.getGradientsViewArray().length() - graph.getLayer("1").getGradientsViewArray().length() / 2, + gradView.realGradientView().length()); + + for(Map.Entry nameGradEntry: gradView.allGradientsPerParam().gradientForVariable().entrySet()){ + final Pair vertexAndParName = factory.paramNameMapping().reverseMap(nameGradEntry.getKey()); + final long[] expectedShape = graph.getLayer(vertexAndParName.getFirst()).getParam(vertexAndParName.getSecond()).shape(); + assertArrayEquals("Incorrect grad size for " + nameGradEntry.getKey() + "!", + expectedShape, + nameGradEntry.getValue().shape()); + } + } + + /** + * Test that gradients in the black list are not selected + */ + @Test + public void createNoBlackList() { + final ComputationGraph graph = createGraph(); + final GradientViewFactory factory =new GradientViewSelectionFromBlacklisted(new ArrayList<>()); + final ParameterGradientView gradView = factory.create(graph); + + assertEquals("Incorrect number of parameter gradients in view!", + graph.getGradientsViewArray().length(), + gradView.realGradientView().length()); + + for(Map.Entry nameGradEntry: gradView.allGradientsPerParam().gradientForVariable().entrySet()){ + final Pair vertexAndParName = factory.paramNameMapping().reverseMap(nameGradEntry.getKey()); + final long[] expectedShape = graph.getLayer(vertexAndParName.getFirst()).getParam(vertexAndParName.getSecond()).shape(); + assertArrayEquals("Incorrect grad size for " + nameGradEntry.getKey() + "!", + expectedShape, + nameGradEntry.getValue().shape()); + } + } + + /** + * Test that a clone is equal to the original + */ + @Test + public void clonetest() { + final GradientViewFactory factory = new GradientViewSelectionFromBlacklisted(Arrays.asList("ff", "gg")); + assertTrue("Clones shall be equal!" , factory.equals(factory.clone())); + } + + /** + * Test equals + */ + @Test + public void equals() { + assertTrue("Shall be equal!", new GradientViewSelectionFromBlacklisted().equals(new GradientViewSelectionFromBlacklisted())); + assertTrue("Shall be equal!", + new GradientViewSelectionFromBlacklisted(Arrays.asList("aa", "bb")).equals( + new GradientViewSelectionFromBlacklisted(Arrays.asList("aa", "bb")))); + assertFalse("Shall not be equal!", + new GradientViewSelectionFromBlacklisted(Arrays.asList("aa", "bb")).equals( + new GradientViewSelectionFromBlacklisted(Arrays.asList("aa", "bb", "cc")))); + } + + /** + * Test that a {@link GradientViewSelectionFromBlacklisted} can be serialized and then deserialized into the same thing + */ + @Test + public void serializeDeserialize() throws IOException { + final GradientViewFactory factory = new GradientViewSelectionFromBlacklisted(Arrays.asList("qq", "ww")); + final String json = new ObjectMapper().writeValueAsString(factory); + final GradientViewFactory deserialized = new ObjectMapper().readValue(json, GradientViewSelectionFromBlacklisted.class); + assertTrue("Did not deserialize to the same thing!", factory.equals(deserialized)); + } + + @NotNull + ComputationGraph createGraph() { + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .graphBuilder() + .addInputs("input") + .addLayer("0", new Convolution2D.Builder().convolutionMode(ConvolutionMode.Same).nOut(3).build(), "input") + .addLayer("1", new BatchNormalization.Builder().nOut(3).build(), "0") + .addLayer("2", new Convolution2D.Builder().convolutionMode(ConvolutionMode.Same).nOut(3).build(), "1") + .setOutputs("2") + .setInputTypes(InputType.convolutional(5, 5, 3)) + .build()); + graph.init(); + graph.initGradientsView(); + return graph; + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/gradview/NonContiguous1DViewTest.java b/src/test/java/ode/vertex/impl/gradview/NonContiguous1DViewTest.java new file mode 100644 index 0000000..d240d72 --- /dev/null +++ b/src/test/java/ode/vertex/impl/gradview/NonContiguous1DViewTest.java @@ -0,0 +1,102 @@ +package ode.vertex.impl.gradview; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import static org.junit.Assert.assertEquals; + +/** + * Test cases for {@link NonContiguous1DView} + * + * @author Christian Skarby + */ +public class NonContiguous1DViewTest { + + /** + * Test assignment of {@link NonContiguous1DView} from another {@link INDArray}. Set first and last three elements + * and leave three in the middle untouched + */ + @Test + public void assignFrom() { + final INDArray toView = Nd4j.zeros(9); + + final NonContiguous1DView view = new NonContiguous1DView(); + view.addView(toView.get(NDArrayIndex.interval(0, 3))); + view.addView(toView.get(NDArrayIndex.interval(6, 9))); + view.assignFrom(Nd4j.ones(new long[] {6})); + + final INDArray expected = Nd4j.create(new double[] {1,1,1,0,0,0,1,1,1}); + assertEquals("Viewed array was not changed!", expected, toView); + } + + /** + * Test assignment to another {@link INDArray} from a {@link NonContiguous1DView}. + */ + @Test + public void assignTo() { + final INDArray toView = Nd4j.linspace(0,8,9); + + final NonContiguous1DView view = new NonContiguous1DView(); + view.addView(toView.get(NDArrayIndex.interval(0, 3))); + view.addView(toView.get(NDArrayIndex.interval(6, 9))); + + final INDArray actual = Nd4j.create(view.length()); + view.assignTo(actual); + + final INDArray expected = Nd4j.create(new double[] {0,1,2,6,7,8}); + assertEquals("Viewed array was not changed!", expected, actual); + } + + /** + * Test that length is correct + */ + @Test + public void length() { + final long length0 = 13; + final long length1 = 7; + + final INDArray toView = Nd4j.ones(length0+length1+10).reshape(length0+length1+10); + + final NonContiguous1DView view = new NonContiguous1DView(); + view.addView(toView.get(NDArrayIndex.interval(0, length0))); + view.addView(toView.get(NDArrayIndex.interval(length0, length0+length1))); + + assertEquals("Incorrect length!", length0 + length1, view.length()); + } + + /** + * Test assignment of {@link NonContiguous1DView} with one 2x3x4 view and one 2x2 view from another {@link INDArray}. + */ + @Test + public void assignFrom2x3x4and2x2() { + final INDArray toView = Nd4j.zeros(2*3*4*5); + + final NonContiguous1DView view = new NonContiguous1DView(); + view.addView(toView.get(NDArrayIndex.interval(0, 2*3*4)).reshape(2,3,4)); + view.addView(toView.get(NDArrayIndex.interval(2*3*4+4, 2*3*4+8)).reshape(2,2)); + view.assignFrom(Nd4j.ones(view.length())); + + final double expectedSum = view.length(); + assertEquals("Viewed array was not changed!", expectedSum, toView.sumNumber().doubleValue(),1e-10); + } + + /** + * Test assignment to another {@link INDArray} from a {@link NonContiguous1DView}. + */ + @Test + public void assignTo2x3x4() { + final INDArray toView = Nd4j.linspace(0,36,37); + + final NonContiguous1DView view = new NonContiguous1DView(); + view.addView(toView.get(NDArrayIndex.interval(0, 12)).reshape(2,3,2)); + view.addView(toView.get(NDArrayIndex.interval(17, 17+10)).reshape(5,2)); + + final INDArray actual = Nd4j.create(view.length()); + view.assignTo(actual); + + final INDArray expected = Nd4j.create(new double[] {0,1,2,3,4,5,6,7,8,9,10,11, 17,18,19,20,21,22,23,24,25,26}); + assertEquals("Viewed array was not changed!", expected, actual); + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/gradview/parname/ConcatTest.java b/src/test/java/ode/vertex/impl/gradview/parname/ConcatTest.java new file mode 100644 index 0000000..53f6d63 --- /dev/null +++ b/src/test/java/ode/vertex/impl/gradview/parname/ConcatTest.java @@ -0,0 +1,38 @@ +package ode.vertex.impl.gradview.parname; + +import org.junit.Test; +import org.nd4j.linalg.primitives.Pair; + +import static org.junit.Assert.assertEquals; + +/** + * Test cases for {@link Concat} + * + * @author Christian Skarby + */ +public class ConcatTest { + + /** + * Test that input is concatenated + */ + @Test + public void map() { + assertEquals("vertex-param", new Concat().map("vertex", "param")); + } + + /** + * Test that input is de-concatenated + */ + @Test + public void reverseMap() { + assertEquals(new Pair<>("vertex", "param"), new Concat().reverseMap("vertex-param")); + } + + /** + * Test that an error is thrown for illegal input + */ + @Test(expected = IllegalArgumentException.class) + public void reverseMapIllegal() { + new Concat().reverseMap("aaa-vertex-param"); + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/gradview/parname/PrefixTest.java b/src/test/java/ode/vertex/impl/gradview/parname/PrefixTest.java new file mode 100644 index 0000000..6dfc909 --- /dev/null +++ b/src/test/java/ode/vertex/impl/gradview/parname/PrefixTest.java @@ -0,0 +1,30 @@ +package ode.vertex.impl.gradview.parname; + +import org.junit.Test; +import org.nd4j.linalg.primitives.Pair; + +import static org.junit.Assert.assertEquals; + +/** + * Test cases for {@link Prefix} + * + * @author Christian Skarby + */ +public class PrefixTest { + + /** + * Test that a prefix is added + */ + @Test + public void map() { + assertEquals("prefix_vertex-param", new Prefix("prefix_", new Concat()).map("vertex", "param")); + } + + /** + * Test that prefixing is reversed correctly + */ + @Test + public void reverseMap() { + assertEquals(new Pair<>("vertex", "param"), new Prefix("prefix_", new Concat()).reverseMap("prefix_vertex-param")); + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/helper/backward/MultiStepAdjointTest.java b/src/test/java/ode/vertex/impl/helper/backward/MultiStepAdjointTest.java new file mode 100644 index 0000000..562a9ca --- /dev/null +++ b/src/test/java/ode/vertex/impl/helper/backward/MultiStepAdjointTest.java @@ -0,0 +1,327 @@ +package ode.vertex.impl.helper.backward; + +import ode.solve.conf.SolverConfig; +import ode.solve.impl.DormandPrince54Solver; +import ode.vertex.conf.OdeVertex; +import ode.vertex.conf.helper.InputStep; +import ode.vertex.impl.gradview.NonContiguous1DView; +import ode.vertex.impl.helper.backward.timegrad.CalcMultiStepTimeGrad; +import ode.vertex.impl.helper.backward.timegrad.NoMultiStepTimeGrad; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.distribution.ConstantDistribution; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.GraphVertex; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.jetbrains.annotations.NotNull; +import org.junit.Test; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.primitives.Pair; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertNotEquals; + +/** + * Test cases for {@link MultiStepAdjoint} + * + * @author Christian Skarby + */ +public class MultiStepAdjointTest { + + /** + * Smoke test for backwards solve without time gradients + */ + @Test + public void solveNoTime() { + final int nrofInputs = 7; + final int nrofTimeSteps = 5; + final ComputationGraph graph = SingleStepAdjointTest.getTestGraph(nrofInputs); + final OdeHelperBackward.InputArrays inputArrays = getTestInputArrays(nrofInputs, nrofTimeSteps, graph); + + final INDArray time = Nd4j.arange(nrofTimeSteps); + final OdeHelperBackward helper = new MultiStepAdjoint( + new DormandPrince54Solver(new SolverConfig(1e-3, 1e-3, 0.1, 10)), + time, NoMultiStepTimeGrad.factory); + + INDArray[] gradients = helper.solve(graph, inputArrays, new OdeHelperBackward.MiscPar( + false, + LayerWorkspaceMgr.noWorkspacesImmutable())); + + assertEquals("Incorrect number of input gradients!", 1, gradients.length); + + final INDArray inputGrad = gradients[0]; + assertArrayEquals("Incorrect input gradient shape!", inputArrays.getLastInputs()[0].shape(), inputGrad.shape()); + + final INDArray parGrad = graph.getGradientsViewArray(); + assertNotEquals("Expected non-zero parameter gradient!", 0.0, parGrad.maxNumber().doubleValue(), 1e-10); + assertNotEquals("Expected non-zero input gradient!", 0.0, inputGrad.maxNumber().doubleValue(), 1e-10); + } + + /** + * Smoke test for backwards solve without time gradients + */ + @Test + public void solveWithTime() throws InterruptedException { + final int nrofInputs = 5; + final int nrofTimeSteps = 7; + final ComputationGraph graph = SingleStepAdjointTest.getTestGraph(nrofInputs); + final OdeHelperBackward.InputArrays inputArrays = getTestInputArrays(nrofInputs, nrofTimeSteps, graph); + + final INDArray time = Nd4j.arange(nrofTimeSteps); + final OdeHelperBackward helper = new MultiStepAdjoint( + new DormandPrince54Solver(new SolverConfig(1e-3, 1e-3, 0.1, 10)), + time, new CalcMultiStepTimeGrad.Factory(time, 1)); + + INDArray[] gradients = helper.solve(graph, inputArrays, new OdeHelperBackward.MiscPar( + false, + LayerWorkspaceMgr.noWorkspacesImmutable())); + + + assertEquals("Incorrect number of input gradients!", 2, gradients.length); + + final INDArray inputGrad = gradients[0]; + final INDArray timeGrad = gradients[1]; + assertArrayEquals("Incorrect input gradient shape!", inputArrays.getLastInputs()[0].shape(), inputGrad.shape()); + assertArrayEquals("Incorrect time gradient shape!", time.shape(), timeGrad.shape()); + + final INDArray parGrad = graph.getGradientsViewArray(); + assertNotEquals("Expected non-zero parameter gradient!", 0.0, parGrad.maxNumber().doubleValue(), 1e-10); + assertNotEquals("Expected non-zero input gradient!", 0.0, inputGrad.maxNumber().doubleValue(), 1e-10); + assertNotEquals("Expected non-zero time gradient!", 0.0, timeGrad.maxNumber().doubleValue(), 1e-10); + } + + @NotNull + private static OdeHelperBackward.InputArrays getTestInputArrays(int nrofInputs, int nrofTimeSteps, ComputationGraph graph) { + final int batchSize = 3; + final INDArray input = Nd4j.arange(batchSize * nrofInputs).reshape(batchSize, nrofInputs); + final INDArray output = Nd4j.arange(batchSize * nrofInputs * nrofTimeSteps).reshape(batchSize, nrofInputs, nrofTimeSteps); + final INDArray epsilon = Nd4j.ones(batchSize, nrofInputs, nrofTimeSteps).assign(0.01); + final NonContiguous1DView realGrads = new NonContiguous1DView(); + realGrads.addView(graph.getGradientsViewArray()); + + return new OdeHelperBackward.InputArrays( + new INDArray[]{input}, + output, + epsilon, + realGrads + ); + } + + /** + * Test that an exception is thrown if time array is not sorted + */ + @Test(expected = IllegalArgumentException.class) + public void timeNotSorted() { + final INDArray time = Nd4j.create(new double[]{0, 1, 2, 1.5, 3, 4}); + new MultiStepAdjoint(new DormandPrince54Solver( + new SolverConfig(1, 1, 1e-1, 1)), + time, new CalcMultiStepTimeGrad.Factory(time, 1)); + } + + /** + * Test the result vs the result from the original repo. Reimplementation of test_adjoint in + * https://github.com/rtqichen/torchdiffeq/blob/master/tests/gradient_tests.py. + *

+ * Numbers below from the following test ODE: + *

+ *
+     *     class Linear1DODE(torch.nn.Module):
+     *
+     *     def __init__(self, device):
+     *         super(Linear1DODE, self).__init__()
+     *         self.a = torch.nn.Parameter(torch.tensor([0.2, 0.2, 0.2, 0.2]).reshape(2,2).to(device))
+     *         self.b = torch.nn.Parameter(torch.tensor([3.0, 3.0]).reshape(1,2).to(device))
+     *
+     *     def forward(self, t, y):
+     *         return y.matmul(self.a)+ self.b
+     *
+     *     def y_exact(self, t):
+     *         return t
+     * Test code (to set same inputs as below):
+     *         f, y0, t_points, _ = construct_problem(TEST_DEVICE, ode='linear1D')
+     *
+     *         y0 = torch.linspace(-1.23,2.34,4).reshape(2, 2).type_as(y0)
+     *         y0.requires_grad = True
+     *
+     *         print('t: ', t_points)
+     *         print('y0: ', y0)
+     *
+     *         params = (list(f.parameters()))
+     *         optimizer = torch.optim.SGD(params, lr=0.0001)
+     *
+     *
+     *         optimizer.zero_grad()
+     *         func = lambda y0, t_points: torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
+     *         ys = func(y0, t_points).permute(1, 2, 0)
+     *
+     *         gradys = torch.linspace(3,  13, ys.numel()).type_as(ys).reshape(ys.size())
+     *         for i in range (0, int(gradys.size()[0]/2)):
+     *             gradys[i*2, :, i*2+1] = -gradys[i*2,:,i*2+1]
+     *
+     *         ys.backward(gradys)
+     * 
+ */ + @Test + public void testGradVsReferenceLinear1dOdeForward() { + + final long nrofTimeSteps = 10; + final long nrofDims = 2; + + final GraphVertex odevert = createOdeVertex(nrofTimeSteps, nrofDims); + + final INDArray time = Nd4j.linspace(1, 8, nrofTimeSteps); + final INDArray y0 = Nd4j.linspace(-1.23, 2.34, 2 * nrofDims).reshape(2, nrofDims); + + odevert.setInputs(y0, time); + final INDArray ys = odevert.doForward(true, LayerWorkspaceMgr.noWorkspacesImmutable()); + + // From original repo + final double[][][] expectedYs = {{{-1.23, 1.2753194351313588, 4.694932254437282, 9.36250433337755, 15.73345986285669, 24.429438907671752, 36.29894018833409, 52.50011823313958, 74.61376567400269, 104.79757535834692}, + {-0.040000000000000036, 2.4653194351313603, 5.884932254437281, 10.552504333377543, 16.92345986285671, 25.61943890767168, 37.488940188334105, 53.69011823313958, 75.8037656740027, 105.98757535834685}}, + {{1.15, 4.523878831433274, 9.129023844467975, 15.414778231911933, 23.994455416185012, 35.70520942482531, 51.689701681157864, 73.50760277718506, 103.28774415967294, 143.93586077027197}, + {2.34, 5.7138788314332745, 10.319023844467976, 16.604778231911936, 25.184455416185003, 36.89520942482521, 52.879701681157854, 74.69760277718505, 104.47774415967294, 145.12586077027197}}}; + + for (int i = 0; i < expectedYs.length; i++) { + for (int j = 0; j < expectedYs.length; j++) { + assertArrayEquals("Incorrect ys: ", + expectedYs[i][j], + ys.reshape(ys.size(0), nrofDims, nrofTimeSteps).getRow(i).getRow(j).toDoubleVector(), 1e-3); + } + } + + final INDArray lossgrad = Nd4j.linspace(3, 13, ys.length()).reshape(ys.shape()); + for (int i = 0; i < lossgrad.size(0) / 2; i++) { + lossgrad.get(NDArrayIndex.point(i * 2), NDArrayIndex.all(), NDArrayIndex.point(i * 2 + 1)).negi(); + } + + odevert.setEpsilon(lossgrad.reshape(ys.shape())); + Pair grads = odevert.doBackward(false, LayerWorkspaceMgr.noWorkspacesImmutable()); + + final double[][] expectedYsGrad = {{330.3413671858052, 350.8541876986256}, {641.5288548171546, 667.1698804581804}}; + for (int i = 0; i < expectedYsGrad.length; i++) { + assertArrayEquals("Incorrect loss gradient: ", + expectedYsGrad[i], + grads.getSecond()[0].getRow(i).toDoubleVector(), 1e-3); + } + + // Note: Order of element 1 and 2 are swapped + final double[] expectedParsGrad = {26399.28182908193, 28805.83177942673, 29982.48959443847, 32597.88284962647, 2022.3108828170273, 2197.8094583156044}; + // Compare one by one due to large dynamic range + for (int i = 0; i < expectedParsGrad.length; i++) { + assertEquals("Incorrect parameter gradient for param " + i + "!", + 1, + grads.getFirst().gradient().getDouble(i) / expectedParsGrad[i], 1e-4); + } + + // First element from reference implementation sure looks weird... + final double[] expectedTimeGrad = {-6617.017543489908, 63.564528808863464, 185.79311916695133, 262.00021152296233, 369.0851188922797, 519.4357040639363, + 730.3690674996153, 1026.0795986642863, 1440.3518309918456, 2020.3383638791681}; + // Compare one by one due to large dynamic range + for (int i = 0; i < expectedTimeGrad.length; i++) { + assertEquals("Incorrect time gradient for param " + i + "!", 1, grads.getSecond()[1].getDouble(i) / expectedTimeGrad[i], 1e-4); + } + } + + /** + * Same test as above but backwards + */ + @Test + public void testGradVsReferenceLinear1dOdeBackward() { + + final long nrofTimeSteps = 10; + final long nrofDims = 2; + + final GraphVertex odevert = createOdeVertex(nrofTimeSteps, nrofDims); + + final INDArray time = Nd4j.linspace(-1, -8, nrofTimeSteps); + final INDArray y0 = Nd4j.linspace(-1.23, 2.34, 2 * nrofDims).reshape(2, nrofDims); + + odevert.setInputs(y0, time); + final INDArray ys = odevert.doForward(true, LayerWorkspaceMgr.noWorkspacesImmutable()); + + // From original repo + final double[][][] expectedYs = {{{-1.23, -3.0654783048539187, -4.410209436983333, -5.395402672421891, -6.117186729574334, -6.645989569834943, -7.033407758920375, -7.317242594011386, -7.525189149078903, -7.677538128370835}, + {-0.040000000000000036, -1.875478304853917, -3.2202094369833336, -4.205402672421891, -4.927186729574334, -5.4559895698349425, -5.843407758920375, -6.127242594011386, -6.335189149078852, -6.487538128370829}}, + {{1.15, -1.321813099544714, -3.132743808435675, -4.459489833436336, -5.4315063823619125, -6.143637811088716, -6.665368496900053, -7.047604920850001, -7.327643653785053, -7.532809904849017}, + {2.34, -0.13181309954471376, -1.9427438084356758, -3.2694898334363294, -4.241506382361924, -4.953637811088717, -5.475368496900053, -5.857604920850002, -6.137643653785053, -6.342809904849003}}}; + + for (int i = 0; i < expectedYs.length; i++) { + for (int j = 0; j < expectedYs.length; j++) { + assertArrayEquals("Incorrect ys: ", + expectedYs[i][j], + ys.reshape(ys.size(0), nrofDims, nrofTimeSteps).getRow(i).getRow(j).toDoubleVector(), 1e-3); + } + } + + final INDArray lossgrad = Nd4j.linspace(3, 13, ys.length()).reshape(ys.shape()); + for (int i = 0; i < lossgrad.size(0) / 2; i++) { + lossgrad.get(NDArrayIndex.point(i * 2), NDArrayIndex.all(), NDArrayIndex.point(i * 2 + 1)).negi(); + } + + odevert.setEpsilon(lossgrad.reshape(ys.shape())); + Pair grads = odevert.doBackward(false, LayerWorkspaceMgr.noWorkspacesImmutable()); + + final double[][] expectedYsGrad = {{0.4791608727137171, 20.99198138553419}, {22.890949753489018, 48.53197539451459}}; + for (int i = 0; i < expectedYsGrad.length; i++) { + assertArrayEquals("Incorrect loss gradient: ", + expectedYsGrad[i], + grads.getSecond()[0].getRow(i).toDoubleVector(), 1e-3); + } + + // Note: Order of element 1 and 2 are swapped + final double[] expectedParsGrad = {902.9542328576254, 696.559656699922, 1683.4912664748028, 1268.25338547379, -173.44082030059053, -348.93939579916616}; + // Compare one by one due to large dynamic range + for (int i = 0; i < expectedParsGrad.length; i++) { + assertEquals("Incorrect parameter gradient for param " + i + "!", + 1, + grads.getFirst().gradient().getDouble(i) / expectedParsGrad[i], 1e-4); + } + + // First element from reference implementation sure looks weird... + final double[] expectedTimeGrad = {-229.96656912353825, 34.11827941396485, 53.52715859322072, 40.515245326331545, 30.634856397062812, 23.14160607429379, + 17.465314677923114, 13.17005445759693, 9.923109065345606, 7.470945117798939}; + // Compare one by one due to large dynamic range + for (int i = 0; i < expectedTimeGrad.length; i++) { + assertEquals("Incorrect time gradient for param " + i + "!", 1, grads.getSecond()[1].getDouble(i) / expectedTimeGrad[i], 1e-4); + } + } + + GraphVertex createOdeVertex(long nrofTimeSteps, long nrofDims) { + final ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() + .seed(666) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .graphBuilder() + .setInputTypes(InputType.feedForward(nrofDims), InputType.feedForward(nrofTimeSteps)); + + String next = "y0"; + builder.addInputs(next, "time"); + + builder.addVertex("ode", new OdeVertex.Builder("0", new DenseLayer.Builder() + .activation(new ActivationIdentity()) + .weightInit(new ConstantDistribution(0.2)) + .biasInit(3) + .nOut(nrofDims) + .build()) + .odeConf(new InputStep( + new ode.solve.conf.DormandPrince54Solver( + new SolverConfig(1e-12, 1e-6, 1e-20, 1e2)), + 1, true, true)) + .build(), next, "time"); + + builder.allowNoOutput(true); + + final ComputationGraph graph = new ComputationGraph(builder.build()); + graph.init(); + graph.initGradientsView(); + + return graph.getVertex("ode"); + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/helper/backward/SingleStepAdjointTest.java b/src/test/java/ode/vertex/impl/helper/backward/SingleStepAdjointTest.java new file mode 100644 index 0000000..462fe32 --- /dev/null +++ b/src/test/java/ode/vertex/impl/helper/backward/SingleStepAdjointTest.java @@ -0,0 +1,126 @@ +package ode.vertex.impl.helper.backward; + +import ode.solve.conf.SolverConfig; +import ode.solve.impl.DormandPrince54Solver; +import ode.vertex.impl.gradview.NonContiguous1DView; +import ode.vertex.impl.helper.backward.timegrad.CalcTimeGrad; +import ode.vertex.impl.helper.backward.timegrad.NoTimeGrad; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.distribution.ConstantDistribution; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.jetbrains.annotations.NotNull; +import org.junit.Test; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertNotEquals; + +/** + * Test cases for {@link SingleStepAdjoint} + * + * @author Christian Skarby + */ +public class SingleStepAdjointTest { + + /** + * Smoke test for backwards solve without time gradients + */ + @Test + public void solveNoTime() { + final int nrofInputs = 7; + final ComputationGraph graph = getTestGraph(nrofInputs); + final OdeHelperBackward.InputArrays inputArrays = getTestInputArrays(nrofInputs, graph); + + final INDArray time = Nd4j.arange(2); + final OdeHelperBackward helper = new SingleStepAdjoint( + new DormandPrince54Solver(new SolverConfig(1e-3, 1e-3, 0.1, 10)), + time, NoTimeGrad.factory); + + INDArray[] gradients = helper.solve(graph, inputArrays, new OdeHelperBackward.MiscPar( + false, + LayerWorkspaceMgr.noWorkspaces())); + + assertEquals("Incorrect number of input gradients!", 1, gradients.length); + + final INDArray inputGrad = gradients[0]; + assertArrayEquals("Incorrect input gradient shape!", inputArrays.getLastInputs()[0].shape(), inputGrad.shape()); + + final INDArray parGrad = graph.getGradientsViewArray(); + assertNotEquals("Expected non-zero parameter gradient!", 0.0, parGrad.sumNumber().doubleValue(),1e-10); + assertNotEquals("Expected non-zero input gradient!", 0.0, inputGrad.sumNumber().doubleValue(), 1e-10); + } + + /** + * Smoke test for backwards solve without time gradients + */ + @Test + public void solveWithTime() { + final int nrofInputs = 5; + final ComputationGraph graph = getTestGraph(nrofInputs); + final OdeHelperBackward.InputArrays inputArrays = getTestInputArrays(nrofInputs, graph); + + final INDArray time = Nd4j.arange(2); + final OdeHelperBackward helper = new SingleStepAdjoint( + new DormandPrince54Solver(new SolverConfig(1e-3, 1e-3, 0.1, 10)), + time, new CalcTimeGrad.Factory(inputArrays.getLossGradient(), 1)); + + INDArray[] gradients = helper.solve(graph, inputArrays, new OdeHelperBackward.MiscPar( + false, + LayerWorkspaceMgr.noWorkspaces())); + + assertEquals("Incorrect number of input gradients!", 2, gradients.length); + + final INDArray inputGrad = gradients[0]; + final INDArray timeGrad = gradients[1]; + assertArrayEquals("Incorrect input gradient shape!", inputArrays.getLastInputs()[0].shape(), inputGrad.shape()); + assertArrayEquals("Incorrect time gradient shape!", time.shape(), timeGrad.shape()); + + final INDArray parGrad = graph.getGradientsViewArray(); + assertNotEquals("Expected non-zero parameter gradient!", 0.0, parGrad.sumNumber().doubleValue(),1e-10); + assertNotEquals("Expected non-zero input gradient!", 0.0, inputGrad.sumNumber().doubleValue(), 1e-10); + assertNotEquals("Expected non-zero time gradient!", 0.0, timeGrad.sumNumber().doubleValue(), 1e-10); + } + + @NotNull + private static OdeHelperBackward.InputArrays getTestInputArrays(int nrofInputs, ComputationGraph graph) { + final INDArray input = Nd4j.arange(nrofInputs); + final INDArray output = input.add(1); + final INDArray epsilon = Nd4j.ones(nrofInputs).assign(0.01); + final NonContiguous1DView realGrads = new NonContiguous1DView(); + realGrads.addView(graph.getGradientsViewArray()); + + return new OdeHelperBackward.InputArrays( + new INDArray[]{input}, + output, + epsilon, + realGrads + ); + } + + /** + * + * Create a simple graph for testing + * @param nrofInputs Determines the number of inputs (nIn) to the graph + * @return a {@link ComputationGraph} + */ + @NotNull + static ComputationGraph getTestGraph(int nrofInputs) { + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .weightInit(new ConstantDistribution(0.01)) + .graphBuilder() + .setInputTypes(InputType.feedForward(nrofInputs)) + .addInputs("input") + .addLayer("dense", new DenseLayer.Builder().nOut(nrofInputs).activation(new ActivationIdentity()).build(), "input") + .allowNoOutput(true) + .build()); + graph.init(); + graph.initGradientsView(); + return graph; + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/helper/forward/ForwardPassTest.java b/src/test/java/ode/vertex/impl/helper/forward/ForwardPassTest.java new file mode 100644 index 0000000..9a2ff81 --- /dev/null +++ b/src/test/java/ode/vertex/impl/helper/forward/ForwardPassTest.java @@ -0,0 +1,53 @@ +package ode.vertex.impl.helper.forward; + +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.junit.Test; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; + +/** + * Test cases for {@link ForwardPass} + * + * @author Christian Skarby + */ +public class ForwardPassTest { + + /** + * Test that the derivative is a forward pass through the layers + */ + @Test + public void calculateDerivative() { + final long nrofInputs = 5; + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .graphBuilder() + .setInputTypes(InputType.feedForward(nrofInputs)) + .allowNoOutput(true) + .addInputs("input") + // Very simple dense layer which just performs element wise multiplication of input + .addLayer("dense", new DenseLayer.Builder() + .nOut(nrofInputs) + .hasBias(false) + .weightInit(WeightInit.IDENTITY) + .activation(new ActivationIdentity()) + .build(), "input") + .build()); + graph.init(); + double mul = 1.23; + graph.params().muli(mul); + + final INDArray input = Nd4j.arange(nrofInputs); + final INDArray expected = input.mul(mul); + final INDArray actual = new ForwardPass(graph, LayerWorkspaceMgr.noWorkspaces(), false, new INDArray[]{input}) + .calculateDerivative(input, Nd4j.scalar(0), input.dup()); + + assertEquals("Incorrect output!", expected, actual); + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/helper/forward/InputStepTest.java b/src/test/java/ode/vertex/impl/helper/forward/InputStepTest.java new file mode 100644 index 0000000..d444be6 --- /dev/null +++ b/src/test/java/ode/vertex/impl/helper/forward/InputStepTest.java @@ -0,0 +1,69 @@ +package ode.vertex.impl.helper.forward; + +import ode.solve.api.FirstOrderSolver; +import ode.solve.conf.SolverConfig; +import ode.solve.impl.DormandPrince54Solver; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Test cases for {@link InputStep}. Also tests {@link FixedStep} as {@link InputStep} depends on it + * + * @author Christian Skarby + */ +public class InputStepTest { + + /** + * Test if a simple ODE can be solved + */ + @Test + public void solveSingleStep() { + final double exponent = 2; + final int nrofInputs = 7; + final ComputationGraph graph = SingleStepTest.getSimpleExpGraph(exponent, nrofInputs); + + final INDArray input = Nd4j.arange(nrofInputs); + final INDArray expected = input.mul(Math.exp(2)); + + final FirstOrderSolver solver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 100)); + final OdeHelperForward helper = new InputStep( + solver, + 1, false); + + final INDArray actual = helper.solve(graph, LayerWorkspaceMgr.noWorkspaces(), new INDArray[]{input, Nd4j.linspace(0, 1, 2)}); + + assertArrayEquals("Incorrect answer!", expected.toDoubleVector(),actual.toDoubleVector(), 1e-3); + } + + + /** + * Test if a simple ODE can be solved + */ + @Test + public void solveMultiStep() { + final INDArray t = Nd4j.arange(5); + final double exponent = 0.12; + final int nrofInputs = 4; + final ComputationGraph graph = SingleStepTest.getSimpleExpGraph(exponent, nrofInputs); + + final INDArray input = Nd4j.arange(nrofInputs); + final INDArray expected = input.transpose().mmul(Transforms.exp(t.mul(exponent))); + + final FirstOrderSolver solver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 100)); + final OdeHelperForward helper = new InputStep( + solver, + 1, false); + + final INDArray actual = helper.solve(graph, LayerWorkspaceMgr.noWorkspaces(), new INDArray[]{input, t}).reshape(expected.shape()); + + for(int row = 0; row < actual.rows(); row++) { + assertArrayEquals("Incorrect answer!", expected.getRow(row).toDoubleVector(), actual.getRow(row).toDoubleVector(), 1e-3); + } + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/helper/forward/MultiStepTest.java b/src/test/java/ode/vertex/impl/helper/forward/MultiStepTest.java new file mode 100644 index 0000000..c7caf56 --- /dev/null +++ b/src/test/java/ode/vertex/impl/helper/forward/MultiStepTest.java @@ -0,0 +1,47 @@ +package ode.vertex.impl.helper.forward; + +import ode.solve.api.FirstOrderSolver; +import ode.solve.conf.SolverConfig; +import ode.solve.impl.DormandPrince54Solver; +import ode.solve.impl.SingleSteppingMultiStepSolver; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Test cases for {@link MultiStep} + * + * @author Christian Skarby + */ +public class MultiStepTest { + + /** + * Test if a simple ODE can be solved + */ + @Test + public void solve() { + final INDArray t = Nd4j.arange(4); + final double exponent = 1.23; + final int nrofInputs = 3; + final ComputationGraph graph = SingleStepTest.getSimpleExpGraph(exponent, nrofInputs); + + final INDArray input = Nd4j.arange(nrofInputs); + final INDArray expected = input.transpose().mmul(Transforms.exp(t.mul(exponent))); + + final FirstOrderSolver solver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 100)); + final OdeHelperForward helper = new MultiStep( + new SingleSteppingMultiStepSolver(solver), + t); + + final INDArray actual = helper.solve(graph, LayerWorkspaceMgr.noWorkspaces(), new INDArray[]{input}).reshape(expected.shape()); + + for(int row = 0; row < actual.rows(); row++) { + assertArrayEquals("Incorrect answer!", expected.getRow(row).toDoubleVector(), actual.getRow(row).toDoubleVector(), 1e-3); + } + } +} \ No newline at end of file diff --git a/src/test/java/ode/vertex/impl/helper/forward/SingleStepTest.java b/src/test/java/ode/vertex/impl/helper/forward/SingleStepTest.java new file mode 100644 index 0000000..7ba9f13 --- /dev/null +++ b/src/test/java/ode/vertex/impl/helper/forward/SingleStepTest.java @@ -0,0 +1,66 @@ +package ode.vertex.impl.helper.forward; + +import ode.solve.api.FirstOrderSolver; +import ode.solve.conf.SolverConfig; +import ode.solve.impl.DormandPrince54Solver; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.graph.ScaleVertex; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.jetbrains.annotations.NotNull; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Test cases for {@link SingleStep} + * + * @author Christian Skarby + */ +public class SingleStepTest { + + /** + * Test if a simple ODE can be solved + */ + @Test + public void solve() { + final double exponent = 2; + final int nrofInputs = 7; + final ComputationGraph graph = getSimpleExpGraph(exponent, nrofInputs); + + final INDArray input = Nd4j.arange(nrofInputs); + final INDArray expected = input.mul(Math.exp(2)); + + final FirstOrderSolver solver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 100)); + final OdeHelperForward helper = new SingleStep( + solver, + Nd4j.linspace(0, 1, 2)); + + final INDArray actual = helper.solve(graph, LayerWorkspaceMgr.noWorkspaces(), new INDArray[]{input}); + + assertArrayEquals("Incorrect answer!", expected.toDoubleVector(),actual.toDoubleVector(), 1e-3); + } + + /** + * dy/dt = exponent*y => y = y0*e^(exponent*t) + * @param exponent Exponent of e + * @param nrofInputs Number of elements in y + * @return a {@link ComputationGraph} + */ + @NotNull + static ComputationGraph getSimpleExpGraph(double exponent, int nrofInputs) { + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + .graphBuilder() + .setInputTypes(InputType.feedForward(nrofInputs)) + .addInputs("input") + + .addVertex("scale", new ScaleVertex(exponent), "input") + .allowNoOutput(true) + .build()); + graph.init(); + return graph; + } +} \ No newline at end of file diff --git a/src/test/java/util/listen/step/MaskTest.java b/src/test/java/util/listen/step/MaskTest.java index 4e68ee1..53190c9 100644 --- a/src/test/java/util/listen/step/MaskTest.java +++ b/src/test/java/util/listen/step/MaskTest.java @@ -1,6 +1,7 @@ package util.listen.step; import ode.solve.api.StepListener; +import ode.solve.impl.util.StateContainer; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -34,13 +35,13 @@ public void backward() { private void testMask(ProbeStepListener probeStepListener, StepListener mask, INDArray masked, INDArray notMasked) { mask.begin(masked, Nd4j.zeros(0)); - mask.step(Nd4j.zeros(0), Nd4j.zeros(0), Nd4j.zeros(0), Nd4j.zeros(0)); + mask.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(0), Nd4j.zeros(0)); mask.done(); probeStepListener.assertNrofCalls(0,0,0); mask.begin(notMasked, Nd4j.zeros(0)); - mask.step(Nd4j.zeros(0), Nd4j.zeros(0), Nd4j.zeros(0), Nd4j.zeros(0)); + mask.step(new StateContainer(1, new double[] {0}, new double[] {0}), Nd4j.zeros(0), Nd4j.zeros(0)); mask.done(); probeStepListener.assertNrofCalls(1,1,1); diff --git a/src/test/java/util/listen/step/ProbeStepListener.java b/src/test/java/util/listen/step/ProbeStepListener.java index 0659237..a8ad9c4 100644 --- a/src/test/java/util/listen/step/ProbeStepListener.java +++ b/src/test/java/util/listen/step/ProbeStepListener.java @@ -1,6 +1,7 @@ package util.listen.step; import ode.solve.api.StepListener; +import ode.solve.impl.util.SolverState; import org.nd4j.linalg.api.ndarray.INDArray; import static org.junit.Assert.assertEquals; @@ -23,7 +24,7 @@ public void begin(INDArray t, INDArray y0) { } @Override - public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) { + public void step(SolverState solverState, INDArray step, INDArray error) { nrofStep++; } diff --git a/src/test/java/util/listen/step/StepCounterTest.java b/src/test/java/util/listen/step/StepCounterTest.java index 979d0bd..409d3ed 100644 --- a/src/test/java/util/listen/step/StepCounterTest.java +++ b/src/test/java/util/listen/step/StepCounterTest.java @@ -1,6 +1,7 @@ package util.listen.step; import ode.solve.api.StepListener; +import ode.solve.impl.util.StateContainer; import org.junit.Test; import org.nd4j.linalg.factory.Nd4j; @@ -24,8 +25,8 @@ public void countAndReport() { for(int i = 0; i < 7; i++) { listener.begin(Nd4j.linspace(0, 1, 2), Nd4j.zeros(1)); - listener.step(Nd4j.ones(1), Nd4j.ones(1), Nd4j.zeros(1), Nd4j.zeros(0)); - listener.step(Nd4j.ones(1), Nd4j.ones(1), Nd4j.zeros(1), Nd4j.zeros(0)); + listener.step(new StateContainer(1, new double[] {0}, new double[] {0}), Nd4j.ones(1), Nd4j.zeros(1)); + listener.step(new StateContainer(1, new double[] {0}, new double[] {0}), Nd4j.ones(1), Nd4j.zeros(1)); listener.done(); assertEquals("Incorrect number of calls!", (i+1) / 3, probe.nrofCalls); diff --git a/src/test/java/util/random/SeededRandomFactoryTest.java b/src/test/java/util/random/SeededRandomFactoryTest.java new file mode 100644 index 0000000..5287293 --- /dev/null +++ b/src/test/java/util/random/SeededRandomFactoryTest.java @@ -0,0 +1,40 @@ +package util.random; + +import org.junit.Test; +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; + +import static junit.framework.TestCase.assertEquals; + +public class SeededRandomFactoryTest { + + /** + * Test that the random seed can be reset to generate the same sequence again + */ + @Test + public void getRandom() { + final long baseSeed = 666; + final long[] shape = {2,3,4}; + SeededRandomFactory.setNd4jSeed(baseSeed); + final Random first = Nd4j.getRandom(); + SeededRandomFactory.setNd4jSeed(baseSeed); + final Random second = Nd4j.getRandom(); + + assertEquals("Not same random!", first.nextGaussian(shape), second.nextGaussian(shape)); + } + + /** + * Test that the new random instances + */ + @Test + public void getNewRandomInstance() { + final long baseSeed = 666; + final long[] shape = {2,3,4}; + SeededRandomFactory.setNd4jSeed(baseSeed); + final Random first = Nd4j.getRandomFactory().getNewRandomInstance(); + SeededRandomFactory.setNd4jSeed(baseSeed); + final Random second = Nd4j.getRandomFactory().getNewRandomInstance(); + + assertEquals("Not same random!", first.nextGaussian(shape), second.nextGaussian(shape)); + } +} \ No newline at end of file