Y(t1), Y(t2), ..., Y(tN)
for which
+ *
+ * @param equation is the function F(Y(t))
for which dY/dt = F(Y(t))
+ * @param t is a vector of times t0, t1, t2, ..., tN
for which Y(t1), Y(t2), ..., Y(tN)
is sought
+ * @param y0 is the value of Y(t) at t = t0
, i.e. Y(t0)
+ * @param yOut will contain the estimated values of Y(t1), Y(t2), ..., Y(tN)
.
+ * @return yOut, same instance as input param. Note that yOut does not contain Y(t0)
in order to comply
+ * with {@link FirstOrderSolver} API in the sense that if t
has only two values, only Y(t1)
is returned
+ */
+ @Override
+ INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, INDArray yOut);
+}
diff --git a/src/main/java/ode/solve/api/FirstOrderSolver.java b/src/main/java/ode/solve/api/FirstOrderSolver.java
index 505bdb1..d0fed27 100644
--- a/src/main/java/ode/solve/api/FirstOrderSolver.java
+++ b/src/main/java/ode/solve/api/FirstOrderSolver.java
@@ -11,25 +11,28 @@
public interface FirstOrderSolver {
/**
- * Compute estimated value of Y(t1)
for which
- * @param equation is the function F(Y(t)
for which dY/dt = F(Y(t)
- * @param t is a vector with initial value t0 and desired value t1
- * @param y0 is the value of
Y(t) at t = t0
- * @param yOut will contain the estimated value of Y(t1)
. Implementations shall assume that y0 and yOut
- * are the same instance.
+ * Compute estimated value of Y(t1))
for which
+ *
+ * @param equation is the function F(Y(t))
for which dY/dt = F(Y(t))
+ * @param t is a vector with initial time t0
and time t1
for which Y(t1)
is sought
+ * @param y0 is the value of
Y(t) at t = t0
, i.e. Y(t0)
+ * @param yOut will contain the estimated value of Y(t1)
. Implementations may not make any assumption
+ * as to whether y0 and yOut are the same instance or not.
* @return yOut, same instance as input param
*/
INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, INDArray yOut);
/**
* Add {@link StepListener}s which will be notified of steps taken
+ *
* @param listeners listeners to add
*/
void addListener(StepListener... listeners);
/**
* Clear the given listeners. Clear all listeners if empty
+ *
* @param listeners listeners to remove
*/
- void clearListeners(StepListener ... listeners);
+ void clearListeners(StepListener... listeners);
}
diff --git a/src/main/java/ode/solve/api/StepListener.java b/src/main/java/ode/solve/api/StepListener.java
index 5067818..4f44fd4 100644
--- a/src/main/java/ode/solve/api/StepListener.java
+++ b/src/main/java/ode/solve/api/StepListener.java
@@ -1,5 +1,6 @@
package ode.solve.api;
+import ode.solve.impl.util.SolverState;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
@@ -18,12 +19,11 @@ public interface StepListener {
/**
* Indicates a step has been taken
- * @param currTime Current time (after step has been taken
+ * @param solverState Current state of the solver
* @param step Taken step
* @param error Estimated error
- * @param y current state (i.e estimated state at currTime)
*/
- void step(INDArray currTime, INDArray step, INDArray error, INDArray y);
+ void step(SolverState solverState, INDArray step, INDArray error);
/**
* Indicates that previous step was the last step taken
diff --git a/src/main/java/ode/solve/commons/StepListenerAdapter.java b/src/main/java/ode/solve/commons/StepListenerAdapter.java
index d7079a7..cfabd4b 100644
--- a/src/main/java/ode/solve/commons/StepListenerAdapter.java
+++ b/src/main/java/ode/solve/commons/StepListenerAdapter.java
@@ -1,6 +1,7 @@
package ode.solve.commons;
import ode.solve.api.StepListener;
+import ode.solve.impl.util.StateContainer;
import org.apache.commons.math3.exception.MaxCountExceededException;
import org.apache.commons.math3.ode.sampling.StepHandler;
import org.apache.commons.math3.ode.sampling.StepInterpolator;
@@ -22,10 +23,12 @@ public void init(double t0, double[] y0, double t) {
@Override
public void handleStep(StepInterpolator interpolator, boolean isLast) throws MaxCountExceededException {
wrappedListener.step(
- Nd4j.create(1).assign(interpolator.getCurrentTime()),
+ new StateContainer(
+ interpolator.getCurrentTime(),
+ interpolator.getInterpolatedState(),
+ interpolator.getInterpolatedDerivatives()),
Nd4j.create(1).assign(interpolator.getCurrentTime() - interpolator.getPreviousTime()),
- null, // error not observable :(
- Nd4j.create(interpolator.getInterpolatedState())
+ null // error not observable :(
);
}
}
diff --git a/src/main/java/ode/solve/impl/AdaptiveRungeKuttaSolver.java b/src/main/java/ode/solve/impl/AdaptiveRungeKuttaSolver.java
index 09747ba..0b87fe7 100644
--- a/src/main/java/ode/solve/impl/AdaptiveRungeKuttaSolver.java
+++ b/src/main/java/ode/solve/impl/AdaptiveRungeKuttaSolver.java
@@ -115,7 +115,7 @@ public INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0,
equation,
t.getScalar(0).dup(),
yOut.assign(y0),
- tableu.c.length() + 1);
+ tableu.cMid);
listener.begin(t, y0);
@@ -134,6 +134,7 @@ private void solve(FirstOrderEquationWithState equation, INDArray t) {
final INDArray error = Nd4j.create(1);
// Alg variable used for new steps
final INDArray step = stepPolicy.initializeStep(equation, t);
+
// Alg variable for where next step starts
final TimeLimit timeLimit = t.argMax().getInt(0) == 1 ?
new TimeLimitForwards(t.getDouble(1), equation.time()) :
@@ -170,7 +171,9 @@ private boolean acceptStep(FirstOrderEquationWithState equation, INDArray step,
// local error is small enough: accept the step,
equation.update();
- listener.step(equation.time(), step, error, equation.getCurrentState());
+ listener.step(equation, step, error);
+
+ equation.shiftDerivative();
return true;
}
diff --git a/src/main/java/ode/solve/impl/DormandPrince54Solver.java b/src/main/java/ode/solve/impl/DormandPrince54Solver.java
index deb77c5..2a8e1f5 100644
--- a/src/main/java/ode/solve/impl/DormandPrince54Solver.java
+++ b/src/main/java/ode/solve/impl/DormandPrince54Solver.java
@@ -42,7 +42,10 @@ public class DormandPrince54Solver implements FirstOrderSolver {
})
.c(new double[]{
1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0
- });
+ })
+ .cMid(new double[]{
+ 6025192743d / 30085553152d / 2d, 0, 51252292925d / 65400821598d / 2d, -2691868925d / 45128329728d / 2d,
+ 187940372067d / 1594534317056d / 2d, -1776094331 / 19743644256d / 2d, 11237099 / 235043384d / 2d});
private final AdaptiveRungeKuttaSolver solver;
diff --git a/src/main/java/ode/solve/impl/DummyIteration.java b/src/main/java/ode/solve/impl/DummyIteration.java
index f447cb2..6e8c5e0 100644
--- a/src/main/java/ode/solve/impl/DummyIteration.java
+++ b/src/main/java/ode/solve/impl/DummyIteration.java
@@ -4,6 +4,7 @@
import ode.solve.api.FirstOrderSolver;
import ode.solve.api.StepListener;
import ode.solve.impl.util.AggStepListener;
+import ode.solve.impl.util.StateContainer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@@ -32,7 +33,11 @@ public INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0,
for (int i = 0; i < nrofIters; i++) {
next = equation.calculateDerivative(next, t.getColumn(0), yOut);
- listener.step(t.getColumn(0), Nd4j.zeros(1), Nd4j.zeros(1), next);
+ listener.step(
+ new StateContainer(t.getColumn(0),
+ next,
+ Nd4j.zeros(next.shape())),
+ Nd4j.zeros(1), Nd4j.zeros(1));
}
listener.done();
diff --git a/src/main/java/ode/solve/impl/InterpolatingMultiStepSolver.java b/src/main/java/ode/solve/impl/InterpolatingMultiStepSolver.java
new file mode 100644
index 0000000..544e35a
--- /dev/null
+++ b/src/main/java/ode/solve/impl/InterpolatingMultiStepSolver.java
@@ -0,0 +1,68 @@
+package ode.solve.impl;
+
+import ode.solve.api.FirstOrderEquation;
+import ode.solve.api.FirstOrderMultiStepSolver;
+import ode.solve.api.FirstOrderSolver;
+import ode.solve.api.StepListener;
+import ode.solve.impl.util.InterpolatingStepListener;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.indexing.SpecifiedIndex;
+
+import java.util.Arrays;
+
+/**
+ * {@link FirstOrderMultiStepSolver} which uses an {@link InterpolatingStepListener} to estimate intermediate time
+ * steps. Will invoke the the wrapped {@link FirstOrderSolver} with the first and last time steps and return the
+ * interpolated state (i.e not the output from the wrapped solver).
+ */
+public class InterpolatingMultiStepSolver implements FirstOrderMultiStepSolver {
+
+ private final FirstOrderSolver solver;
+
+ public InterpolatingMultiStepSolver(FirstOrderSolver solver) {
+ this.solver = solver;
+ }
+
+ @Override
+ public INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, INDArray yOut) {
+ if(!t.isVector()) {
+ throw new IllegalStateException("t must be a vector! Was shape " + Arrays.toString(t.shape()));
+ }
+
+ final INDArray wantedTimes = t.get(NDArrayIndex.interval(1, t.length())).reshape(getTimeShapeForLength(t, t.length()-1));
+
+ final InterpolatingStepListener interpolation = new InterpolatingStepListener(wantedTimes, yOut);
+
+ addListener(interpolation);
+
+ // Note: skip first time index in order to comply with API description
+ final INDArray tStartEnd = t.get(new SpecifiedIndex(0, t.length()-1)).reshape(getTimeShapeForLength(t, 2));
+ solver.integrate(equation, tStartEnd, y0, y0);
+
+ clearListeners(interpolation);
+
+ // interpolation has made sure yOut contains estimated solutions at desired times
+ return yOut;
+ }
+
+ private long[] getTimeShapeForLength(INDArray t, long length) {
+ final long[] shape = t.shape().clone();
+ for(int i = 0; i < shape.length; i++) {
+ if(shape[i] != 1) {
+ shape[i] = length;
+ }
+ }
+ return shape;
+ }
+
+ @Override
+ public void addListener(StepListener... listeners) {
+ solver.addListener(listeners);
+ }
+
+ @Override
+ public void clearListeners(StepListener... listeners) {
+ solver.clearListeners(listeners);
+ }
+}
diff --git a/src/main/java/ode/solve/impl/SingleSteppingMultiStepSolver.java b/src/main/java/ode/solve/impl/SingleSteppingMultiStepSolver.java
new file mode 100644
index 0000000..7f7a338
--- /dev/null
+++ b/src/main/java/ode/solve/impl/SingleSteppingMultiStepSolver.java
@@ -0,0 +1,55 @@
+package ode.solve.impl;
+
+import ode.solve.api.FirstOrderEquation;
+import ode.solve.api.FirstOrderMultiStepSolver;
+import ode.solve.api.FirstOrderSolver;
+import ode.solve.api.StepListener;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+
+/**
+ * {@link FirstOrderSolver} for multiple time steps. Will invoke the wrapped {@link FirstOrderSolver} multiple times for
+ * every pair {@code t[n], t[n+1], 0 < n < t.length()-1} and concatenate the solution.
+ *
+ * @author Christian Skarby
+ */
+public class SingleSteppingMultiStepSolver implements FirstOrderMultiStepSolver {
+
+ private final FirstOrderSolver solver;
+
+ public SingleSteppingMultiStepSolver(FirstOrderSolver solver) {
+ this.solver = solver;
+ }
+
+ @Override
+ public INDArray integrate(FirstOrderEquation equation, INDArray t, INDArray y0, INDArray yOut) {
+ if(yOut.size(0) + 1 != t.length()) {
+ throw new IllegalArgumentException("yOut must have one element less in first dimension compared to t!");
+ }
+ final INDArrayIndex[] yOutAccess = new INDArrayIndex[yOut.rank()];
+ for(int dim = 0; dim < yOutAccess.length; dim++) {
+ yOutAccess[dim] = NDArrayIndex.all();
+ }
+
+ for (int step = 1; step < t.length(); step++) {
+ yOutAccess[0] = NDArrayIndex.point(step -2);
+ final INDArray yPrev = step == 1 ? y0 : yOut.get(yOutAccess);
+ yOutAccess[0] = NDArrayIndex.point(step-1);
+ final INDArray yCurr = yOut.get(yOutAccess);
+ solver.integrate(equation, t.get(NDArrayIndex.interval(step - 1, step+1)), yPrev, yCurr);
+ }
+
+ return yOut;
+ }
+
+ @Override
+ public void addListener(StepListener... listeners) {
+ solver.addListener(listeners);
+ }
+
+ @Override
+ public void clearListeners(StepListener... listeners) {
+ solver.clearListeners(listeners);
+ }
+}
diff --git a/src/main/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicy.java b/src/main/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicy.java
index bf54de8..fbaf574 100644
--- a/src/main/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicy.java
+++ b/src/main/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicy.java
@@ -119,6 +119,7 @@ public INDArray initializeStep(FirstOrderEquationWithState equation, INDArray t)
// step size is computed such that
// step^order * max (||y'/tol||, ||y''/tol||) = 0.01
+ // TODO: Should be abs(yDDotOnScale) for when negative step?
final INDArray maxInv2 = max(sqrt(yDotOnScale2), yDDotOnScale);
final INDArray step1 = maxInv2.getDouble(0) < 1e-15 ?
max(MIN_H, abs(step).muli(0.001)) :
diff --git a/src/main/java/ode/solve/impl/util/AggStepListener.java b/src/main/java/ode/solve/impl/util/AggStepListener.java
index 5386b4a..06d6f1e 100644
--- a/src/main/java/ode/solve/impl/util/AggStepListener.java
+++ b/src/main/java/ode/solve/impl/util/AggStepListener.java
@@ -24,9 +24,9 @@ public void begin(INDArray t, INDArray y0) {
}
@Override
- public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) {
+ public void step(SolverState solverState, INDArray step, INDArray error) {
for (StepListener listener : listeners) {
- listener.step(currTime, step, error, y);
+ listener.step(solverState, step, error);
}
}
diff --git a/src/main/java/ode/solve/impl/util/ButcherTableu.java b/src/main/java/ode/solve/impl/util/ButcherTableu.java
index 96bfe1e..5cf9898 100644
--- a/src/main/java/ode/solve/impl/util/ButcherTableu.java
+++ b/src/main/java/ode/solve/impl/util/ButcherTableu.java
@@ -18,6 +18,7 @@ public class ButcherTableu {
public final INDArray b;
public final INDArray bStar;
public final INDArray c;
+ public final double[] cMid;
/**
* Create a new {@link Builder} instance
@@ -27,11 +28,12 @@ public static Builder builder() {
return new Builder();
}
- private ButcherTableu(INDArray[] a, INDArray b, INDArray bStar, INDArray c) {
+ private ButcherTableu(INDArray[] a, INDArray b, INDArray bStar, INDArray c, double[] cMid) {
this.a = a;
this.b = b;
this.bStar = bStar;
this.c = c;
+ this.cMid = cMid;
}
/**
@@ -46,6 +48,7 @@ public static class Builder {
private double[] b;
private double[] bStar;
private double[] c;
+ private double[] cMid;
public Builder a(double[][] a) {
cache.clear();
@@ -67,6 +70,11 @@ public Builder c(double[] c) {
this.c = c; return this;
}
+ public Builder cMid(double[] cMid) {
+ cache.clear();
+ this.cMid = cMid; return this;
+ }
+
public ButcherTableu build() {
ButcherTableu tableu = cache.get(Nd4j.dataType());
if(tableu == null) {
@@ -78,7 +86,8 @@ public ButcherTableu build() {
aArr,
Nd4j.create(b),
Nd4j.create(bStar),
- Nd4j.create(c));
+ Nd4j.create(c),
+ cMid);
cache.put(Nd4j.dataType(), tableu);
}
return tableu;
diff --git a/src/main/java/ode/solve/impl/util/FirstOrderEquationWithState.java b/src/main/java/ode/solve/impl/util/FirstOrderEquationWithState.java
index d25a4a2..109fa36 100644
--- a/src/main/java/ode/solve/impl/util/FirstOrderEquationWithState.java
+++ b/src/main/java/ode/solve/impl/util/FirstOrderEquationWithState.java
@@ -11,11 +11,12 @@
*
* @author Christian Skarby
*/
-public class FirstOrderEquationWithState {
+public class FirstOrderEquationWithState implements SolverState {
private final FirstOrderEquation equation;
private final INDArray time;
private final State state;
+ private final double[] midPointCoeffs;
private final static class State {
private final INDArray y; // Last value of y. May be of any shape
@@ -48,13 +49,14 @@ public FirstOrderEquationWithState(
FirstOrderEquation equation,
INDArray time,
INDArray state,
- long nrofStages) {
+ double[] midPointCoeffs) {
if (!time.isScalar()) {
throw new IllegalArgumentException("Expected time to be a scalar! Was: " + time);
}
this.equation = equation;
this.time = time;
- this.state = new State(state, nrofStages);
+ this.state = new State(state, midPointCoeffs.length);
+ this.midPointCoeffs = midPointCoeffs;
}
@@ -64,10 +66,12 @@ public FirstOrderEquationWithState(
* @param stage which stage to update
*/
public void calculateDerivative(long stage) {
+ //System.out.println("\tUpdate stage " + stage + " from " + state.getStateDot(stage));
equation.calculateDerivative(
state.yWorking,
time.add(state.timeOffset),
state.getStateDot(stage)); // Note, stateDot of the given stage will be updated by this operation
+ //System.out.println("\tto: " + state.getStateDot(stage));
}
/**
@@ -75,6 +79,7 @@ public void calculateDerivative(long stage) {
* @param stage wanted stage of derivative
* @return The derivative of the given stage
*/
+ @Override
public INDArray getStateDot(long stage) {
return state.getStateDot(stage);
}
@@ -83,10 +88,25 @@ public INDArray getStateDot(long stage) {
* Return the current state
* @return the current state
*/
+ @Override
public INDArray getCurrentState() {
return state.y;
}
+ /**
+ * Return the current time state
+ * @return the current time state
+ */
+ @Override
+ public INDArray time() {
+ return time;
+ }
+
+ @Override
+ public double[] getInterpolationMidpoints() {
+ return midPointCoeffs;
+ }
+
/**
* Estimate the error using the given {@link AdaptiveRungeKuttaSolver.MseComputation}
* @param mseComputation Strategy for computing the error
@@ -96,14 +116,6 @@ public INDArray estimateError(AdaptiveRungeKuttaSolver.MseComputation mseComputa
return mseComputation.estimateMse(state.yDotK, state.y, state.yWorking, state.timeOffset);
}
- /**
- * Return the current time state
- * @return
- */
- public INDArray time() {
- return time;
- }
-
/**
* Update the working state by taking a step accumulated over all stages up the the given stage. The base step is
* weighted with the given coefficients for each stage.
@@ -115,12 +127,18 @@ public void step(INDArray stepCoeffPerStage, INDArray step) {
}
/**
- * Update the current state to the working state. This also includes shifting the derivative so that previous
- * stage 1 becomes new stage 0.
+ * Update the current state to the working state.
*/
public void update() {
time.addi(state.timeOffset);
state.y.assign(state.yWorking);
+
+ }
+
+ /**
+ * Shift the derivative so that previous stage 1 becomes new stage 0 in preparation for next step.
+ */
+ public void shiftDerivative() {
state.yDotK.putRow(0, state.yDotK.getRow(state.yDotK.rows()-1));
}
}
diff --git a/src/main/java/ode/solve/impl/util/InterpolatingStepListener.java b/src/main/java/ode/solve/impl/util/InterpolatingStepListener.java
new file mode 100644
index 0000000..0395a01
--- /dev/null
+++ b/src/main/java/ode/solve/impl/util/InterpolatingStepListener.java
@@ -0,0 +1,151 @@
+package ode.solve.impl.util;
+
+import ode.solve.api.StepListener;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.indexing.conditions.And;
+import org.nd4j.linalg.indexing.conditions.GreaterThan;
+import org.nd4j.linalg.indexing.conditions.LessThan;
+
+import java.util.Arrays;
+
+/**
+ * Samples {@link SolverState} at given time indexes using an {@link Interpolation}. Useful to extract multiple time
+ * steps from a single time step solver. Advantage compared to calling the solver once per time step pair is fewer
+ * function evaluations.
+ *
+ * @author Christian Skarby
+ */
+public class InterpolatingStepListener implements StepListener {
+
+ private final INDArray wantedTimeInds;
+ private final INDArray yInterpol;
+ private final INDArrayIndex[] yInterpolAccess;
+ private State state;
+
+ private final class State {
+ private Interpolation interpolation = new Interpolation();
+ private INDArray y0;
+ private INDArray t0;
+ }
+
+ /**
+ * Create an {@link InterpolatingStepListener}. Output for the each wanted time will be assigned along dimension 0
+ * of the provided yInterpol. In other words, output for wantedTimes[x] can be accessed through
+ * yInterpol.get(NDArrayIndex.point(x), NDArrayIndex.all(), NDArrayIndex.all(), ...)
+ *
+ * @param wantedTimes Time samples for which output is desired.
+ * @param yInterpol Will contain output from the provided {@link SolverState} at the desired times.
+ */
+ public InterpolatingStepListener(INDArray wantedTimes, INDArray yInterpol) {
+ if (wantedTimes.length() != yInterpol.size(0)) {
+ throw new IllegalArgumentException("Must have one wanted time per element in dimension 0 of yInterpol! " +
+ "wantedTimes shape: " + Arrays.toString(wantedTimes.shape()) + ", yInterpol shape: " +
+ Arrays.toString(yInterpol.shape()));
+ }
+
+ this.wantedTimeInds = wantedTimes;
+ this.yInterpol = yInterpol;
+
+ this.yInterpolAccess = new INDArrayIndex[yInterpol.rank()];
+ for (int dim = 0; dim < yInterpolAccess.length; dim++) {
+ yInterpolAccess[dim] = NDArrayIndex.all();
+ }
+ }
+
+ @Override
+ public void begin(INDArray t, INDArray y0) {
+ this.state = new State();
+ this.state.t0 = t.getScalar(0);
+ this.state.y0 = y0;
+
+ // Edge case: The first wanted time index is the start time -> user wants the starting state to be added to output
+ if (state.t0.equalsWithEps(wantedTimeInds.getScalar(0), 1e-10)) {
+ yInterpolAccess[0] = NDArrayIndex.point(0);
+ yInterpol.put(yInterpolAccess, state.y0);
+ }
+ }
+
+ @Override
+ public void step(SolverState solverState, INDArray step, INDArray error) {
+ final INDArray greaterThanTime;
+ final INDArray lessThanTime;
+ if (step.getDouble(0) > 0) {
+ greaterThanTime = state.t0;
+ lessThanTime = solverState.time();
+ } else {
+ greaterThanTime = solverState.time();
+ lessThanTime = state.t0;
+ }
+
+ final INDArray timeInds = wantedTimeInds.cond(
+ new And(
+ new GreaterThan(greaterThanTime.getDouble(0)),
+ new LessThan(lessThanTime.getDouble(0)))
+ );
+
+ if (timeInds.sumNumber().doubleValue() > 0) {
+ fitInterpolationCoeffs(solverState, step);
+ doInterpolation(timeInds, solverState.time());
+ }
+
+ state.t0 = solverState.time().dup();
+ state.y0 = solverState.getCurrentState().dup();
+ }
+
+ private void fitInterpolationCoeffs(SolverState solverState, INDArray step) {
+ final INDArray[] yDotStages = new INDArray[solverState.getInterpolationMidpoints().length];
+ for (int i = 0; i < yDotStages.length; i++) {
+ yDotStages[i] = solverState.getStateDot(i);
+ }
+
+ final INDArray state0 = state.y0;
+ final INDArray state1 = solverState.getCurrentState();
+
+ final INDArray yMid = state.y0.add(scaledDotProduct(
+ Nd4j.createUninitialized(state0.shape()),
+ solverState.getInterpolationMidpoints(),
+ yDotStages,
+ step));
+
+ state.interpolation.fitCoeffs(
+ state0,
+ state1,
+ yMid,
+ yDotStages[0],
+ yDotStages[yDotStages.length - 1],
+ step);
+ }
+
+ private INDArray scaledDotProduct(INDArray output, double[] factors, INDArray[] inputs, INDArray scale) {
+ output.assign((inputs[0].mul(factors[0])).mul(scale));
+ for (int i = 1; i < inputs.length; i++) {
+ output.addi(inputs[i].mul(factors[i]).mul(scale));
+ }
+ return output;
+ }
+
+ private void doInterpolation(INDArray timeInds, INDArray tNew) {
+ final int startInd = timeInds.argMax().getInt(0);
+ final int stopInd = startInd + timeInds.sumNumber().intValue();
+
+ for (int i = startInd; i < stopInd; i++) {
+ yInterpolAccess[0] = NDArrayIndex.point(i);
+
+ yInterpol.put(yInterpolAccess,
+ state.interpolation.interpolate(state.t0.getDouble(0), tNew.getDouble(0), wantedTimeInds.getDouble(i)));
+ }
+ }
+
+ @Override
+ public void done() {
+
+ // Edge case: User wants last time step to be added to interpolation
+ if (state.t0.equalsWithEps(wantedTimeInds.getScalar(wantedTimeInds.length() - 1), 1e-10)) {
+ yInterpolAccess[0] = NDArrayIndex.point(wantedTimeInds.length() - 1);
+ yInterpol.put(yInterpolAccess, state.y0);
+ }
+ }
+}
diff --git a/src/main/java/ode/solve/impl/util/Interpolation.java b/src/main/java/ode/solve/impl/util/Interpolation.java
new file mode 100644
index 0000000..7c17009
--- /dev/null
+++ b/src/main/java/ode/solve/impl/util/Interpolation.java
@@ -0,0 +1,85 @@
+package ode.solve.impl.util;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+/**
+ * Performs interpolation using a fourth order polynomial. Useful for reducing the number of function evaluations
+ * when multiple (closely spaced) time steps are used.
+ * Reimplementation of https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/interp.py
+ *
+ * @author Christian Skarby
+ */
+public class Interpolation {
+
+ private final INDArray[] coeffs = new INDArray[5];
+
+ private void initCoeffs(long[] shape) {
+ for (int i = 0; i < coeffs.length; i++) {
+ coeffs[i] = Nd4j.createUninitialized(shape);
+ }
+ }
+
+ /**
+ * Fit coefficients for a fourth order polynomial: p = coeffs[0] * x^4 + coeffs[1] * x^3 + coeffs[2] * x^2 + coeffs[3] * x + coeffs[4]
+ *
+ * @param y0 Function value at start of interval
+ * @param y1 Function value at end of interval
+ * @param yMid Function value at midpoint of interval
+ * @param f0 Derivative value at start of interval
+ * @param f1 Derivative value at end of interval
+ * @param dt Time between start and end of interval
+ */
+ public void fitCoeffs(INDArray y0, INDArray y1, INDArray yMid, INDArray f0, INDArray f1, INDArray dt) {
+ if (coeffs[0] == null) {
+ initCoeffs(y0.shape());
+ }
+
+ final INDArray[] inputs = {f0, f1, y0, y1, yMid};
+
+ final double dtdub = dt.getDouble(0);
+ dotProduct(coeffs[0], new double[]{-2 * dtdub, 2 * dtdub, -8, -8, 16}, inputs);
+
+ dotProduct(coeffs[1], new double[]{5 * dtdub, -3 * dtdub, 18, 14, -32}, inputs);
+
+ dotProduct(coeffs[2], new double[]{-4 * dtdub, dtdub, -11, -5, 16}, inputs);
+
+ coeffs[3] = f0.mul(dt);
+ coeffs[4] = y0.dup();
+ }
+
+ private INDArray dotProduct(INDArray output, double[] factors, INDArray[] inputs) {
+ output.assign((inputs[0].mul(factors[0])));
+ for (int i = 1; i < inputs.length; i++) {
+ output.addi(inputs[i].mul(factors[i]));
+ }
+ return output;
+ }
+
+ /**
+ * Evaluate interpolation of a fourth order polynomial: p = coeffs[0] * x^4 + coeffs[1] * x^3 + coeffs[2] * x^2 + coeffs[3] * x + coeffs[4]
+ *
+ * @param t0 Start of interval
+ * @param t1 End of interval
+ * @param t Wanted time
+ * @return Result of the interpolation
+ */
+ public INDArray interpolate(double t0, double t1, double t) {
+
+ if (Math.max(t0, t1) < t || Math.min(t0, t1) > t) {
+ throw new IllegalArgumentException("t0 < t < t1 or t1 < t < t0 not satisfied! t0: " + t0 + ", t: " + t + ", t1: " + t1);
+ }
+
+ final double x = ((t - t0) / (t1 - t0));
+
+ // Create array {x^4, x^3, x^2, x, 1};
+ final double[] xs = new double[coeffs.length];
+ xs[coeffs.length - 1] = 1;
+ for (int i = coeffs.length - 2; i > -1; i--) {
+ xs[i] = xs[i + 1] * x;
+ }
+
+ return dotProduct(Nd4j.createUninitialized(coeffs[0].shape()), xs, coeffs);
+ }
+
+}
diff --git a/src/main/java/ode/solve/impl/util/SolverState.java b/src/main/java/ode/solve/impl/util/SolverState.java
new file mode 100644
index 0000000..a9b2e98
--- /dev/null
+++ b/src/main/java/ode/solve/impl/util/SolverState.java
@@ -0,0 +1,36 @@
+package ode.solve.impl.util;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+/**
+ * Interface for providing state of a solver
+ *
+ * @author Christian Skarby
+ */
+public interface SolverState {
+
+ /**
+ * Return the given stage of the derivative of the state
+ * @param stage wanted stage of derivative
+ * @return The derivative of the given stage
+ */
+ INDArray getStateDot(long stage);
+
+ /**
+ * Return the current state
+ * @return the current state
+ */
+ INDArray getCurrentState();
+
+ /**
+ * Return the current time state
+ * @return the current time state
+ */
+ INDArray time();
+
+ /**
+ * Return coefficients for calculating midpoints for the given solver
+ * @return coefficients for calculating midpoints for the given solver
+ */
+ double[] getInterpolationMidpoints();
+}
diff --git a/src/main/java/ode/solve/impl/util/StateContainer.java b/src/main/java/ode/solve/impl/util/StateContainer.java
new file mode 100644
index 0000000..1ce89a9
--- /dev/null
+++ b/src/main/java/ode/solve/impl/util/StateContainer.java
@@ -0,0 +1,46 @@
+package ode.solve.impl.util;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+/**
+ * Simple container for solver state.
+ *
+ * @author Christian Skarby
+ */
+public class StateContainer implements SolverState {
+
+ private final INDArray currentTime;
+ private final INDArray currentState;
+ private final INDArray currentStateDot;
+
+ public StateContainer(double currentTime, double[] currentState, double[] currentStateDot) {
+ this(Nd4j.scalar(currentTime), Nd4j.create(currentState), Nd4j.create(currentStateDot));
+ }
+
+ public StateContainer(INDArray currentTime, INDArray currentState, INDArray currentStateDot) {
+ this.currentTime = currentTime;
+ this.currentState = currentState;
+ this.currentStateDot = currentStateDot;
+ }
+
+ @Override
+ public INDArray getStateDot(long stage) {
+ return currentStateDot;
+ }
+
+ @Override
+ public INDArray getCurrentState() {
+ return currentState;
+ }
+
+ @Override
+ public INDArray time() {
+ return currentTime;
+ }
+
+ @Override
+ public double[] getInterpolationMidpoints() {
+ return new double[0];
+ }
+}
diff --git a/src/main/java/ode/vertex/conf/DefaultTrainingConfig.java b/src/main/java/ode/vertex/conf/DefaultTrainingConfig.java
index a41690e..3510ad3 100644
--- a/src/main/java/ode/vertex/conf/DefaultTrainingConfig.java
+++ b/src/main/java/ode/vertex/conf/DefaultTrainingConfig.java
@@ -1,8 +1,12 @@
package ode.vertex.conf;
+import ode.vertex.impl.gradview.parname.ParamNameMapping;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.nd4j.linalg.learning.config.IUpdater;
+import org.nd4j.linalg.primitives.Pair;
/**
* Basic {@link TrainingConfig} for vertices which do not have a {@link org.deeplearning4j.nn.conf.layers.Layer}
@@ -11,13 +15,14 @@
*/
public class DefaultTrainingConfig implements TrainingConfig {
-
private final String name;
- private final IUpdater updater;
+ private final ComputationGraph graph;
+ private final ParamNameMapping paramNameMapping;
- public DefaultTrainingConfig(String name, IUpdater updater) {
+ public DefaultTrainingConfig(ComputationGraph graph, String name, ParamNameMapping paramNameMapping) {
this.name = name;
- this.updater = updater;
+ this.graph = graph;
+ this.paramNameMapping = paramNameMapping;
}
@Override
@@ -32,22 +37,30 @@ public boolean isPretrain() {
@Override
public double getL1ByParam(String paramName) {
- return 0;
+ final Pair vertexAndParam = paramNameMapping.reverseMap(paramName);
+ final GraphVertex vertex = graph.getVertex(vertexAndParam.getFirst());
+ return vertex.getConfig().getL1ByParam(vertexAndParam.getSecond());
}
@Override
public double getL2ByParam(String paramName) {
- return 0;
+ final Pair vertexAndParam = paramNameMapping.reverseMap(paramName);
+ final GraphVertex vertex = graph.getVertex(vertexAndParam.getFirst());
+ return vertex.getConfig().getL2ByParam(vertexAndParam.getSecond());
}
@Override
public boolean isPretrainParam(String paramName) {
- return false;
+ final Pair vertexAndParam = paramNameMapping.reverseMap(paramName);
+ final GraphVertex vertex = graph.getVertex(vertexAndParam.getFirst());
+ return vertex.getConfig().isPretrainParam(vertexAndParam.getSecond());
}
@Override
public IUpdater getUpdaterByParam(String paramName) {
- return updater;
+ final Pair vertexAndParam = paramNameMapping.reverseMap(paramName);
+ final GraphVertex vertex = graph.getVertex(vertexAndParam.getFirst());
+ return vertex.getConfig().getUpdaterByParam(vertexAndParam.getSecond());
}
@Override
diff --git a/src/main/java/ode/vertex/conf/OdeVertex.java b/src/main/java/ode/vertex/conf/OdeVertex.java
index f01a660..1150723 100644
--- a/src/main/java/ode/vertex/conf/OdeVertex.java
+++ b/src/main/java/ode/vertex/conf/OdeVertex.java
@@ -1,23 +1,40 @@
package ode.vertex.conf;
import lombok.Data;
-import ode.solve.api.FirstOrderSolver;
-import ode.solve.api.FirstOrderSolverConf;
import ode.solve.conf.DormandPrince54Solver;
+import ode.vertex.conf.helper.OdeHelper;
+import ode.vertex.conf.helper.backward.FixedStepAdjoint;
+import ode.vertex.conf.helper.backward.OdeHelperBackward;
+import ode.vertex.conf.helper.forward.FixedStep;
+import ode.vertex.conf.helper.forward.OdeHelperForward;
+import ode.vertex.impl.gradview.GradientViewFactory;
+import ode.vertex.impl.gradview.GradientViewSelectionFromBlacklisted;
+import ode.vertex.impl.helper.OdeGraphHelper;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
-import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonProperty;
/**
- * Configuration of an ODE block.
+ * Configuration of an ODE block. Contains a {@link ComputationGraphConfiguration} which defines the structure of the
+ * learnable function {@code f = z(t)/dt} for which the {@link ode.vertex.impl.OdeVertex} will output an estimate
+ * of z(t) for given t(s).
+ *
+ * A {@link Builder} is used to add {@link Layer}s and {@link GraphVertex GraphVertices} to the internal
+ * {@link ComputationGraphConfiguration}.
+ *
+ * Note that the internal {@code ComputationGraphConfiguration} is not the same as the "outer"
+ * {@code ComputationGraphConfiguration} which houses the OdeVertex itself. This understandably confusing composition
+ * comes from the fact that the {@code OdeVertex} needs to operate on an arbitrary graph and I didn't want to
+ * reimplement all the routing for doing this. If dl4j had something similar to pytorch's nn.Module I would rather have
+ * used that.
*
* @author Christian Skarby
*/
@@ -26,23 +43,26 @@ public class OdeVertex extends GraphVertex {
protected ComputationGraphConfiguration conf;
protected String firstVertex;
- protected String lastVertex;
- protected FirstOrderSolverConf odeSolver;
+ protected OdeHelperForward odeForwardConf;
+ protected OdeHelperBackward odeBackwardConf;
+ protected GradientViewFactory gradientViewFactory;
public OdeVertex(
@JsonProperty("conf") ComputationGraphConfiguration conf,
@JsonProperty("firstVertex") String firstVertex,
- @JsonProperty("lastVertex") String lastVertex,
- @JsonProperty("odeSolver") FirstOrderSolverConf odeSolver) {
+ @JsonProperty("odeForwardConf") OdeHelperForward odeForwardConf,
+ @JsonProperty("odeBackwardConf") OdeHelperBackward odeBackwardConf,
+ @JsonProperty("gradientViewFactory") GradientViewFactory gradientViewFactory) {
this.conf = conf;
this.firstVertex = firstVertex;
- this.lastVertex = lastVertex;
- this.odeSolver = odeSolver;
+ this.odeForwardConf = odeForwardConf;
+ this.odeBackwardConf = odeBackwardConf;
+ this.gradientViewFactory = gradientViewFactory;
}
@Override
public GraphVertex clone() {
- return new OdeVertex(conf.clone(), firstVertex, lastVertex, odeSolver.clone());
+ return new OdeVertex(conf.clone(), firstVertex, odeForwardConf.clone(), odeBackwardConf.clone(), gradientViewFactory.clone());
}
@Override
@@ -50,11 +70,12 @@ public boolean equals(Object o) {
if (!(o instanceof OdeVertex)) {
return false;
}
- final OdeVertex other = (OdeVertex)o;
+ final OdeVertex other = (OdeVertex) o;
return conf.equals(other.conf)
&& firstVertex.equals(other.firstVertex)
- && lastVertex.equals(other.lastVertex)
- && odeSolver.equals(other.odeSolver);
+ && odeForwardConf.equals(other.odeForwardConf)
+ && odeBackwardConf.equals(other.odeBackwardConf)
+ && gradientViewFactory.equals(other.gradientViewFactory);
}
@Override
@@ -71,12 +92,12 @@ public long numParams(boolean backprop) {
@Override
public int minVertexInputs() {
- return conf.getVertices().get(firstVertex).minVertexInputs();
+ return conf.getVertices().get(firstVertex).minVertexInputs() + odeForwardConf.nrofTimeInputs();
}
@Override
public int maxVertexInputs() {
- return conf.getVertices().get(firstVertex).maxVertexInputs();
+ return conf.getVertices().get(firstVertex).maxVertexInputs() + odeForwardConf.nrofTimeInputs();
}
@Override
@@ -110,18 +131,27 @@ public void setBackpropGradientsViewArray(INDArray gradient) {
innerGraph.init(paramsView, false); // This does not update any parameters, just sets them
- return new ode.vertex.impl.OdeVertex(
- graph,
- name,
- idx,
+ final DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig(
innerGraph,
- odeSolver.instantiate(),
- new DefaultTrainingConfig(name, graph.getVertices()[1].getConfig().getUpdaterByParam("W").clone()));
+ name,
+ gradientViewFactory.paramNameMapping());
+
+ return new ode.vertex.impl.OdeVertex(
+ new ode.vertex.impl.OdeVertex.BaseGraphVertexInputs(graph, name, idx),
+ new OdeGraphHelper(
+ odeForwardConf.instantiate(),
+ odeBackwardConf.instantiate(),
+ new OdeGraphHelper.CompGraphAsOdeFunction(
+ innerGraph,
+ // Hacky handling for legacy models. To be removed...
+ gradientViewFactory == null ? new GradientViewSelectionFromBlacklisted() : gradientViewFactory)
+ ),
+ trainingConfig);
}
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
- return conf.getLayerActivationTypes(vertexInputs).get(lastVertex);
+ return odeForwardConf.getOutputType(conf, vertexInputs);
}
@Override
@@ -131,17 +161,25 @@ public MemoryReport getMemoryReport(InputType... inputTypes) {
public static class Builder {
- private final String inputName = this.toString() + "_input";
- private final String outputName = this.toString() + "_output";
- private final ComputationGraphConfiguration.GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder()
- .graphBuilder();
-
- private String first = null;
- private String last;
+ private final ComputationGraphConfiguration.GraphBuilder graphBuilder;
+ private final String first;
+ private OdeHelperForward odeForwardConf = new FixedStep(new DormandPrince54Solver(), Nd4j.arange(2), true);
+ private OdeHelperBackward odeBackwardConf = new FixedStepAdjoint(new DormandPrince54Solver(), Nd4j.arange(2));
+ private GradientViewFactory gradientViewFactory = new GradientViewSelectionFromBlacklisted();
- private FirstOrderSolverConf odeSolver = new DormandPrince54Solver();
public Builder(String name, Layer layer) {
+ graphBuilder = new NeuralNetConfiguration.Builder().graphBuilder();
+ final String inputName = this.toString() + "_input";
+ graphBuilder
+ .addInputs(inputName)
+ .addLayer(name, layer, inputName);
+ first = name;
+ }
+
+ public Builder(NeuralNetConfiguration.Builder globalConf, String name, Layer layer) {
+ graphBuilder = globalConf.clone().graphBuilder();
+ final String inputName = this.toString() + "_input";
graphBuilder
.addInputs(inputName)
.addLayer(name, layer, inputName);
@@ -153,8 +191,6 @@ public Builder(String name, Layer layer) {
*/
public Builder addLayer(String name, Layer layer, String... inputs) {
graphBuilder.addLayer(name, layer, inputs);
- checkFirst(name);
- last = name;
return this;
}
@@ -163,25 +199,51 @@ public Builder addLayer(String name, Layer layer, String... inputs) {
*/
public Builder addVertex(String name, GraphVertex vertex, String... inputs) {
graphBuilder.addVertex(name, vertex, inputs);
- checkFirst(name);
- last = name;
return this;
}
/**
- * Sets the {@link FirstOrderSolver} to use
- * @param odeSolver solver instance
+ * Set the {@link OdeHelper} to use
+ *
+ * @param odeConf ODE configuration
+ * @return the Builder for fluent API
+ */
+ public Builder odeConf(OdeHelper odeConf) {
+ odeForward(odeConf.forward());
+ return odeBackward(odeConf.backward());
+ }
+
+ /**
+ * Sets the {@link OdeHelperForward} to use
+ *
+ * @param odeForwardConf Configuration of forward helper
+ * @return the Builder for fluent API
+ */
+ public Builder odeForward(OdeHelperForward odeForwardConf) {
+ this.odeForwardConf = odeForwardConf;
+ return this;
+ }
+
+ /**
+ * Sets the {@link OdeHelperBackward} to use
+ *
+ * @param odeBackwardConf Configuration of backward helper
* @return the Builder for fluent API
*/
- public Builder odeSolver(FirstOrderSolverConf odeSolver) {
- this.odeSolver = odeSolver;
+ public Builder odeBackward(OdeHelperBackward odeBackwardConf) {
+ this.odeBackwardConf = odeBackwardConf;
return this;
}
- private void checkFirst(String name) {
- if (first == null) {
- first = name;
- }
+ /**
+ * Sets the {@link GradientViewFactory} to use
+ *
+ * @param gradientViewFactory Factory for gradient views
+ * @return the Builder for fluent API
+ */
+ public Builder gradientViewFactory(GradientViewFactory gradientViewFactory) {
+ this.gradientViewFactory = gradientViewFactory;
+ return this;
}
/**
@@ -191,9 +253,12 @@ private void checkFirst(String name) {
*/
public OdeVertex build() {
return new OdeVertex(graphBuilder
- .setOutputs(outputName)
- .addLayer(outputName, new CnnLossLayer(), last)
- .build(), first, last, odeSolver);
+ .allowNoOutput(true)
+ .build(),
+ first,
+ odeForwardConf,
+ odeBackwardConf,
+ gradientViewFactory);
}
}
diff --git a/src/main/java/ode/vertex/conf/helper/FixedStep.java b/src/main/java/ode/vertex/conf/helper/FixedStep.java
new file mode 100644
index 0000000..9b6f0be
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/FixedStep.java
@@ -0,0 +1,40 @@
+package ode.vertex.conf.helper;
+
+import ode.solve.api.FirstOrderSolverConf;
+import ode.vertex.conf.helper.backward.FixedStepAdjoint;
+import ode.vertex.conf.helper.backward.OdeHelperBackward;
+import ode.vertex.conf.helper.forward.OdeHelperForward;
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+/**
+ * Configuration for using a {@link ode.solve.api.FirstOrderSolver} inside a ComputationGraph when a fixed predefined
+ * set of time steps shall be used when solving the ODE.
+ *
+ * @author Christian Skarby
+ */
+public class FixedStep implements OdeHelper {
+
+ private final FirstOrderSolverConf solver;
+ private final INDArray time;
+ private final boolean interpolateForwardIfMultiStep;
+
+ public FixedStep(FirstOrderSolverConf solver, INDArray time) {
+ this(solver, time, false);
+ }
+
+ public FixedStep(FirstOrderSolverConf solver, INDArray time, boolean interpolateForwardIfMultiStep) {
+ this.solver = solver;
+ this.time = time.dup();
+ this.interpolateForwardIfMultiStep = interpolateForwardIfMultiStep;
+ }
+
+ @Override
+ public OdeHelperForward forward() {
+ return new ode.vertex.conf.helper.forward.FixedStep(solver, time, interpolateForwardIfMultiStep);
+ }
+
+ @Override
+ public OdeHelperBackward backward() {
+ return new FixedStepAdjoint(solver, time);
+ }
+}
diff --git a/src/main/java/ode/vertex/conf/helper/InputStep.java b/src/main/java/ode/vertex/conf/helper/InputStep.java
new file mode 100644
index 0000000..9e20f58
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/InputStep.java
@@ -0,0 +1,53 @@
+package ode.vertex.conf.helper;
+
+import ode.solve.api.FirstOrderSolverConf;
+import ode.vertex.conf.helper.backward.InputStepAdjoint;
+import ode.vertex.conf.helper.backward.OdeHelperBackward;
+import ode.vertex.conf.helper.forward.OdeHelperForward;
+
+/**
+ * Configuration for using a {@link ode.solve.api.FirstOrderSolver} inside a {@code ComputationGraph} when time steps for solving
+ * the ODE comes as inputs to the {@code GraphVertex} housing the ODE.
+ *
+ * Example:
+ *
+ * graphBuilder.addVertex("odeVertex",
+ * new OdeVertex.Builder("0", new DenseLayer.Builder().nOut(4).build())
+ * .odeConf(new InputStep(solverConf, 1)) // Refers to input "time" on the line below
+ * .build(), "someLayer", "time");
+ *
+ *
+ * @author Christian Skarby
+ */
+public class InputStep implements OdeHelper {
+
+ private final FirstOrderSolverConf solverConf;
+ private final int timeInputIndex;
+ private final boolean interpolateForwardIfMultiStep;
+ private final boolean needTimeGradient;
+
+ public InputStep(FirstOrderSolverConf solverConf, int timeInputIndex) {
+ this(solverConf, timeInputIndex, false);
+ }
+
+ public InputStep(FirstOrderSolverConf solverConf, int timeInputIndex, boolean interpolateForwardIfMultiStep) {
+ this(solverConf, timeInputIndex, interpolateForwardIfMultiStep, false);
+ }
+
+ public InputStep(FirstOrderSolverConf solverConf, int timeInputIndex, boolean interpolateForwardIfMultiStep, boolean needTimeGradient) {
+ this.solverConf = solverConf;
+ this.timeInputIndex = timeInputIndex;
+ this.interpolateForwardIfMultiStep = interpolateForwardIfMultiStep;
+ this.needTimeGradient = needTimeGradient;
+ }
+
+ @Override
+ public OdeHelperForward forward() {
+ return new ode.vertex.conf.helper.forward.InputStep(solverConf, timeInputIndex, interpolateForwardIfMultiStep);
+ }
+
+ @Override
+ public OdeHelperBackward backward() {
+ return new InputStepAdjoint(solverConf, timeInputIndex, needTimeGradient);
+ }
+}
diff --git a/src/main/java/ode/vertex/conf/helper/OdeHelper.java b/src/main/java/ode/vertex/conf/helper/OdeHelper.java
new file mode 100644
index 0000000..9a65e85
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/OdeHelper.java
@@ -0,0 +1,27 @@
+package ode.vertex.conf.helper;
+
+import ode.vertex.conf.helper.backward.OdeHelperBackward;
+import ode.vertex.conf.helper.forward.OdeHelperForward;
+
+/**
+ * Convenience configuration of both {@link OdeHelperForward} and {@link OdeHelperBackward}.
+ *
+ * @author Christian Skarby
+ */
+public interface OdeHelper {
+
+ /**
+ * Create the helper config in the forward direction
+ *
+ * @return helper in forward direction
+ */
+ OdeHelperForward forward();
+
+ /**
+ * Create the helper config in the backward direction
+ *
+ * @return helper in backward direction
+ */
+ OdeHelperBackward backward();
+
+}
diff --git a/src/main/java/ode/vertex/conf/helper/backward/FixedStepAdjoint.java b/src/main/java/ode/vertex/conf/helper/backward/FixedStepAdjoint.java
new file mode 100644
index 0000000..618e46e
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/backward/FixedStepAdjoint.java
@@ -0,0 +1,41 @@
+package ode.vertex.conf.helper.backward;
+
+import lombok.Data;
+import ode.solve.api.FirstOrderSolverConf;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.serde.jackson.shaded.NDArrayDeSerializer;
+import org.nd4j.serde.jackson.shaded.NDArraySerializer;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
+import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
+import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
+
+/**
+ * Serializable configuration for {@link ode.vertex.impl.helper.backward.BackpropagateAdjoint}
+ *
+ * @author Christian Skarby
+ */
+@Data
+public class FixedStepAdjoint implements OdeHelperBackward {
+
+ private final FirstOrderSolverConf solverConf;
+ @JsonSerialize(using = NDArraySerializer.class)
+ @JsonDeserialize(using = NDArrayDeSerializer.class)
+ private final INDArray time;
+
+ public FixedStepAdjoint(
+ @JsonProperty("solverConf") FirstOrderSolverConf solverConf,
+ @JsonProperty("time") INDArray time) {
+ this.solverConf = solverConf;
+ this.time = time;
+ }
+
+ @Override
+ public ode.vertex.impl.helper.backward.OdeHelperBackward instantiate() {
+ return new ode.vertex.impl.helper.backward.FixedStepAdjoint(solverConf.instantiate(), time);
+ }
+
+ @Override
+ public FixedStepAdjoint clone() {
+ return new FixedStepAdjoint(solverConf.clone(), time.dup());
+ }
+}
diff --git a/src/main/java/ode/vertex/conf/helper/backward/InputStepAdjoint.java b/src/main/java/ode/vertex/conf/helper/backward/InputStepAdjoint.java
new file mode 100644
index 0000000..022c87b
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/backward/InputStepAdjoint.java
@@ -0,0 +1,41 @@
+package ode.vertex.conf.helper.backward;
+
+import lombok.Data;
+import ode.solve.api.FirstOrderSolverConf;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
+
+/**
+ * Serializable configuration of an {@link ode.vertex.impl.helper.backward.InputStepAdjoint}
+ *
+ * @author Christian Skarby
+ */
+@Data
+public class InputStepAdjoint implements OdeHelperBackward {
+
+ private final FirstOrderSolverConf solverConf;
+ private final int timeInputIndex;
+ private final boolean needTimeGradient;
+
+ public InputStepAdjoint(FirstOrderSolverConf solverConf, int timeInputIndex) {
+ this(solverConf, timeInputIndex, false);
+ }
+
+ public InputStepAdjoint(
+ @JsonProperty("solverConf") FirstOrderSolverConf solverConf,
+ @JsonProperty("timeInputIndex") int timeInputIndex,
+ @JsonProperty("needTimeGradient") boolean needTimeGradient) {
+ this.solverConf = solverConf;
+ this.timeInputIndex = timeInputIndex;
+ this.needTimeGradient = needTimeGradient;
+ }
+
+ @Override
+ public ode.vertex.impl.helper.backward.OdeHelperBackward instantiate() {
+ return new ode.vertex.impl.helper.backward.InputStepAdjoint(solverConf.instantiate(), timeInputIndex, needTimeGradient);
+ }
+
+ @Override
+ public InputStepAdjoint clone() {
+ return new InputStepAdjoint(solverConf.clone(), timeInputIndex);
+ }
+}
diff --git a/src/main/java/ode/vertex/conf/helper/backward/OdeHelperBackward.java b/src/main/java/ode/vertex/conf/helper/backward/OdeHelperBackward.java
new file mode 100644
index 0000000..8c4c1f2
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/backward/OdeHelperBackward.java
@@ -0,0 +1,24 @@
+package ode.vertex.conf.helper.backward;
+
+import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
+
+/**
+ * Serializable configuration of an {@link ode.vertex.impl.helper.backward.OdeHelperBackward}
+ *
+ * @author Christian Skarby
+ */
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
+public interface OdeHelperBackward {
+
+ /**
+ * Instantiate the helper
+ * @return a New {@link ode.vertex.impl.helper.backward.OdeHelperBackward}
+ */
+ ode.vertex.impl.helper.backward.OdeHelperBackward instantiate();
+
+ /**
+ * Clone the configuration
+ * @return a clone of the configuration
+ */
+ OdeHelperBackward clone();
+}
diff --git a/src/main/java/ode/vertex/conf/helper/forward/FixedStep.java b/src/main/java/ode/vertex/conf/helper/forward/FixedStep.java
new file mode 100644
index 0000000..af5fbf4
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/forward/FixedStep.java
@@ -0,0 +1,66 @@
+package ode.vertex.conf.helper.forward;
+
+import lombok.Data;
+import ode.solve.api.FirstOrderSolverConf;
+import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.serde.jackson.shaded.NDArrayDeSerializer;
+import org.nd4j.serde.jackson.shaded.NDArraySerializer;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
+import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
+import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
+
+/**
+ * Serializable configuration for {@link ode.vertex.impl.helper.forward.FixedStep}
+ *
+ * @author Christian Skarby
+ */
+@Data
+public class FixedStep implements OdeHelperForward{
+
+ private final FirstOrderSolverConf solverConf;
+ @JsonSerialize(using = NDArraySerializer.class)
+ @JsonDeserialize(using = NDArrayDeSerializer.class)
+ private final INDArray time;
+ private final boolean interpolateIfMultiStep;
+
+ public FixedStep(
+ @JsonProperty("solverConf") FirstOrderSolverConf solverConf,
+ @JsonProperty("time") INDArray time,
+ @JsonProperty("interpolateIfMultiStep") boolean interpolateIfMultiStep) {
+ this.solverConf = solverConf;
+ this.time = time;
+ this.interpolateIfMultiStep = interpolateIfMultiStep;
+ }
+
+ @Override
+ public ode.vertex.impl.helper.forward.OdeHelperForward instantiate() {
+ return new ode.vertex.impl.helper.forward.FixedStep(solverConf.instantiate(), time, interpolateIfMultiStep);
+ }
+
+ @Override
+ public int nrofTimeInputs() {
+ return 0;
+ }
+
+ @Override
+ public FixedStep clone() {
+ return new FixedStep(solverConf.clone(), time.dup(), interpolateIfMultiStep);
+ }
+
+ @Override
+ public InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException {
+
+ final OutputTypeHelper confHelper = new OutputTypeFromConfig();
+ if(time.length() == 2) {
+ return confHelper.getOutputType(conf, vertexInputs);
+ }
+
+ final InputType[] withTime = new InputType[vertexInputs.length+1];
+ System.arraycopy(vertexInputs, 0, withTime, 0, vertexInputs.length);
+ withTime[vertexInputs.length] = InputType.inferInputType(this.time);
+ return new OutputTypeAddTimeAsDimension(vertexInputs.length, confHelper).getOutputType(conf, withTime);
+ }
+}
diff --git a/src/main/java/ode/vertex/conf/helper/forward/InputStep.java b/src/main/java/ode/vertex/conf/helper/forward/InputStep.java
new file mode 100644
index 0000000..948b1b4
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/forward/InputStep.java
@@ -0,0 +1,67 @@
+package ode.vertex.conf.helper.forward;
+
+import lombok.Data;
+import ode.solve.api.FirstOrderSolverConf;
+import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Serializable configuration of an {@link ode.vertex.impl.helper.forward.InputStep}
+ *
+ * @author Christian Skarby
+ */
+@Data
+public class InputStep implements OdeHelperForward {
+
+ private final FirstOrderSolverConf solverConf;
+ private final int timeInputIndex;
+ private final boolean interpolateIfMultiStep;
+
+ public InputStep(
+ @JsonProperty("solverConf") FirstOrderSolverConf solverConf,
+ @JsonProperty("timeInputIndex") int timeInputIndex,
+ @JsonProperty("interpolateIfMultiStep") boolean interpolateIfMultiStep) {
+ this.solverConf = solverConf;
+ this.timeInputIndex = timeInputIndex;
+ this.interpolateIfMultiStep = interpolateIfMultiStep;
+ }
+
+ @Override
+ public ode.vertex.impl.helper.forward.OdeHelperForward instantiate() {
+ return new ode.vertex.impl.helper.forward.InputStep(solverConf.instantiate(), timeInputIndex, interpolateIfMultiStep);
+ }
+
+ @Override
+ public int nrofTimeInputs() {
+ return 1;
+ }
+
+ @Override
+ public InputStep clone() {
+ return new InputStep(solverConf.clone(), timeInputIndex, interpolateIfMultiStep);
+ }
+
+ @Override
+ public InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException {
+ if(vertexInputs.length <= timeInputIndex) {
+ throw new InvalidInputTypeException("Time input index was not part of input types!!");
+ }
+
+ final InputType timeInput = vertexInputs[timeInputIndex];
+ if(timeInput.arrayElementsPerExample() > 2) {
+ return new OutputTypeAddTimeAsDimension(timeInputIndex, new OutputTypeFromConfig()).getOutputType(conf, vertexInputs);
+ }
+ List inputTypeList = new ArrayList<>();
+ for (int i = 0; i < vertexInputs.length; i++) {
+ if (i != timeInputIndex) {
+ inputTypeList.add(vertexInputs[i]);
+ }
+ }
+ return new OutputTypeFromConfig().getOutputType(conf, inputTypeList.toArray(new InputType[0]));
+ }
+}
diff --git a/src/main/java/ode/vertex/conf/helper/forward/OdeHelperForward.java b/src/main/java/ode/vertex/conf/helper/forward/OdeHelperForward.java
new file mode 100644
index 0000000..7c5eda8
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/forward/OdeHelperForward.java
@@ -0,0 +1,31 @@
+package ode.vertex.conf.helper.forward;
+
+import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
+
+/**
+ * Serializable configuration of an {@link ode.vertex.impl.helper.forward.OdeHelperForward}
+ *
+ * @author Christian Skarby
+ */
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
+public interface OdeHelperForward extends OutputTypeHelper {
+
+ /**
+ * Instantiate the helper
+ * @return a New {@link ode.vertex.impl.helper.forward.OdeHelperForward}
+ */
+ ode.vertex.impl.helper.forward.OdeHelperForward instantiate();
+
+ /**
+ * How many time inputs are needed
+ * @return the number of needed time inputs
+ */
+ int nrofTimeInputs();
+
+
+ /**
+ * Clone the configuration
+ * @return a clone of the configuration
+ */
+ OdeHelperForward clone();
+}
diff --git a/src/main/java/ode/vertex/conf/helper/forward/OutputTypeAddTimeAsDimension.java b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeAddTimeAsDimension.java
new file mode 100644
index 0000000..4f0d844
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeAddTimeAsDimension.java
@@ -0,0 +1,58 @@
+package ode.vertex.conf.helper.forward;
+
+import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
+import org.deeplearning4j.nn.conf.layers.Convolution3D;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Adds a time dimension to output types from a given {@link OutputTypeHelper} given an input time type
+ *
+ * @author Christian Skarby
+ */
+public class OutputTypeAddTimeAsDimension implements OutputTypeHelper {
+
+ private final int timeInputIndex;
+ private final OutputTypeHelper sourceHelper;
+
+ public OutputTypeAddTimeAsDimension(int timeInputIndex, OutputTypeHelper sourceHelper) {
+ this.timeInputIndex = timeInputIndex;
+ this.sourceHelper = sourceHelper;
+ }
+
+ @Override
+ public InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException {
+ List inputTypeList = new ArrayList<>();
+ InputType time = vertexInputs[timeInputIndex];
+ for (int i = 0; i < vertexInputs.length; i++) {
+ if (i != timeInputIndex) {
+ inputTypeList.add(vertexInputs[i]);
+ }
+ }
+
+ InputType outputs = sourceHelper.getOutputType(conf, inputTypeList.toArray(new InputType[0]));
+
+ if(time.getType() != InputType.Type.FF) {
+ throw new IllegalArgumentException("Time must be 1D!");
+ }
+
+ return addTimeDim(outputs, time);
+ }
+
+ private InputType addTimeDim(InputType type, InputType timeDim) {
+ switch (type.getType()) {
+ case FF: return InputType.recurrent(type.arrayElementsPerExample(), timeDim.arrayElementsPerExample());
+ case CNN:
+ InputType.InputTypeConvolutional convType = (InputType.InputTypeConvolutional)type;
+ return InputType.convolutional3D(Convolution3D.DataFormat.NDHWC,
+ timeDim.arrayElementsPerExample(),
+ convType.getHeight(),
+ convType.getWidth(),
+ convType.getChannels());
+ default: throw new InvalidInputTypeException("Input type not supported with time as input!");
+ }
+ }
+}
diff --git a/src/main/java/ode/vertex/conf/helper/forward/OutputTypeFromConfig.java b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeFromConfig.java
new file mode 100644
index 0000000..d92cfb4
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeFromConfig.java
@@ -0,0 +1,27 @@
+package ode.vertex.conf.helper.forward;
+
+import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
+
+import java.util.Map;
+
+/**
+ * Determines output type from the given configuration
+ *
+ * @author Christian Skarby
+ */
+public class OutputTypeFromConfig implements OutputTypeHelper {
+
+ @Override
+ public InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException {
+ final Map inputTypeMap = conf.getLayerActivationTypes(vertexInputs);
+ inputTypeMap.keySet().removeAll(conf.getVertexInputs().keySet());
+
+ if(inputTypeMap.size() != 1) {
+ throw new IllegalStateException("Can only support one single output!");
+ }
+
+ return inputTypeMap.values().iterator().next();
+ }
+}
diff --git a/src/main/java/ode/vertex/conf/helper/forward/OutputTypeHelper.java b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeHelper.java
new file mode 100644
index 0000000..7e6c77b
--- /dev/null
+++ b/src/main/java/ode/vertex/conf/helper/forward/OutputTypeHelper.java
@@ -0,0 +1,22 @@
+package ode.vertex.conf.helper.forward;
+
+import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
+
+/**
+ * Interface for determining the shape of output given shape of inputs.
+ *
+ * @author Christian Skarby
+ */
+public interface OutputTypeHelper {
+
+ /**
+ * Return the format of the input for the given InputTypes
+ * @param conf {@link ComputationGraphConfiguration} which will be used as the derivative
+ * @param vertexInputs Inputs to the vertex
+ * @return an {@link InputType} for the next vertex in the graph
+ * @throws InvalidInputTypeException
+ */
+ InputType getOutputType(ComputationGraphConfiguration conf, InputType... vertexInputs) throws InvalidInputTypeException;
+}
diff --git a/src/main/java/ode/vertex/impl/OdeVertex.java b/src/main/java/ode/vertex/impl/OdeVertex.java
index aeb757a..582395c 100644
--- a/src/main/java/ode/vertex/impl/OdeVertex.java
+++ b/src/main/java/ode/vertex/impl/OdeVertex.java
@@ -1,31 +1,27 @@
package ode.vertex.impl;
-import ode.solve.api.FirstOrderEquation;
-import ode.solve.api.FirstOrderSolver;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import ode.vertex.impl.helper.OdeGraphHelper;
+import ode.vertex.impl.helper.backward.OdeHelperBackward;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
-import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
-import org.deeplearning4j.nn.graph.vertex.GraphVertex;
-import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
/**
- * Implementation of an ODE block.
+ * Implementation of an ODE block. Contains a {@link ComputationGraph} which defines the learnable function
+ * {@code f = z(t)/dt} for which the {@code OdeVertex} will output an estimate of z(t) for given t(s).
*
* @author Christian Skarby
*/
@@ -33,41 +29,28 @@ public class OdeVertex extends BaseGraphVertex {
private static final Logger log = LoggerFactory.getLogger(OdeVertex.class);
- private final static String parName = "params";
-
- private final ComputationGraph graph;
- private final FirstOrderSolver odeSolver;
+ private final OdeGraphHelper odeHelper;
private final TrainingConfig trainingConfig;
- private final Parameters parameters;
-
- private static class Parameters {
- private final INDArray time;
- private INDArray lastOutput; // z(t1) from paper
- private final NonContiguous1DView realGradients; // Parts of graph.getFlattenedGradients() which are actually gradients
-
- public Parameters(INDArray time) {
- this.time = time;
- realGradients = new NonContiguous1DView();
- }
+ @AllArgsConstructor
+ @Getter
+ public static class BaseGraphVertexInputs {
+ private final ComputationGraph graph;
+ private final String name;
+ private final int vertexIndex;
}
- public OdeVertex(ComputationGraph actualGraph,
- String name,
- int vertexIndex,
- ComputationGraph innerGraph,
- FirstOrderSolver odeSolver,
+ public OdeVertex(BaseGraphVertexInputs baseGraphVertexInputs,
+ OdeGraphHelper odeHelper,
TrainingConfig trainingConfig) {
- super(actualGraph, name, vertexIndex, null, null);
- this.graph = innerGraph;
+ super(baseGraphVertexInputs.getGraph(), baseGraphVertexInputs.getName(), baseGraphVertexInputs.getVertexIndex(), null, null);
this.trainingConfig = trainingConfig;
- this.odeSolver = odeSolver;
- this.parameters = new Parameters(Nd4j.create(new double[]{0, 1}));
+ this.odeHelper = odeHelper;
}
@Override
public String toString() {
- return graph.toString();
+ return odeHelper.getFunction().toString();
}
@Override
@@ -82,24 +65,23 @@ public Layer getLayer() {
@Override
public long numParams() {
- return graph.numParams();
+ return odeHelper.getFunction().numParams();
}
@Override
public INDArray params() {
- return graph.params();
+ return odeHelper.getFunction().params();
}
@Override
public void clear() {
super.clear();
- graph.clearLayersStates();
- parameters.lastOutput = null;
+ odeHelper.clear();
}
@Override
public Map paramTable(boolean backpropOnly) {
- return Collections.synchronizedMap(Collections.singletonMap(parName, params()));
+ return odeHelper.paramTable(backpropOnly);
}
@Override
@@ -110,10 +92,6 @@ public TrainingConfig getConfig() {
private void validateForward() {
if (!canDoForward())
throw new IllegalStateException("Cannot do forward pass: inputs not set");
-
- if (getInputs().length != 1) {
- throw new IllegalStateException("Only one input supported!");
- }
}
private void validateBackprop() {
@@ -137,26 +115,9 @@ public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
leverageInputs(workspaceMgr);
- final LayerWorkspaceMgr innerWorkspaceMgr = createWorkspaceMgr(workspaceMgr);
-
- final ForwardPass equation = new ForwardPass(
- graph,
- innerWorkspaceMgr,
- true, // Always use training as batch norm running mean and var become messed up otherwise. Same effect seen in original pytorch repo.
- getInputs());
-
- // nrof outputs must be same as number of inputs due to resblock
- parameters.lastOutput = workspaceMgr.createUninitialized(ArrayType.INPUT, getInputs()[0].shape()).detach();
- odeSolver.integrate(equation, parameters.time, getInputs()[0], parameters.lastOutput);
+ final INDArray output = odeHelper.doForward(workspaceMgr, getInputs());
- for (GraphVertex vertex : graph.getVertices()) {
- final INDArray[] inputs = vertex.getInputs();
- for (int i = 0; i < inputs.length; i++) {
- vertex.setInput(i, workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, inputs[i]), workspaceMgr);
- }
- }
-
- return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, parameters.lastOutput);
+ return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, output);
}
@Override
@@ -164,51 +125,18 @@ public Pair doBackward(boolean tbptt, LayerWorkspaceMgr wo
validateBackprop();
log.trace("Start backward");
- // Create augmented dynamics for adjoint method
- // Initialization: S0:
- // z(t1) = lastoutput
- // a(t) = -dL/d(z(t1)) = -epsilon from next layer (i.e getEpsilon)
- // parameters = zeros
- // dL/dt1 = dL / dz(t1) dot z(t1)
- final INDArray dL_dtN = getEpsilon().reshape(1, parameters.lastOutput.length())
- .mmul(parameters.lastOutput.reshape(parameters.lastOutput.length(), 1)).muli(-1);
-
- final INDArray zAug = Nd4j.create(1, parameters.lastOutput.length() + getEpsilon().length() + graph.numParams() + dL_dtN.length());
-
- final NDArrayIndexAccumulator accumulator = new NDArrayIndexAccumulator(zAug);
- accumulator.increment(parameters.lastOutput.reshape(new long[]{1, parameters.lastOutput.length()}))
- .increment(getEpsilon().reshape(new long[]{1, getEpsilon().length()}))
- .increment(Nd4j.zeros(parameters.realGradients.length()).reshape(new long[]{1, Nd4j.zeros(parameters.realGradients.length()).length()}))
- .increment(dL_dtN.reshape(new long[]{1, dL_dtN.length()}));
-
- final AugmentedDynamics augmentedDynamics = new AugmentedDynamics(
- zAug,
- getEpsilon().shape(),
- new long[]{parameters.realGradients.length()},
- dL_dtN.shape());
-
- final LayerWorkspaceMgr innerWorkspaceMgr = createWorkspaceMgr(workspaceMgr);
-
- final FirstOrderEquation equation = new BackpropagateAdjoint(
- augmentedDynamics,
- new ForwardPass(graph,
- innerWorkspaceMgr,
- true,
- getInputs()),
- new BackpropagateAdjoint.GraphInfo(graph, parameters.realGradients, innerWorkspaceMgr, tbptt)
- );
-
- INDArray augAns = odeSolver.integrate(equation, Nd4j.reverse(parameters.time.dup()), zAug, zAug.dup());
-
- augmentedDynamics.updateFrom(augAns);
-
- final INDArray epsilonOut = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, augmentedDynamics.zAdjoint());
+ final Pair gradients = odeHelper.doBackward(
+ new OdeHelperBackward.MiscPar(tbptt, workspaceMgr),
+ getEpsilon(),
+ getInputs());
- parameters.realGradients.assignFrom(augmentedDynamics.paramAdjoint());
- final Gradient gradient = new DefaultGradient(graph.getFlattenedGradients());
- gradient.setGradientFor(parName, graph.getFlattenedGradients());
+ final INDArray[] inputGrads = gradients.getSecond();
+ final INDArray[] leveragedGrads = new INDArray[inputGrads.length];
+ for (int i = 0; i < inputGrads.length; i++) {
+ leveragedGrads[i] = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, inputGrads[i]);
+ }
- return new Pair<>(gradient, new INDArray[]{epsilonOut});
+ return new Pair<>(gradients.getFirst(), leveragedGrads);
}
private void leverageInputs(LayerWorkspaceMgr workspaceMgr) {
@@ -217,63 +145,9 @@ private void leverageInputs(LayerWorkspaceMgr workspaceMgr) {
}
}
- private LayerWorkspaceMgr createWorkspaceMgr(final LayerWorkspaceMgr outerWsMgr) {
-
- return new ComputationGraph(graph.getConfiguration()) {
- public LayerWorkspaceMgr spyWsConfigs() {
- // A little bit too many methods to comfortably decorate. Try to copy config instead
- final LayerWorkspaceMgr.Builder wsBuilder = LayerWorkspaceMgr.builder();
- for (ArrayType type : ArrayType.values()) {
- if (outerWsMgr.hasConfiguration(type)) {
- wsBuilder.with(type, outerWsMgr.getWorkspaceName(type), outerWsMgr.getConfiguration(type));
- }
- }
-
- final LayerWorkspaceMgr wsMgr = wsBuilder
- .with(ArrayType.FF_WORKING_MEM, "WS_ODE_VERTEX_LAYER_WORKING_MEM", WS_LAYER_WORKING_MEM_CONFIG)
- .with(ArrayType.BP_WORKING_MEM, "WS_ODE_VERTEX_LAYER_WORKING_MEM", WS_LAYER_WORKING_MEM_CONFIG)
- .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, "WS_ODE_VERTEX_RNN_LOOP_WORKING_MEM", WS_RNN_LOOP_WORKING_MEM_CONFIG)
- .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, "WS_ODE_VERTEX_RNN_LOOP_WORKING_MEM", WS_RNN_LOOP_WORKING_MEM_CONFIG)
- .with(ArrayType.ACTIVATIONS, "WS_ODE_VERTEX_ALL_LAYERS_ACT", WS_ALL_LAYERS_ACT_CONFIG)
- .with(ArrayType.ACTIVATION_GRAD, "WS_ODE_VERTEX_ALL_LAYERS_GRAD", WS_ALL_LAYERS_ACT_CONFIG)
- .build();
- wsMgr.setHelperWorkspacePointers(outerWsMgr.getHelperWorkspacePointers());
- return wsMgr;
- }
- }.spyWsConfigs();
-
- }
-
@Override
public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
- graph.setBackpropGradientsViewArray(backpropGradientsViewArray);
-
- // What is this about? Some layers "abuse" the gradient to perform updates of parameters for which no gradient
- // is calculated and this screws up the ODE solvers idea of what the solution is. The following layers are known
- // to do this:
- //
- // * BatchNormalization: The global variance and mean are just the (sliding) average of the batch dittos.
- // However, in order to support distributed training the updates are performed by adding
- // the change to the state as a gradient even through it is not really.
-
- // Maybe get these from config so user can specify others e.g. for custom layers
- final List nonGradientParamNames = Arrays.asList(
- BatchNormalizationParamInitializer.GLOBAL_VAR,
- BatchNormalizationParamInitializer.GLOBAL_MEAN);
-
- parameters.realGradients.clear();
- for (Layer layer : graph.getLayers()) {
- Map gradParams = layer.conf().getLayer().initializer().getGradientsFromFlattened(layer.conf(), layer.getGradientsViewArray());
- for (Map.Entry parNameAndGradView : gradParams.entrySet()) {
-
- final String parName = parNameAndGradView.getKey();
- final INDArray grad = parNameAndGradView.getValue();
-
- if (!nonGradientParamNames.contains(parName)) {
- parameters.realGradients.addView(grad.reshape(grad.length()));
- }
- }
- }
+ odeHelper.setBackpropGradientsViewArray(backpropGradientsViewArray);
}
@Override
diff --git a/src/main/java/ode/vertex/impl/gradview/Contiguous1DView.java b/src/main/java/ode/vertex/impl/gradview/Contiguous1DView.java
new file mode 100644
index 0000000..cf31163
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/gradview/Contiguous1DView.java
@@ -0,0 +1,49 @@
+package ode.vertex.impl.gradview;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+/**
+ * A simple contiguous gradient view
+ *
+ * @author Christian Skarby
+ */
+public class Contiguous1DView implements INDArray1DView {
+
+ private final INDArray view;
+
+ public Contiguous1DView(INDArray view) {
+ this.view = view;
+ }
+
+ @Override
+ public void assignFrom(INDArray toAssign) {
+ if(toAssign.length() != view.length()) {
+ throw new IllegalArgumentException("Array to assignFrom must have same length! " +
+ "This length: " + view.length() +" array length: " + toAssign.length());
+ }
+
+ if(toAssign.rank() != 1) {
+ throw new IllegalArgumentException("Array toAssign must have rank 1!");
+ }
+
+ view.assign(toAssign);
+ }
+
+ @Override
+ public void assignTo(INDArray assignTo) {
+ if(assignTo.length() != view.length()) {
+ throw new IllegalArgumentException("Array assignTo must have same length! " +
+ "This length: " + view.length() +" array length: " + assignTo.length());
+ }
+ if(assignTo.rank() != 1) {
+ throw new IllegalArgumentException("Array assignTo must have rank 1!");
+ }
+
+ assignTo.assign(view);
+ }
+
+ @Override
+ public long length() {
+ return view.length();
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/gradview/GradientViewFactory.java b/src/main/java/ode/vertex/impl/gradview/GradientViewFactory.java
new file mode 100644
index 0000000..933db86
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/gradview/GradientViewFactory.java
@@ -0,0 +1,41 @@
+package ode.vertex.impl.gradview;
+
+import ode.vertex.impl.gradview.parname.ParamNameMapping;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
+
+import java.io.Serializable;
+
+/**
+ * Creates an {@link INDArray1DView} of all gradients in a set of vertices. Reason why this exists is that there are
+ * instances of parameters which DL4J puts in the gradient view but which are not actually gradients. One example is the
+ * running mean and variance of batchnorm. Such parameters does not play well with adjoint backpropagation for Neural
+ * ODEs.
+ *
+ * @author Christian Skarby
+ */
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
+public interface GradientViewFactory extends Serializable {
+
+ /**
+ * Create a {@link ParameterGradientView} of the gradients of the given graph.
+ *
+ * @param graph Graph to extract gradients from
+ * @return Views of the gradients
+ */
+ ParameterGradientView create(ComputationGraph graph);
+
+ /**
+ * Return the mapping used to create non-colliding names
+ *
+ * @return the mapping used to create non-colliding names
+ */
+ ParamNameMapping paramNameMapping();
+
+ /**
+ * Clone the factory
+ *
+ * @return a clone
+ */
+ GradientViewFactory clone();
+}
diff --git a/src/main/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklisted.java b/src/main/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklisted.java
new file mode 100644
index 0000000..d6e2d0f
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklisted.java
@@ -0,0 +1,151 @@
+package ode.vertex.impl.gradview;
+
+import lombok.Data;
+import ode.vertex.impl.gradview.parname.Concat;
+import ode.vertex.impl.gradview.parname.ParamNameMapping;
+import org.deeplearning4j.nn.api.Layer;
+import org.deeplearning4j.nn.gradient.DefaultGradient;
+import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.graph.vertex.GraphVertex;
+import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * {@link GradientViewFactory} which selects either a {@link Contiguous1DView} or a {@link NonContiguous1DView} based on
+ * presence of blacklisted parameters in the graph.
+ *
+ * @author Christian Skarby
+ */
+@Data
+public class GradientViewSelectionFromBlacklisted implements GradientViewFactory {
+
+ private final List nonGradientParamNames;
+ private final ParamNameMapping paramNameMapping;
+
+ public GradientViewSelectionFromBlacklisted() {
+ this(Arrays.asList(
+ BatchNormalizationParamInitializer.GLOBAL_VAR,
+ BatchNormalizationParamInitializer.GLOBAL_MEAN));
+ }
+
+ public GradientViewSelectionFromBlacklisted(List nonGradientParamNames) {
+ this(nonGradientParamNames,
+ new Concat());
+ }
+
+ public GradientViewSelectionFromBlacklisted(@JsonProperty("nonGradientParamNames") List nonGradientParamNames,
+ @JsonProperty("paramNameMapping") ParamNameMapping paramNameMapping) {
+ this.nonGradientParamNames = nonGradientParamNames;
+ this.paramNameMapping = paramNameMapping;
+ }
+
+ public ParameterGradientView create(ComputationGraph graph) {
+
+ final Gradient gradient = getAllGradients(graph);
+
+ for (GraphVertex vertex : graph.getVertices()) {
+ if (hasNonGradient(vertex)) {
+ return new ParameterGradientView(gradient, createNonContiguous1DView(graph));
+ }
+ }
+
+ return new ParameterGradientView(gradient, new Contiguous1DView(graph.getGradientsViewArray()));
+ }
+
+ private boolean hasNonGradient(GraphVertex vertex) {
+ boolean anyNonGrad = false;
+ for (String parName : vertex.paramTable(false).keySet()) {
+ anyNonGrad |= nonGradientParamNames.contains(parName);
+ }
+ return anyNonGrad;
+ }
+
+ private NonContiguous1DView createNonContiguous1DView(ComputationGraph graph) {
+ final NonContiguous1DView gradView = new NonContiguous1DView();
+
+ for (GraphVertex vertex : graph.getVertices()) {
+ addGradientView(gradView, vertex);
+ }
+ return gradView;
+ }
+
+ private void addGradientView(NonContiguous1DView gradView, GraphVertex vertex) {
+ if (vertex.numParams() > 0 && hasNonGradient(vertex)) {
+ Layer layer = vertex.getLayer();
+
+ if (layer == null) {
+ // Only way I have found to get mapping from gradient view to gradient view per parameter is though
+ // a ParameterInitializer as done below and only Layers seem be able to provide them
+ throw new UnsupportedOperationException("Can not (reliably) get correct gradient views from non-layer " +
+ "vertices with blacklisted parameters!");
+ }
+
+ Map gradParams = layer.conf().getLayer().initializer().getGradientsFromFlattened(layer.conf(), layer.getGradientsViewArray());
+ for (Map.Entry parNameAndGradView : gradParams.entrySet()) {
+ final String parName = parNameAndGradView.getKey();
+ final INDArray grad = parNameAndGradView.getValue();
+
+ if (!nonGradientParamNames.contains(parName)) {
+ gradView.addView(grad);
+ }
+ }
+ } else if (vertex.numParams() > 0) {
+ gradView.addView(vertex.getGradientsViewArray());
+ }
+ }
+
+ private Gradient getAllGradients(ComputationGraph graph) {
+ final Gradient allGradients = new DefaultGradient(graph.getGradientsViewArray());
+ for (GraphVertex vertex : graph.getVertices()) {
+ if (vertex.numParams() > 0) {
+ Layer layer = vertex.getLayer();
+
+ if (layer == null) {
+ // Only way I have found to get mapping from gradient view to gradient view per parameter is though
+ // a ParameterInitializer as done below and only Layers seem be able to provide them
+ throw new UnsupportedOperationException("Can not (reliably) get correct gradient views from non-layer " +
+ "vertices with blacklisted parameters!");
+ }
+
+ Map gradParams = layer.conf().getLayer().initializer().getGradientsFromFlattened(layer.conf(), layer.getGradientsViewArray());
+ for (Map.Entry parNameAndGradView : gradParams.entrySet()) {
+ final String parName = parNameAndGradView.getKey();
+ final INDArray grad = parNameAndGradView.getValue();
+ allGradients.setGradientFor(paramNameMapping.map(vertex.getVertexName(), parName), grad);
+ }
+ }
+ }
+ return allGradients;
+ }
+
+ @Override
+ public ParamNameMapping paramNameMapping() {
+ return paramNameMapping;
+ }
+
+ @Override
+ public GradientViewFactory clone() {
+ return new GradientViewSelectionFromBlacklisted(nonGradientParamNames, paramNameMapping);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (!(o instanceof GradientViewSelectionFromBlacklisted)) return false;
+ GradientViewSelectionFromBlacklisted that = (GradientViewSelectionFromBlacklisted) o;
+ return nonGradientParamNames.equals(that.nonGradientParamNames) && paramNameMapping.equals(that.paramNameMapping);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(nonGradientParamNames);
+ }
+
+}
diff --git a/src/main/java/ode/vertex/impl/gradview/INDArray1DView.java b/src/main/java/ode/vertex/impl/gradview/INDArray1DView.java
new file mode 100644
index 0000000..43367dd
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/gradview/INDArray1DView.java
@@ -0,0 +1,29 @@
+package ode.vertex.impl.gradview;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+/**
+ * A 1D view of one or more {@link INDArray}s.
+ *
+ * @author Christian Skarby
+ */
+public interface INDArray1DView {
+
+ /**
+ * Sets the view to the given {@link INDArray}
+ * @param toAssign view will be set to this. Must be same size as view
+ */
+ void assignFrom(INDArray toAssign);
+
+ /**
+ * Sets the values of the given {@link INDArray} to the values of the view
+ * @param assignTo will be set to state of the view. Must be same size as view
+ */
+ void assignTo(INDArray assignTo);
+
+ /**
+ * Return the current length (total number of elements) of the view
+ * @return the current length of the view
+ */
+ long length();
+}
diff --git a/src/main/java/ode/vertex/impl/NonContiguous1DView.java b/src/main/java/ode/vertex/impl/gradview/NonContiguous1DView.java
similarity index 61%
rename from src/main/java/ode/vertex/impl/NonContiguous1DView.java
rename to src/main/java/ode/vertex/impl/gradview/NonContiguous1DView.java
index 2f981a5..4d72c79 100644
--- a/src/main/java/ode/vertex/impl/NonContiguous1DView.java
+++ b/src/main/java/ode/vertex/impl/gradview/NonContiguous1DView.java
@@ -1,11 +1,10 @@
-package ode.vertex.impl;
+package ode.vertex.impl.gradview;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
/**
@@ -13,28 +12,17 @@
*
* @author Christian Skarby
*/
-public class NonContiguous1DView {
+public class NonContiguous1DView implements INDArray1DView {
private final List view = new ArrayList<>();
private long length = 0;
- public void addView(INDArray array, long begin, long end) {
- INDArray viewSlice = array.get(NDArrayIndex.interval(begin, end));
- addView(viewSlice);
- }
-
public void addView(INDArray viewSlice) {
- if(viewSlice.isColumnVectorOrScalar()) {
- throw new IllegalArgumentException("Must be vector or scalar! Had shape: " + Arrays.toString(viewSlice.shape()));
- }
length += viewSlice.length();
view.add(viewSlice);
}
- /**
- * Sets the view to the given {@link INDArray}
- * @param toAssign view will be set to this. Must be same size as view
- */
+ @Override
public void assignFrom(INDArray toAssign) {
if(toAssign.length() != length) {
throw new IllegalArgumentException("Array to assignFrom must have same length! " +
@@ -47,15 +35,12 @@ public void assignFrom(INDArray toAssign) {
long ptr = 0;
for(INDArray viewSlice: view) {
- viewSlice.assign(toAssign.get(NDArrayIndex.interval(ptr, ptr + viewSlice.length())));
+ viewSlice.assign(toAssign.get(NDArrayIndex.interval(ptr, ptr + viewSlice.length())).reshape(viewSlice.shape()));
ptr += viewSlice.length();
}
}
- /**
- * Sets the values of the given {@link INDArray} to the values of the view
- * @param assignTo will be set to state of the view. Must be same size as view
- */
+ @Override
public void assignTo(INDArray assignTo) {
if(assignTo.length() != length) {
throw new IllegalArgumentException("Array assignTo must have same length! " +
@@ -67,24 +52,19 @@ public void assignTo(INDArray assignTo) {
long ptr = 0;
for(INDArray viewSlice: view) {
- assignTo.put(new INDArrayIndex[] {NDArrayIndex.interval(ptr, ptr + viewSlice.length())}, viewSlice);
+ assignTo.put(new INDArrayIndex[] {NDArrayIndex.interval(ptr, ptr + viewSlice.length())}, viewSlice.reshape(viewSlice.length()));
ptr += viewSlice.length();
}
}
- /**
- * Return the current length (total number of elements) of the view
- * @return the current length of the view
- */
+ @Override
public long length() {
return length;
}
- /**
- * Clears the view
- */
- public void clear() {
- view.clear();
- length = 0;
+ @Override
+ public String toString() {
+ return view.toString();
}
+
}
diff --git a/src/main/java/ode/vertex/impl/gradview/ParameterGradientView.java b/src/main/java/ode/vertex/impl/gradview/ParameterGradientView.java
new file mode 100644
index 0000000..917129a
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/gradview/ParameterGradientView.java
@@ -0,0 +1,37 @@
+package ode.vertex.impl.gradview;
+
+import org.deeplearning4j.nn.gradient.Gradient;
+
+/**
+ * Different views of the parameter gradients of a graph.
+ *
+ * @author Christian Skarby
+ */
+public class ParameterGradientView {
+
+ private final Gradient allGradients;
+ private final INDArray1DView realGradientView;
+
+ public ParameterGradientView(Gradient allGradients, INDArray1DView realGradientView) {
+ this.allGradients = allGradients;
+ this.realGradientView = realGradientView;
+ }
+
+
+ /**
+ * Returns all gradients (even those which are not actually gradients) per parameter
+ * @return a Gradient for all parameters
+ */
+ public Gradient allGradientsPerParam() {
+ return allGradients;
+ }
+
+ /**
+ * Returns an {@link INDArray1DView} of only the parts of the gradient view which are actually gradients.
+ * @return and {@link INDArray1DView}
+ */
+ public INDArray1DView realGradientView() {
+ return realGradientView;
+ }
+
+}
diff --git a/src/main/java/ode/vertex/impl/gradview/parname/Concat.java b/src/main/java/ode/vertex/impl/gradview/parname/Concat.java
new file mode 100644
index 0000000..c4d3681
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/gradview/parname/Concat.java
@@ -0,0 +1,39 @@
+package ode.vertex.impl.gradview.parname;
+
+import lombok.Data;
+import org.nd4j.linalg.primitives.Pair;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
+
+/**
+ * {@link ParamNameMapping} which concatenates the names
+ *
+ * @author Christian Skarby
+ */
+@Data
+public class Concat implements ParamNameMapping {
+
+ private final String concatStr;
+
+ public Concat() {
+ this("-");
+ }
+
+ public Concat(@JsonProperty("concatStr") String concatStr) {
+ this.concatStr = concatStr;
+ }
+
+
+ @Override
+ public String map(String vertexName, String paramName) {
+ return vertexName + concatStr + paramName;
+ }
+
+ @Override
+ public Pair reverseMap(String combinedName) {
+ final String[] split = combinedName.split(concatStr);
+ if(split.length != 2) {
+ throw new IllegalArgumentException("Can not reverse mapping for " + combinedName);
+ }
+ return new Pair<>(split[0], split[1]);
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/gradview/parname/ParamNameMapping.java b/src/main/java/ode/vertex/impl/gradview/parname/ParamNameMapping.java
new file mode 100644
index 0000000..467e448
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/gradview/parname/ParamNameMapping.java
@@ -0,0 +1,30 @@
+package ode.vertex.impl.gradview.parname;
+
+import org.nd4j.linalg.primitives.Pair;
+import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
+
+/**
+ * A mapping between parameter and vertex names to a combined name. Also capable of reverse mapping.
+ *
+ * @author Christian Skarby
+ */
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
+public interface ParamNameMapping {
+
+ /**
+ * Map layerName and paramName to a new parameter name which is unique for the given input
+ * @param vertexName Name of layer
+ * @param paramName Name of parameter
+ * @return A combined namn
+ */
+ String map(String vertexName, String paramName);
+
+ /**
+ * Reverse mapping. In other words, {@code mapping.reverseMap(mapping.map(vertexName, paramName)); } returns
+ * {@code [vertexName, paramName]}
+ * @param combinedName Combined name to reverese map
+ * @return a Pair where vertexName is the first member and paramName is the second.
+ */
+ Pair reverseMap(String combinedName);
+
+}
diff --git a/src/main/java/ode/vertex/impl/gradview/parname/Prefix.java b/src/main/java/ode/vertex/impl/gradview/parname/Prefix.java
new file mode 100644
index 0000000..813c6bf
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/gradview/parname/Prefix.java
@@ -0,0 +1,33 @@
+package ode.vertex.impl.gradview.parname;
+
+import lombok.Data;
+import org.nd4j.linalg.primitives.Pair;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
+
+/**
+ * {@link ParamNameMapping} which adds a prefix to another mapping
+ *
+ * @author Christian Skarby
+ */
+@Data
+public class Prefix implements ParamNameMapping {
+
+ private final ParamNameMapping mapping;
+ private final String prefix;
+
+ public Prefix(@JsonProperty("prefix") String prefix,
+ @JsonProperty("mapping") ParamNameMapping mapping) {
+ this.mapping = mapping;
+ this.prefix = prefix;
+ }
+
+ @Override
+ public String map(String vertexName, String paramName) {
+ return prefix + mapping.map(vertexName, paramName);
+ }
+
+ @Override
+ public Pair reverseMap(String combinedName) {
+ return mapping.reverseMap(combinedName.substring(prefix.length()));
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/NDArrayIndexAccumulator.java b/src/main/java/ode/vertex/impl/helper/NDArrayIndexAccumulator.java
similarity index 80%
rename from src/main/java/ode/vertex/impl/NDArrayIndexAccumulator.java
rename to src/main/java/ode/vertex/impl/helper/NDArrayIndexAccumulator.java
index 790473a..9873878 100644
--- a/src/main/java/ode/vertex/impl/NDArrayIndexAccumulator.java
+++ b/src/main/java/ode/vertex/impl/helper/NDArrayIndexAccumulator.java
@@ -1,4 +1,4 @@
-package ode.vertex.impl;
+package ode.vertex.impl.helper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
@@ -10,12 +10,12 @@
*
* @author Christian Skarby
*/
-class NDArrayIndexAccumulator {
+public class NDArrayIndexAccumulator {
private final INDArrayIndex[] state;
private final INDArray array;
- NDArrayIndexAccumulator(INDArray array) {
+ public NDArrayIndexAccumulator(INDArray array) {
this.array = array;
state = new INDArrayIndex[array.shape().length];
for(int i = 0; i < array.shape().length; i++) {
@@ -23,7 +23,9 @@ class NDArrayIndexAccumulator {
}
}
- NDArrayIndexAccumulator increment(INDArray toAdd) {
+ public NDArrayIndexAccumulator increment(INDArray toAdd) {
+ if(toAdd.isEmpty()) return this;
+
for(int dim = 0; dim < toAdd.shape().length; dim++) {
if(toAdd.size(dim) != array.size(dim)) {
final long curr = state[dim] instanceof NDArrayIndexAll ? 0 : state[dim].end();
diff --git a/src/main/java/ode/vertex/impl/helper/OdeGraphHelper.java b/src/main/java/ode/vertex/impl/helper/OdeGraphHelper.java
new file mode 100644
index 0000000..7bca036
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/OdeGraphHelper.java
@@ -0,0 +1,189 @@
+package ode.vertex.impl.helper;
+
+import ode.vertex.impl.gradview.GradientViewFactory;
+import ode.vertex.impl.gradview.INDArray1DView;
+import ode.vertex.impl.gradview.ParameterGradientView;
+import ode.vertex.impl.helper.backward.OdeHelperBackward;
+import ode.vertex.impl.helper.forward.OdeHelperForward;
+import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.graph.vertex.GraphVertex;
+import org.deeplearning4j.nn.workspace.ArrayType;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.primitives.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Helper which jumps through the hoops so that a {@link ComputationGraph} can be seen as the function which provides
+ * the derivatives for an ODE solver.
+ *
+ * @author Christian Skarby
+ */
+public class OdeGraphHelper {
+
+ private static final Logger log = LoggerFactory.getLogger(OdeGraphHelper.class);
+
+ private final OdeHelperForward odeHelperForward;
+ private final OdeHelperBackward odeHelperBackward;
+ private final CompGraphAsOdeFunction odeFunction;
+
+ public OdeGraphHelper(OdeHelperForward odeHelperForward, OdeHelperBackward odeHelperBackward, CompGraphAsOdeFunction odeFunction) {
+ this.odeHelperForward = odeHelperForward;
+ this.odeHelperBackward = odeHelperBackward;
+ this.odeFunction = odeFunction;
+ }
+
+ public static class CompGraphAsOdeFunction {
+
+ private INDArray lastOutput; // z(t1) from paper
+ private ParameterGradientView parameterGradientView;
+ private final ComputationGraph function;
+ private final GradientViewFactory gradientViewFactory;
+
+ public CompGraphAsOdeFunction(ComputationGraph odeFunction, GradientViewFactory gradientViewFactory) {
+ this.function = odeFunction;
+ this.gradientViewFactory = gradientViewFactory;
+ }
+
+ private INDArray lastOutput() {
+ return lastOutput;
+ }
+
+ private INDArray1DView realGradients() {
+ return parameterGradientView.realGradientView();
+ }
+
+ private void setLastOutput(INDArray lastOutput) {
+ this.lastOutput = lastOutput;
+ }
+
+ void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
+ function.setBackpropGradientsViewArray(backpropGradientsViewArray);
+ parameterGradientView = gradientViewFactory.create(function);
+ }
+
+ Map paramTable(boolean backpropOnly) {
+ final Map output = new HashMap<>();
+ for(GraphVertex vertex: function.getVertices()) {
+ final Map partable = vertex.paramTable(backpropOnly);
+ if(partable != null) {
+ for(Map.Entry parEntry: partable.entrySet()) {
+ output.put(
+ gradientViewFactory.paramNameMapping().map(vertex.getVertexName(), parEntry.getKey()),
+ parEntry.getValue());
+ }
+ }
+ }
+ return output;
+ }
+ }
+
+ /**
+ * Clears the current state wrt training. Gradient views are not touched.
+ */
+ public void clear() {
+ getFunction().clearLayersStates();
+ odeFunction.setLastOutput(null);
+ }
+
+ public Map paramTable(boolean backpropOnly) {
+ return odeFunction.paramTable(backpropOnly);
+ }
+
+ public ComputationGraph getFunction() {
+ return odeFunction.function;
+ }
+
+ /**
+ * What is this about? Some layers "abuse" the gradient to perform updates of parameters for which no gradient
+ * is calculated and this screws up the ODE solvers idea of what the solution is. The following layers are known
+ * to do this:
+ *
+ * * BatchNormalization: The global variance and mean are just the (sliding) average of the batch dittos.
+ * However, in order to support distributed training the updates are performed by adding
+ * the change to the state as a gradient even through it is not really.
+ * @param backpropGradientsViewArray View of parameter gradients
+ */
+ public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
+ odeFunction.setBackpropGradientsViewArray(backpropGradientsViewArray);
+ }
+
+
+ public INDArray doForward(LayerWorkspaceMgr workspaceMgr, INDArray[] inputs) {
+
+ final LayerWorkspaceMgr innerWorkspaceMgr = createWorkspaceMgr(workspaceMgr, getFunction());
+
+ getFunction().getConfiguration().setIterationCount(0);
+ final INDArray output = odeHelperForward.solve(getFunction(), innerWorkspaceMgr, inputs);
+ log.debug("Nrof func eval forward " + getFunction().getIterationCount());
+
+ odeFunction.setLastOutput(output.detach());
+
+ return output;
+ }
+
+ public Pair doBackward(
+ OdeHelperBackward.MiscPar miscPars,
+ INDArray lossGradient,
+ INDArray[] lastInputs) {
+
+ final OdeHelperBackward.InputArrays inputArrays = new OdeHelperBackward.InputArrays(
+ lastInputs,
+ odeFunction.lastOutput(),
+ lossGradient,
+ odeFunction.realGradients()
+ );
+
+ final OdeHelperBackward.MiscPar miscParNewWsMgr = new OdeHelperBackward.MiscPar(
+ miscPars.isUseTruncatedBackPropTroughTime(),
+ createWorkspaceMgr(miscPars.getWsMgr(), getFunction())
+ );
+
+ getFunction().getConfiguration().setIterationCount(0);
+ final INDArray[] gradients = odeHelperBackward.solve(getFunction(), inputArrays, miscParNewWsMgr);
+ log.debug("Nrof func eval backward " + getFunction().getIterationCount());
+
+ return new Pair<>(odeFunction.parameterGradientView.allGradientsPerParam(), gradients);
+ }
+
+ /**
+ * Changes names of workspaces associated with certain {@link ArrayType}s in order to avoid workspace conflicts
+ * due to "graph in graph".
+ * @param outerWsMgr workspace manager
+ * @return LayerWorkspaceMgr with new workspace names but using the same workspace configs as in {@link ComputationGraph}
+ */
+ private LayerWorkspaceMgr createWorkspaceMgr(final LayerWorkspaceMgr outerWsMgr, ComputationGraph graph) {
+ if(outerWsMgr == LayerWorkspaceMgr.noWorkspacesImmutable()) {
+ // This can be handled better, but I just CBA to check presence for every array type right now...
+ return outerWsMgr;
+ }
+
+ return new ComputationGraph(graph.getConfiguration()) {
+ LayerWorkspaceMgr spyWsConfigs() {
+ // A little bit too many methods to comfortably decorate. Try to copy config instead
+ final LayerWorkspaceMgr.Builder wsBuilder = LayerWorkspaceMgr.builder();
+ for (ArrayType type : ArrayType.values()) {
+ if (outerWsMgr.hasConfiguration(type)) {
+ wsBuilder.with(type, outerWsMgr.getWorkspaceName(type), outerWsMgr.getConfiguration(type));
+ }
+ }
+
+ final LayerWorkspaceMgr wsMgr = wsBuilder
+ .with(ArrayType.FF_WORKING_MEM, "WS_ODE_VERTEX_LAYER_WORKING_MEM", WS_LAYER_WORKING_MEM_CONFIG)
+ .with(ArrayType.BP_WORKING_MEM, "WS_ODE_VERTEX_LAYER_WORKING_MEM", WS_LAYER_WORKING_MEM_CONFIG)
+ .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, "WS_ODE_VERTEX_RNN_LOOP_WORKING_MEM", WS_RNN_LOOP_WORKING_MEM_CONFIG)
+ .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, "WS_ODE_VERTEX_RNN_LOOP_WORKING_MEM", WS_RNN_LOOP_WORKING_MEM_CONFIG)
+ .with(ArrayType.ACTIVATIONS, "WS_ODE_VERTEX_ALL_LAYERS_ACT", WS_ALL_LAYERS_ACT_CONFIG)
+ .with(ArrayType.ACTIVATION_GRAD, "WS_ODE_VERTEX_ALL_LAYERS_GRAD", WS_ALL_LAYERS_ACT_CONFIG)
+ .build();
+ wsMgr.setHelperWorkspacePointers(outerWsMgr.getHelperWorkspacePointers());
+ return wsMgr;
+ }
+ }.spyWsConfigs();
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/AugmentedDynamics.java b/src/main/java/ode/vertex/impl/helper/backward/AugmentedDynamics.java
similarity index 91%
rename from src/main/java/ode/vertex/impl/AugmentedDynamics.java
rename to src/main/java/ode/vertex/impl/helper/backward/AugmentedDynamics.java
index f859712..e071582 100644
--- a/src/main/java/ode/vertex/impl/AugmentedDynamics.java
+++ b/src/main/java/ode/vertex/impl/helper/backward/AugmentedDynamics.java
@@ -1,4 +1,4 @@
-package ode.vertex.impl;
+package ode.vertex.impl.helper.backward;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
@@ -11,7 +11,7 @@
*
* @author Christian Skarby
*/
-class AugmentedDynamics {
+public class AugmentedDynamics {
private final INDArray augStateFlat;
private final INDArray z;
@@ -20,7 +20,7 @@ class AugmentedDynamics {
private final INDArray tAdjoint;
- AugmentedDynamics(INDArray zAug, long[] zShape, long[] paramShape, long[] tShape) {
+ public AugmentedDynamics(INDArray zAug, long[] zShape, long[] paramShape, long[] tShape) {
this(
zAug,
zAug.get(NDArrayIndex.interval(0, length(zShape))).reshape(zShape),
@@ -45,7 +45,7 @@ private static long length(long[] shape) {
return length;
}
- void updateFrom(INDArray zAug) {
+ public void updateFrom(INDArray zAug) {
augStateFlat.assign(zAug);
}
diff --git a/src/main/java/ode/vertex/impl/BackpropagateAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/BackpropagateAdjoint.java
similarity index 90%
rename from src/main/java/ode/vertex/impl/BackpropagateAdjoint.java
rename to src/main/java/ode/vertex/impl/helper/backward/BackpropagateAdjoint.java
index 0fa0014..4173669 100644
--- a/src/main/java/ode/vertex/impl/BackpropagateAdjoint.java
+++ b/src/main/java/ode/vertex/impl/helper/backward/BackpropagateAdjoint.java
@@ -1,7 +1,8 @@
-package ode.vertex.impl;
+package ode.vertex.impl.helper.backward;
import lombok.AllArgsConstructor;
import ode.solve.api.FirstOrderEquation;
+import ode.vertex.impl.gradview.INDArray1DView;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
@@ -25,7 +26,7 @@
*
* f(z(t), theta) = output from forward pass through the layers of the ODE vertex (i.e. the layers of graph)
* -a(t)*df/dz(t) = dL / dz(t) = epsilon from a backward pass through the layers of the ODE vertex (i.e. the layers of graph) wrt previous output.
- * -a(t) * df / dt = not used, set to 0
+ * -a(t) * df / dt = not used (as of now), set to 0
* -a(t) df/dtheta = -dL / dtheta = Gradient from a backward pass through the layers of the ODE vertex (i.e. the layers of graph) wrt -epsilon.
*
*
@@ -40,7 +41,7 @@ public class BackpropagateAdjoint implements FirstOrderEquation {
@AllArgsConstructor
public static class GraphInfo {
private final ComputationGraph graph;
- private final NonContiguous1DView realGradients;
+ private final INDArray1DView realGradients;
private final LayerWorkspaceMgr workspaceMgr;
private final boolean truncatedBPTT;
}
@@ -59,7 +60,7 @@ public INDArray calculateDerivative(INDArray zAug, INDArray t, INDArray fzAug) {
augmentedDynamics.updateFrom(zAug);
// Note: Will also update z
- forwardPass.calculateDerivative(augmentedDynamics.z(), t, augmentedDynamics.z());
+ forwardPass.calculateDerivative(augmentedDynamics.z().dup(), t, augmentedDynamics.z());
try (WorkspacesCloseable ws = graphInfo.workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS, ArrayType.ACTIVATION_GRAD)) {
@@ -86,20 +87,14 @@ private List backPropagate(INDArray epsilon) {
final int[] topologicalOrder = graphInfo.graph.topologicalSortOrder();
final GraphVertex[] vertices = graphInfo.graph.getVertices();
+ vertices[topologicalOrder[topologicalOrder.length-1]].setEpsilon(epsilon);
+
List outputEpsilons = new ArrayList<>();
boolean[] setVertexEpsilon = new boolean[topologicalOrder.length]; //If true: already set epsilon for this vertex; later epsilons should be *added* to the existing one, not set
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
GraphVertex current = vertices[topologicalOrder[i]];
- if (current.isOutputVertex()) {
- for (VertexIndices vertexIndices : current.getInputVertices()) {
- final String inputName = vertices[vertexIndices.getVertexIndex()].getVertexName();
- graphInfo.graph.getVertex(inputName).setEpsilon(epsilon);
- }
- continue;
- }
-
if (current.isInputVertex()) {
continue;
}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/FixedStepAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/FixedStepAdjoint.java
new file mode 100644
index 0000000..ebe2fcc
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/FixedStepAdjoint.java
@@ -0,0 +1,30 @@
+package ode.vertex.impl.helper.backward;
+
+import ode.solve.api.FirstOrderSolver;
+import ode.vertex.impl.helper.backward.timegrad.NoMultiStepTimeGrad;
+import ode.vertex.impl.helper.backward.timegrad.NoTimeGrad;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+/**
+ * {@link OdeHelperBackward} with a fixed given sequence of time steps to evaluate the ODE for.
+ *
+ * @author Christian Skarby
+ */
+public class FixedStepAdjoint implements OdeHelperBackward {
+
+ private final OdeHelperBackward helper;
+
+ public FixedStepAdjoint(FirstOrderSolver solver, INDArray time) {
+ if(time.length() > 2) {
+ helper = new MultiStepAdjoint(solver, time, NoMultiStepTimeGrad.factory);
+ } else {
+ helper = new SingleStepAdjoint(solver, time, NoTimeGrad.factory);
+ }
+ }
+
+ @Override
+ public INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars) {
+ return helper.solve(graph, input, miscPars);
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/InputStepAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/InputStepAdjoint.java
new file mode 100644
index 0000000..9e5ded5
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/InputStepAdjoint.java
@@ -0,0 +1,59 @@
+package ode.vertex.impl.helper.backward;
+
+import ode.solve.api.FirstOrderSolver;
+import ode.vertex.impl.helper.backward.timegrad.*;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * {@link OdeHelperBackward} which uses one of the input {@link INDArray}s as the time steps to evaluate the ODE for
+ *
+ * @author Christian Skarby
+ */
+public class InputStepAdjoint implements OdeHelperBackward {
+
+ private final FirstOrderSolver solver;
+ private final int timeIndex;
+ private final boolean needTimeGradient;
+
+ public InputStepAdjoint(FirstOrderSolver solver, int timeIndex, boolean needTimeGradient) {
+ this.solver = solver;
+ this.timeIndex = timeIndex;
+ this.needTimeGradient = needTimeGradient;
+ }
+
+
+ @Override
+ public INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars) {
+ final INDArray time = input.getLastInputs()[timeIndex];
+ final List notTimeInputs = new ArrayList<>();
+ for (int i = 0; i < input.getLastInputs().length; i++) {
+ if (i != timeIndex) {
+ notTimeInputs.add(input.getLastInputs()[i]);
+ }
+ }
+ final InputArrays newInput = new InputArrays(
+ notTimeInputs.toArray(new INDArray[0]),
+ input.getLastOutput(),
+ input.getLossGradient(),
+ input.getRealGradientView()
+ );
+
+ if (time.length() > 2) {
+ final MultiStepTimeGrad.Factory factory = needTimeGradient ?
+ new CalcMultiStepTimeGrad.Factory(time, timeIndex) :
+ new ZeroMultiStepTimeGrad.Factory(time, timeIndex);
+
+ return new MultiStepAdjoint(solver, time, factory).solve(graph, newInput, miscPars);
+ }
+
+ final TimeGrad.Factory factory = needTimeGradient ?
+ new CalcTimeGrad.Factory(input.getLossGradient(), timeIndex) :
+ new ZeroTimeGrad.Factory(timeIndex);
+
+ return new SingleStepAdjoint(solver, time, factory).solve(graph, newInput, miscPars);
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/MultiStepAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/MultiStepAdjoint.java
new file mode 100644
index 0000000..4d93c8c
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/MultiStepAdjoint.java
@@ -0,0 +1,124 @@
+package ode.vertex.impl.helper.backward;
+
+import ode.solve.api.FirstOrderSolver;
+import ode.vertex.impl.helper.backward.timegrad.MultiStepTimeGrad;
+import ode.vertex.impl.helper.backward.timegrad.TimeGrad;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.ops.transforms.Transforms;
+
+import java.util.Arrays;
+
+/**
+ * {@link OdeHelperBackward} using the adjoint method capable of handling multiple time steps. Gradients will be provided
+ * for the last time step only.
+ *
+ * @author Christian Skarby
+ */
+public class MultiStepAdjoint implements OdeHelperBackward {
+
+ private final FirstOrderSolver solver;
+ private final INDArray time;
+ private final MultiStepTimeGrad.Factory timeGradFactory;
+
+ public MultiStepAdjoint(FirstOrderSolver solver, INDArray time, MultiStepTimeGrad.Factory timeGradFactory) {
+ this.solver = solver;
+ this.time = time;
+ this.timeGradFactory = timeGradFactory;
+
+ if(time.length() <= 2 || !time.isVector()) {
+ throw new IllegalArgumentException("time must be a vector of size > 2! Was of shape: " + Arrays.toString(time.shape())+ "!");
+ }
+ assertSorted(time);
+ }
+
+ private void assertSorted(INDArray time) {
+ int signDiffSum = 0;
+ for(int i = 0; i < time.length()-1; i++) {
+ signDiffSum += Transforms.sign(time.getScalar(i).sub(time.getScalar(i+1))).getDouble(0);
+ }
+
+ if(Math.abs(signDiffSum)+1 != time.length()) {
+ throw new IllegalArgumentException("Time must be in ascending or descending order! Got: " + time);
+ }
+ }
+
+ @Override
+ public INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars) {
+ final INDArray zt = alignInShapeToTimeFirst(input.getLastOutput());
+ final INDArray dL_dzt = alignInShapeToTimeFirst(input.getLossGradient());
+
+ assertSizeVsTime(zt);
+ assertSizeVsTime(dL_dzt);
+
+ final INDArrayIndex[] timeIndexer = createIndexer(time);
+ timeIndexer[1] = NDArrayIndex.interval(time.length()-2, time.length());
+ final INDArrayIndex[] ztIndexer = createIndexer(input.getLastOutput());
+ final INDArrayIndex[] dL_dztIndexer= createIndexer(input.getLossGradient());
+
+ INDArray[] gradients = null;
+ final MultiStepTimeGrad timeGrad = timeGradFactory.create();
+
+ // Go backwards in time
+ for (int step = (int)time.length()-1; step > 0; step--) {
+ final INDArray ztStep = getStep(ztIndexer,zt, step);
+
+ final INDArray dL_dztStep = getStep(dL_dztIndexer, dL_dzt, step);
+ final TimeGrad.Factory stepTimeGradFactory = timeGrad.createSingleStepFactory(dL_dztStep.dup());
+
+ timeGrad.prepareStep(gradients, dL_dztStep);
+
+ final InputArrays stepInput = new InputArrays(
+ input.getLastInputs(),
+ ztStep,
+ dL_dztStep,
+ input.getRealGradientView()
+ );
+ timeIndexer[1] = NDArrayIndex.interval(step - 1, step+1);
+
+ final OdeHelperBackward stepSolve = new SingleStepAdjoint(solver, time.get(timeIndexer), stepTimeGradFactory);
+ gradients = stepSolve.solve(graph, stepInput, miscPars);
+
+ timeGrad.updateStep(timeIndexer, gradients);
+ }
+
+ gradients = timeGrad.updateLastStep(timeIndexer, gradients, getStep(dL_dztIndexer, dL_dzt, 0));
+
+ return gradients;
+ }
+
+ private void assertSizeVsTime(INDArray array) {
+ if(array.size(0) != time.length()) {
+ throw new IllegalArgumentException("Must have same number of in first dimension as there are time steps! Input: "
+ + array.size(0) + ", time: " + time.length());
+ }
+ }
+
+ private INDArrayIndex[] createIndexer(INDArray array) {
+ final INDArrayIndex[] indexer = new INDArrayIndex[array.rank()];
+ for(int dim = 0; dim < indexer.length; dim++) {
+ indexer[dim] = NDArrayIndex.all();
+ }
+ return indexer;
+ }
+
+ private INDArray getStep(INDArrayIndex[] indexer, INDArray array, int step) {
+ indexer[0] = NDArrayIndex.point(step);
+ return array.get(indexer);
+ }
+
+ private INDArray alignInShapeToTimeFirst(INDArray array) {
+
+ final long[] shape = array.shape();
+ switch (shape.length) {
+ case 3: // Assume recurrent output
+ return array.permute(2,0,1);
+ case 5: // Assume conv 3D output
+ return array.permute(1,0,2,3,4);
+ // Should not happen as conf throws exception for other types
+ default: throw new UnsupportedOperationException("Rank not supported: " + array.rank());
+ }
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/OdeHelperBackward.java b/src/main/java/ode/vertex/impl/helper/backward/OdeHelperBackward.java
new file mode 100644
index 0000000..c35af70
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/OdeHelperBackward.java
@@ -0,0 +1,55 @@
+package ode.vertex.impl.helper.backward;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import ode.vertex.impl.gradview.INDArray1DView;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+/**
+ * Helps with input/output handling when solving ODEs inside a neural network
+ *
+ * @author Christian Skarby
+ */
+public interface OdeHelperBackward {
+
+ /**
+ * Input arrays needed to do backward pass. Has the following definitions:
+ * {@code lossGradient}: Gradient w.r.t loss from subsequent layers (typically called epsilon in dl4j)
+ * {@code lastOutput}: Last computed output from a forward pass used to calculate the loss gradient
+ * {@code realGradientView}: View of all array elements which are actually gradients in the given
+ * {@link ComputationGraph}s gradient view array. Notable exceptions (i.e. things labeled as gradients which are not
+ * are running mean and variance of Batch Normalization layers.
+ */
+ @Getter @AllArgsConstructor
+ class InputArrays {
+
+ private final INDArray[] lastInputs;
+ private final INDArray lastOutput;
+ private final INDArray lossGradient;
+ private final INDArray1DView realGradientView;
+ }
+
+ /**
+ * Misc parameters needed to jump through the hoops of doing back propagation
+ */
+ @Getter @AllArgsConstructor
+ class MiscPar {
+ private final boolean useTruncatedBackPropTroughTime;
+ private final LayerWorkspaceMgr wsMgr;
+ }
+
+ /**
+ * Return the solution to the ODE when assuming that a backwards pass through the layers of the given graph is
+ * the derivative of the sought function. Note that parameter gradient is set in given graph so it is not returned.
+ *
+ * @param graph Graph of layers to do backwards pass through
+ * @param input Input arrays
+ * @param miscPars Misc parameters needed to jump through the hoops of doing back propagation
+ *
+ * @return Loss gradients (a.k.a epsilon in dl4j) w.r.t last input from previous layers. Note that parameter gradients
+ * are set in graph and can be accessed through graph.getGradientsViewArray()
+ */
+ INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars);
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/SingleStepAdjoint.java b/src/main/java/ode/vertex/impl/helper/backward/SingleStepAdjoint.java
new file mode 100644
index 0000000..84f6097
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/SingleStepAdjoint.java
@@ -0,0 +1,88 @@
+package ode.vertex.impl.helper.backward;
+
+import ode.solve.api.FirstOrderEquation;
+import ode.solve.api.FirstOrderSolver;
+import ode.vertex.impl.gradview.INDArray1DView;
+import ode.vertex.impl.helper.NDArrayIndexAccumulator;
+import ode.vertex.impl.helper.backward.timegrad.TimeGrad;
+import ode.vertex.impl.helper.forward.ForwardPass;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import java.util.Arrays;
+
+/**
+ * {@link OdeHelperBackward} using the adjoint method capable of handling a single time step. Gradients for time steps
+ * will only be provided if required.
+ *
+ * @author Christian Skarby
+ */
+public class SingleStepAdjoint implements OdeHelperBackward {
+
+ private final FirstOrderSolver solver;
+ private final INDArray time;
+ private final TimeGrad.Factory timeGradFactory;
+
+ public SingleStepAdjoint(FirstOrderSolver solver, INDArray time, TimeGrad.Factory timeGradFactory) {
+ this.solver = solver;
+ this.time = time;
+ this.timeGradFactory = timeGradFactory;
+ if (time.length() != 2 && time.rank() != 1) {
+ throw new IllegalArgumentException("time must be a vector with two elements! Was of shape: " + Arrays.toString(time.shape()) + "!");
+ }
+ }
+
+ @Override
+ public INDArray[] solve(ComputationGraph graph, InputArrays input, MiscPar miscPars) {
+
+ // Create augmented dynamics for adjoint method
+ // Initialization: S0:
+ // z(t1) = lastoutput
+ // a(t) = -dL/d(z(t1)) = -epsilon from next layer (i.e getEpsilon). Use last row if more than one timestep
+ // parameters = zeros
+ // dL/dt1 = -dL / dz(t1) dot dz(t1) / dt1
+
+ final INDArray dL_dzt1 = input.getLossGradient();
+ final INDArray zt1 = input.getLastOutput();
+ final INDArray1DView realParamGrads = input.getRealGradientView();
+
+ final FirstOrderEquation forward = new ForwardPass(graph,
+ miscPars.getWsMgr(),
+ true, // Always use training as batch norm running mean and var become messed up otherwise. Same effect seen in original pytorch repo.
+ input.getLastInputs());
+
+ final TimeGrad timeGrad = timeGradFactory.create();
+ final INDArray dL_dt1 = timeGrad.calcTimeGradT1(forward, zt1, time);
+
+ final INDArray zAug = Nd4j.create(1, zt1.length() + dL_dzt1.length() + graph.numParams() + dL_dt1.length());
+ final INDArray paramAdj = Nd4j.zeros(realParamGrads.length());
+ realParamGrads.assignTo(paramAdj);
+
+ final NDArrayIndexAccumulator accumulator = new NDArrayIndexAccumulator(zAug);
+ accumulator.increment(zt1.reshape(new long[]{1, zt1.length()}))
+ .increment(dL_dzt1.reshape(new long[]{1, dL_dzt1.length()}))
+ .increment(paramAdj.reshape(new long[]{1, paramAdj.length()}))
+ .increment(dL_dt1);
+
+ final AugmentedDynamics augmentedDynamics = new AugmentedDynamics(
+ zAug,
+ dL_dzt1.shape(),
+ new long[]{realParamGrads.length()},
+ dL_dt1.shape());
+
+ final FirstOrderEquation equation = new BackpropagateAdjoint(
+ augmentedDynamics,
+ forward,
+ new BackpropagateAdjoint.GraphInfo(graph, realParamGrads, miscPars.getWsMgr(), miscPars.isUseTruncatedBackPropTroughTime())
+ );
+
+ INDArray augAns = solver.integrate(equation, Nd4j.reverse(time.dup()), zAug, zAug.dup());
+
+ augmentedDynamics.updateFrom(augAns);
+
+ realParamGrads.assignFrom(augmentedDynamics.paramAdjoint());
+
+ return timeGrad.createLossGradient(augmentedDynamics.zAdjoint(), augmentedDynamics.tAdjoint());
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcMultiStepTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcMultiStepTimeGrad.java
new file mode 100644
index 0000000..1728fd0
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcMultiStepTimeGrad.java
@@ -0,0 +1,70 @@
+package ode.vertex.impl.helper.backward.timegrad;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+
+/**
+ * Calculates a time gradient for multiple time steps.
+ *
+ * @author Christian Skarby
+ */
+public class CalcMultiStepTimeGrad implements MultiStepTimeGrad {
+
+ private final INDArray timeGradient;
+ private final int timeIndex;
+
+ private final INDArray lastTime = Nd4j.scalar(0);
+
+ /**
+ * Factory for this class.
+ */
+ public static class Factory implements MultiStepTimeGrad.Factory {
+
+ private final INDArray time;
+ private final int timeIndex;
+
+ public Factory(INDArray time, int timeIndex) {
+ this.time = time;
+ this.timeIndex = timeIndex;
+ }
+
+ @Override
+ public MultiStepTimeGrad create() {
+ return new CalcMultiStepTimeGrad(time, timeIndex);
+ }
+ }
+
+ public CalcMultiStepTimeGrad(INDArray time, int timeIndex) {
+ timeGradient = Nd4j.zeros(time.shape());
+ this.timeIndex = timeIndex;
+ }
+
+ @Override
+ public void prepareStep(INDArray[] lastGradients, INDArray dL_dzt) {
+ if(lastGradients == null) return;
+
+ dL_dzt.addi(lastGradients[(timeIndex+1) % lastGradients.length]);
+ }
+
+ @Override
+ public void updateStep(INDArrayIndex[] timeIndexer, INDArray[] gradients) {
+ timeGradient.put(timeIndexer, gradients[timeIndex]);
+ lastTime.subi(timeGradient.get(timeIndexer).getScalar(0));
+ }
+
+ @Override
+ public INDArray[] updateLastStep(INDArrayIndex[] timeIndexer, INDArray[] gradients, INDArray dL_dzt0) {
+ timeIndexer[1] = NDArrayIndex.point(0);
+ timeGradient.put(timeIndexer, lastTime);
+ gradients[timeIndex] = timeGradient;
+ gradients[(timeIndex + 1) % gradients.length].addi(dL_dzt0);
+ return gradients;
+ }
+
+ @Override
+ public TimeGrad.Factory createSingleStepFactory(INDArray dL_dzt1_time) {
+ return new CalcTimeGrad.Factory(dL_dzt1_time, timeIndex);
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcTimeGrad.java
new file mode 100644
index 0000000..67439e3
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/CalcTimeGrad.java
@@ -0,0 +1,67 @@
+package ode.vertex.impl.helper.backward.timegrad;
+
+import ode.solve.api.FirstOrderEquation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+/**
+ * Calculates a time gradient for a single time step.
+ *
+ * @author Christian Skarby
+ */
+public class CalcTimeGrad implements TimeGrad {
+
+ private final INDArray dL_dzt1_time;
+ private final int timeIndex;
+
+ private INDArray dL_dt1;
+
+ /**
+ * Factory for this class
+ */
+ public static class Factory implements TimeGrad.Factory {
+
+ private final INDArray dL_dzt1_time;
+ private final int timeIndex;
+
+ public Factory(INDArray dL_dzt1_time, int timeIndex) {
+ this.dL_dzt1_time = dL_dzt1_time;
+ this.timeIndex = timeIndex;
+ }
+
+ @Override
+ public TimeGrad create() {
+ return new CalcTimeGrad(dL_dzt1_time, timeIndex);
+ }
+ }
+
+ public CalcTimeGrad(INDArray dL_dzt1_time, int timeIndex) {
+ this.dL_dzt1_time = dL_dzt1_time;
+ this.timeIndex = timeIndex;
+ }
+
+ @Override
+ public INDArray calcTimeGradT1(FirstOrderEquation equation, INDArray zt1, INDArray time) {
+ final INDArray dzt1_dt1 = equation.calculateDerivative(zt1, time.getColumn(1), zt1.dup());
+
+ this.dL_dt1 = dL_dzt1_time.reshape(1, dzt1_dt1.length())
+ .mmul(dzt1_dt1.reshape(dzt1_dt1.length(), 1));
+ return dL_dt1;
+ }
+
+ @Override
+ public INDArray[] createLossGradient(INDArray zAdjoint, INDArray tAdjoint) {
+ if(dL_dt1 == null) {
+ throw new IllegalStateException("Must compute dL / dt1 before creating loss gradient!");
+ }
+
+ final INDArray[] epsilons = new INDArray[2];
+ for (int i = 0; i < epsilons.length; i++) {
+ if (i != timeIndex) {
+ epsilons[i] = zAdjoint;
+ }
+ }
+ epsilons[timeIndex] = Nd4j.hstack(dL_dt1, tAdjoint);
+ return epsilons;
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/MultiStepTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/MultiStepTimeGrad.java
new file mode 100644
index 0000000..4252f7e
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/MultiStepTimeGrad.java
@@ -0,0 +1,55 @@
+package ode.vertex.impl.helper.backward.timegrad;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+
+/**
+ * Interface for handling time gradient w.r.t loss when multiple time steps are used. Main use case is to be able to
+ * not calculate time gradients when not needed.
+ *
+ * @author Christian Skarby
+ */
+public interface MultiStepTimeGrad {
+
+ /**
+ * Factory for MultiStepTimeGrads
+ */
+ interface Factory {
+ /**
+ * Return a {@link MultiStepTimeGrad} instance
+ * @return a {@link MultiStepTimeGrad} instance
+ */
+ MultiStepTimeGrad create();
+ }
+
+ /**
+ * Update dL_dzt based on last steps gradients
+ * @param lastGradients gradients from last time step
+ * @param dL_dzt Loss gradient for z(t) for current time step
+ */
+ void prepareStep(INDArray[] lastGradients, INDArray dL_dzt);
+
+ /**
+ * Update after a single step backwards
+ * @param timeIndexer Points to the current time step
+ * @param gradients Current gradients
+ */
+ void updateStep(INDArrayIndex[] timeIndexer, INDArray[] gradients);
+
+ /**
+ * Update after the last step backwards has been performed. Note that this will update the input gradient
+ * @param timeIndexer Points to the current time step
+ * @param gradients Current gradients. Might be updated as a result
+ * @param dL_dzt0 Loss gradient for z(t0).
+ * @return Updated gradients
+ */
+ INDArray[] updateLastStep(INDArrayIndex[] timeIndexer, INDArray[] gradients, INDArray dL_dzt0);
+
+ /**
+ * Create an appropriate {@link TimeGrad.Factory}
+ * @param dL_dzt1_time Loss gradient for calculation of time gradient
+ * @return a {@link TimeGrad.Factory}
+ */
+ TimeGrad.Factory createSingleStepFactory(INDArray dL_dzt1_time);
+
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoMultiStepTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoMultiStepTimeGrad.java
new file mode 100644
index 0000000..250586c
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoMultiStepTimeGrad.java
@@ -0,0 +1,49 @@
+package ode.vertex.impl.helper.backward.timegrad;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+
+/**
+ * Does not calculcate any time gradients
+ *
+ * @author Christian Skarby
+ */
+public class NoMultiStepTimeGrad implements MultiStepTimeGrad {
+
+ // Singleton because stateless
+ private final static MultiStepTimeGrad noTimeGrad = new NoMultiStepTimeGrad();
+
+ /**
+ * Factory for this class
+ */
+ public static MultiStepTimeGrad.Factory factory = new MultiStepTimeGrad.Factory() {
+
+ @Override
+ public MultiStepTimeGrad create() {
+ return noTimeGrad;
+ }
+ };
+
+ @Override
+ public void prepareStep(INDArray[] lastGradients, INDArray dL_dzt) {
+ if(lastGradients == null) return;
+
+ dL_dzt.addi(lastGradients[0]);
+ }
+
+ @Override
+ public void updateStep(INDArrayIndex[] timeIndexer, INDArray[] gradients) {
+ // Do nothing
+ }
+
+ @Override
+ public INDArray[] updateLastStep(INDArrayIndex[] timeIndexer, INDArray[] gradients, INDArray dL_dzt) {
+ gradients[0].addi(dL_dzt);
+ return gradients;
+ }
+
+ @Override
+ public TimeGrad.Factory createSingleStepFactory(INDArray dL_dzt1_time) {
+ return NoTimeGrad.factory;
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoTimeGrad.java
new file mode 100644
index 0000000..b5e5f64
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/NoTimeGrad.java
@@ -0,0 +1,29 @@
+package ode.vertex.impl.helper.backward.timegrad;
+
+import ode.solve.api.FirstOrderEquation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+public class NoTimeGrad implements TimeGrad {
+
+ private final static TimeGrad noTimeGrad = new NoTimeGrad();
+ public static TimeGrad.Factory factory = new TimeGrad.Factory() {
+
+ @Override
+ public TimeGrad create() {
+ // Singleton because stateless
+ return noTimeGrad;
+ }
+ };
+
+
+ @Override
+ public INDArray calcTimeGradT1(FirstOrderEquation equation, INDArray zt1, INDArray time) {
+ return Nd4j.empty();
+ }
+
+ @Override
+ public INDArray[] createLossGradient(INDArray zAdjoint, INDArray tAdjoint) {
+ return new INDArray[] {zAdjoint};
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/TimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/TimeGrad.java
new file mode 100644
index 0000000..190b1e9
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/TimeGrad.java
@@ -0,0 +1,42 @@
+package ode.vertex.impl.helper.backward.timegrad;
+
+import ode.solve.api.FirstOrderEquation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+/**
+ * Interface for handling time gradient w.r.t loss for a single time step. Main use case is to be able to
+ * not calculate time gradients when not needed.
+ *
+ * @author Christian Skarby
+ */
+public interface TimeGrad {
+
+ /**
+ * Factory for TimeGrads
+ */
+ interface Factory {
+ /**
+ * Return a {@link TimeGrad} instance
+ * @return a {@link TimeGrad} instance
+ */
+ TimeGrad create();
+ }
+
+ /**
+ * Calculate time gradient for the last time point (t1)
+ * @param equation Calculates dz(t1) / d(t1)
+ * @param zt1 Value of z(t1)
+ * @param time
+ * @return dLoss / dt1 or empty if no time gradient needed
+ */
+ INDArray calcTimeGradT1(FirstOrderEquation equation, INDArray zt1, INDArray time);
+
+ /**
+ * Create loss gradient (a.k.a epsilons in dl4j) from adjoint state.
+ * @param zAdjoint dL / dz(t0)
+ * @param tAdjoint dL / dt0
+ * @return Array of required loss gradients
+ */
+ INDArray[] createLossGradient(INDArray zAdjoint, INDArray tAdjoint);
+
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroMultiStepTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroMultiStepTimeGrad.java
new file mode 100644
index 0000000..b5a9a3f
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroMultiStepTimeGrad.java
@@ -0,0 +1,70 @@
+package ode.vertex.impl.helper.backward.timegrad;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.INDArrayIndex;
+
+/**
+ * {@link MultiStepTimeGrad} which does not compute and time gradients, but returns a dummy gradient so that array sizes
+ * are correct
+ *
+ * @author Christian Skarby
+ */
+public class ZeroMultiStepTimeGrad implements MultiStepTimeGrad {
+
+ private final int timeIndex;
+ private final INDArray time;
+
+ /**
+ * Factory for this class.
+ */
+ public static class Factory implements MultiStepTimeGrad.Factory {
+
+ private final INDArray time;
+ private final int timeIndex;
+
+ public Factory(INDArray time, int timeIndex) {
+ this.time = time;
+ this.timeIndex = timeIndex;
+ }
+
+ @Override
+ public MultiStepTimeGrad create() {
+ return new ZeroMultiStepTimeGrad(time, timeIndex);
+ }
+ }
+
+ public ZeroMultiStepTimeGrad(INDArray time, int timeIndex) {
+ this.time = time;
+ this.timeIndex = timeIndex;
+ }
+
+ @Override
+ public void prepareStep(INDArray[] lastGradients, INDArray dL_dzt) {
+ if(lastGradients == null) return;
+
+ dL_dzt.addi(lastGradients[(timeIndex+1) % lastGradients.length]);
+ }
+
+ @Override
+ public void updateStep(INDArrayIndex[] timeIndexer, INDArray[] gradients) {
+ // Do nothing
+ }
+
+ @Override
+ public INDArray[] updateLastStep(INDArrayIndex[] timeIndexer, INDArray[] gradients, INDArray dL_dzt) {
+
+ final INDArray[] toRet = new INDArray[2];
+ toRet[timeIndex] = Nd4j.zerosLike(time);
+ final int notTimeIndex = (timeIndex + 1) % toRet.length;
+ toRet[notTimeIndex] = gradients[notTimeIndex % gradients.length];
+ gradients[notTimeIndex % gradients.length].addi(dL_dzt);
+
+ return toRet;
+ }
+
+ @Override
+ public TimeGrad.Factory createSingleStepFactory(INDArray dL_dzt1_time) {
+ return NoTimeGrad.factory;
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroTimeGrad.java b/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroTimeGrad.java
new file mode 100644
index 0000000..19ad527
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/backward/timegrad/ZeroTimeGrad.java
@@ -0,0 +1,55 @@
+package ode.vertex.impl.helper.backward.timegrad;
+
+import ode.solve.api.FirstOrderEquation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+/**
+ * {@link TimeGrad} which does not compute and time gradients, but returns a dummy gradient so that array sizes
+ * are correct
+ *
+ * @author Christian Skarby
+ */
+public class ZeroTimeGrad implements TimeGrad {
+
+ private final int timeIndex;
+
+ /**
+ * Factory for this class
+ */
+ public static class Factory implements TimeGrad.Factory {
+
+ private final int timeIndex;
+
+ public Factory(int timeIndex) {
+ this.timeIndex = timeIndex;
+ }
+
+ @Override
+ public TimeGrad create() {
+ return new ZeroTimeGrad(timeIndex);
+ }
+ }
+
+
+ public ZeroTimeGrad(int timeIndex) {
+ this.timeIndex = timeIndex;
+ }
+
+ @Override
+ public INDArray calcTimeGradT1(FirstOrderEquation equation, INDArray zt1, INDArray time) {
+ return Nd4j.empty();
+ }
+
+ @Override
+ public INDArray[] createLossGradient(INDArray zAdjoint, INDArray tAdjoint) {
+ final INDArray[] epsilons = new INDArray[2];
+ for (int i = 0; i < epsilons.length; i++) {
+ if (i != timeIndex) {
+ epsilons[i] = zAdjoint;
+ }
+ }
+ epsilons[timeIndex] = Nd4j.zeros(1,2);
+ return epsilons;
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/forward/FixedStep.java b/src/main/java/ode/vertex/impl/helper/forward/FixedStep.java
new file mode 100644
index 0000000..35831ba
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/forward/FixedStep.java
@@ -0,0 +1,35 @@
+package ode.vertex.impl.helper.forward;
+
+import ode.solve.api.FirstOrderMultiStepSolver;
+import ode.solve.api.FirstOrderSolver;
+import ode.solve.impl.InterpolatingMultiStepSolver;
+import ode.solve.impl.SingleSteppingMultiStepSolver;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+/**
+ * {@link OdeHelperForward} with a fixed given sequence of time steps to evaluate the ODE for.
+ *
+ * @author Christian Skarby
+ */
+public class FixedStep implements OdeHelperForward {
+
+ private final OdeHelperForward helper;
+
+ public FixedStep(FirstOrderSolver solver, INDArray time, boolean interpolateIfMultiStep) {
+ if(time.length() > 2) {
+ final FirstOrderMultiStepSolver multiStepSolver = interpolateIfMultiStep ?
+ new InterpolatingMultiStepSolver(solver) :
+ new SingleSteppingMultiStepSolver(solver);
+ helper = new MultiStep(multiStepSolver, time);
+ } else {
+ helper = new SingleStep(solver, time);
+ }
+ }
+
+ @Override
+ public INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs) {
+ return helper.solve(graph, wsMgr, inputs);
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/ForwardPass.java b/src/main/java/ode/vertex/impl/helper/forward/ForwardPass.java
similarity index 90%
rename from src/main/java/ode/vertex/impl/ForwardPass.java
rename to src/main/java/ode/vertex/impl/helper/forward/ForwardPass.java
index f36d9db..0d35252 100644
--- a/src/main/java/ode/vertex/impl/ForwardPass.java
+++ b/src/main/java/ode/vertex/impl/helper/forward/ForwardPass.java
@@ -1,6 +1,7 @@
-package ode.vertex.impl;
+package ode.vertex.impl.helper.forward;
import ode.solve.api.FirstOrderEquation;
+import ode.vertex.impl.helper.NDArrayIndexAccumulator;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
@@ -34,12 +35,12 @@ public ForwardPass(ComputationGraph graph,
this.graph = graph;
this.workspaceMgr = workspaceMgr;
this.training = training;
- this.inputs = startInputs;
+ this.inputs = startInputs.clone();
}
-
@Override
public INDArray calculateDerivative(INDArray y, INDArray t, INDArray fy) {
+ graph.getConfiguration().setIterationCount(graph.getIterationCount() + 1);
try (WorkspacesCloseable ws = enterIfNotOpen(ArrayType.ACTIVATIONS)) {
setInputsFromFlat(y);
evaluate(inputs, fy);
@@ -64,7 +65,7 @@ private void setInputsFromFlat(INDArray flatArray) {
INDArray input = inputs[i];
final INDArray z = flatArray.get(NDArrayIndex.interval(lastInd, lastInd + input.length()));
lastInd += input.length();
- inputs[i].assign(z.reshape(input.shape()));
+ inputs[i] = z.reshape(input.shape());
}
}
@@ -81,19 +82,17 @@ private void evaluate(INDArray[] inputs, INDArray output) {
VertexIndices[] inputsTo = current.getOutputVertices();
- INDArray out = null;
+ final INDArray out;
if (current.isInputVertex()) {
out = inputs[vIdx];
- } else if (current.isOutputVertex()) {
- for (INDArray outArr : current.getInputs()) {
- outputAccum.increment(outArr);
- }
} else {
//Standard feed-forward case
out = current.doForward(training, workspaceMgr);
}
- if (inputsTo != null) { //Output vertices may not input to any other vertices
+ if (inputsTo == null) { //Output vertices may not input to any other vertices
+ outputAccum.increment(out);
+ } else {
for (VertexIndices v : inputsTo) {
//Note that we don't have to do anything special here: the activations are always detached in
// this method
diff --git a/src/main/java/ode/vertex/impl/helper/forward/InputStep.java b/src/main/java/ode/vertex/impl/helper/forward/InputStep.java
new file mode 100644
index 0000000..09cfee9
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/forward/InputStep.java
@@ -0,0 +1,42 @@
+package ode.vertex.impl.helper.forward;
+
+import ode.solve.api.FirstOrderSolver;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * {@link OdeHelperForward} which uses one of the input {@link INDArray}s as the time steps to evaluate the ODE for
+ *
+ * @author Christian Skarby
+ */
+public class InputStep implements OdeHelperForward {
+
+ private final FirstOrderSolver solver;
+ private final int timeInputIndex;
+ private final boolean interpolateIfMultiStep;
+
+ public InputStep(FirstOrderSolver solver, int timeInputIndex, boolean interpolateIfMultiStep) {
+ this.solver = solver;
+ this.timeInputIndex = timeInputIndex;
+ this.interpolateIfMultiStep = interpolateIfMultiStep;
+ }
+
+ @Override
+ public INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs) {
+ final List notTimeInputs = new ArrayList<>();
+ for (int i = 0; i < inputs.length; i++) {
+ if (i != timeInputIndex) {
+ notTimeInputs.add(inputs[i]);
+ }
+ }
+ return new FixedStep(
+ solver,
+ inputs[timeInputIndex],
+ interpolateIfMultiStep)
+ .solve(graph, wsMgr, notTimeInputs.toArray(new INDArray[0]));
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/forward/MultiStep.java b/src/main/java/ode/vertex/impl/helper/forward/MultiStep.java
new file mode 100644
index 0000000..2bf8c69
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/forward/MultiStep.java
@@ -0,0 +1,64 @@
+package ode.vertex.impl.helper.forward;
+
+import com.google.common.primitives.Longs;
+import ode.solve.api.FirstOrderEquation;
+import ode.solve.api.FirstOrderMultiStepSolver;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import java.util.Arrays;
+
+/**
+ * {@link OdeHelperForward} capable of handling multiple time steps
+ *
+ * @author Christian Skarby
+ */
+public class MultiStep implements OdeHelperForward {
+
+ private final FirstOrderMultiStepSolver solver;
+ private final INDArray time;
+
+ public MultiStep(FirstOrderMultiStepSolver solver, INDArray time) {
+ this.solver = solver;
+ this.time = time;
+ if(time.length() <= 2 || !time.isVector()) {
+ throw new IllegalArgumentException("time must be a vector of size > 2! Was of shape: " + Arrays.toString(time.shape())+ "!");
+ }
+ }
+
+ @Override
+ public INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs) {
+ if (inputs.length != 1) {
+ throw new IllegalArgumentException("Only single input supported!");
+ }
+
+ final FirstOrderEquation equation = new ForwardPass(
+ graph,
+ wsMgr,
+ true,
+ inputs
+ );
+
+ final INDArray z0 = inputs[0].dup();
+ final INDArray zt = Nd4j.createUninitialized(Longs.concat(new long[]{time.length() - 1}, z0.shape()));
+ solver.integrate(equation, time, inputs[0].dup(), zt);
+
+ return alignOutShape(zt, z0);
+ }
+
+
+ private INDArray alignOutShape(INDArray zt, INDArray z0) {
+ final long[] shape = zt.shape();
+ switch (shape.length) {
+ case 3: // Assume recurrent output
+ return Nd4j.concat(0, z0.reshape(1, shape[1], shape[2]), zt).permute(1, 2, 0);
+ case 5: // Assume conv 3D output
+ return Nd4j.concat(0, z0.reshape(1, shape[1], shape[2], shape[3], shape[4]), zt).permute(1, 0, 2, 3, 4);
+ // Should not happen as conf throws exception for other types
+ default:
+ throw new UnsupportedOperationException("Rank not supported: " + zt.rank());
+ }
+ }
+}
diff --git a/src/main/java/ode/vertex/impl/helper/forward/OdeHelperForward.java b/src/main/java/ode/vertex/impl/helper/forward/OdeHelperForward.java
new file mode 100644
index 0000000..9b8309e
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/forward/OdeHelperForward.java
@@ -0,0 +1,23 @@
+package ode.vertex.impl.helper.forward;
+
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+/**
+ * Helps with input/output handling when solving ODEs inside a neural network
+ *
+ * @author Christian Skarby
+ */
+public interface OdeHelperForward {
+
+ /**
+ * Return the solution to the ODE when assuming that a forward pass through the layers of the given graph is
+ * the derivative of the sought function.
+ * @param graph Graph of layers to do forward pass through
+ * @param wsMgr To handle workspaces for newly created arrays
+ * @param inputs Inputs to vertex, typically activations from previous layers
+ * @return an {@link INDArray} with the solution to the ODE
+ */
+ INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs);
+}
diff --git a/src/main/java/ode/vertex/impl/helper/forward/SingleStep.java b/src/main/java/ode/vertex/impl/helper/forward/SingleStep.java
new file mode 100644
index 0000000..837911b
--- /dev/null
+++ b/src/main/java/ode/vertex/impl/helper/forward/SingleStep.java
@@ -0,0 +1,50 @@
+package ode.vertex.impl.helper.forward;
+
+
+import ode.solve.api.FirstOrderEquation;
+import ode.solve.api.FirstOrderSolver;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import java.util.Arrays;
+
+/**
+ * Simple {@link OdeHelperForward} capable of only a single time step. Main difference compared to using
+ * {@link MultiStep} is that the latter will return output (zt) which includes the first step (z0)
+ * as well.
+ */
+public class SingleStep implements OdeHelperForward {
+
+ private final FirstOrderSolver solver;
+ private final INDArray time;
+
+ public SingleStep(FirstOrderSolver solver, INDArray time) {
+ this.solver = solver;
+ this.time = time;
+ if(time.length() != 2 && time.rank() != 1) {
+ throw new IllegalArgumentException("time must be a vector with two elements! Was of shape: " + Arrays.toString(time.shape())+ "!");
+ }
+ }
+
+ @Override
+ public INDArray solve(ComputationGraph graph, LayerWorkspaceMgr wsMgr, INDArray[] inputs) {
+ if (inputs.length != 1) {
+ throw new IllegalArgumentException("Only single input supported!");
+ }
+
+ final FirstOrderEquation equation = new ForwardPass(
+ graph,
+ wsMgr,
+ true, // Always use training as batch norm running mean and var become messed up otherwise. Same effect seen in original pytorch repo.
+ inputs
+ );
+
+ final INDArray z0 = inputs[0];
+ final INDArray zt = Nd4j.createUninitialized(z0.shape());
+ solver.integrate(equation, time, z0, zt);
+
+ return zt;
+ }
+}
diff --git a/src/main/java/util/listen/step/Mask.java b/src/main/java/util/listen/step/Mask.java
index e438fb9..cd3d9e0 100644
--- a/src/main/java/util/listen/step/Mask.java
+++ b/src/main/java/util/listen/step/Mask.java
@@ -1,6 +1,7 @@
package util.listen.step;
import ode.solve.api.StepListener;
+import ode.solve.impl.util.SolverState;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
@@ -61,9 +62,9 @@ public void begin(INDArray t, INDArray y0) {
}
@Override
- public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) {
+ public void step(SolverState solverState, INDArray step, INDArray error) {
if(mask) {
- listener.step(currTime, step, error, y);
+ listener.step(solverState, step, error);
}
}
diff --git a/src/main/java/util/listen/step/StepCounter.java b/src/main/java/util/listen/step/StepCounter.java
index e9c2df7..f13e99f 100644
--- a/src/main/java/util/listen/step/StepCounter.java
+++ b/src/main/java/util/listen/step/StepCounter.java
@@ -1,6 +1,7 @@
package util.listen.step;
import ode.solve.api.StepListener;
+import ode.solve.impl.util.SolverState;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -47,7 +48,7 @@ public void begin(INDArray t, INDArray y0) {
}
@Override
- public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) {
+ public void step(SolverState solverState, INDArray step, INDArray error) {
nrofSteps++;
}
diff --git a/src/main/java/util/plot/Plot.java b/src/main/java/util/plot/Plot.java
new file mode 100644
index 0000000..2835e1f
--- /dev/null
+++ b/src/main/java/util/plot/Plot.java
@@ -0,0 +1,89 @@
+package util.plot;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Interface for XY plotting. Supports serialization of plot data.
+ * @param
+ * @param
+ */
+public interface Plot {
+
+ /**
+ * Factory interface
+ */
+ interface Factory {
+
+ /**
+ * Create a new plot with the given title
+ * @param title title of the plot
+ * @return a new Plot instance
+ */
+ Plot create(String title);
+ }
+
+ /**
+ * Creates a time series for the given label. If data with the given label exists in serialized format in the
+ * plotDir the time series of that data will be recreated.
+ * @param label series label.
+ */
+ void createSeries(String label);
+
+ /**
+ * Plot some data belonging to a certain label. Will be appended to an existing series of such exists, either in
+ * an existing window or in serialized format in the plotDir. If no series with the given label exists it will
+ * be created in the window of this plot instance.
+ * @param label series label
+ * @param x point on x axis
+ * @param y point on y axis
+ */
+ void plotData(String label, X x, Y y);
+
+ /**
+ * Plot some data belonging to a certain label. Will be appended to an existing series of such exists, either in
+ * an existing window or in serialized format in the plotDir. If no series with the given label exists it will
+ * be created in the window of this plot instance.
+ * @param label series label
+ * @param x points on x axis
+ * @param y points on y axis
+ */
+ void plotData(String label, List x, List y);
+
+ /**
+ * Clears the data for the given label
+ * @param label series label
+ */
+ void clearData(String label);
+
+ /**
+ * Serialize the data for all labels.
+ * @throws IOException
+ */
+ void storePlotData() throws IOException;
+
+ /**
+ * Serialize the data for the given label.
+ * @param label series label
+ * @throws IOException
+ */
+ void storePlotData(String label) throws IOException;
+
+ /**
+ * Save plot as a picture
+ */
+ void savePicture(String suffix) throws IOException;
+
+ /**
+ * Convenience method for debugging purposes. Plots the given data vs list indexes
+ * @param data
+ * @param
+ */
+ static void plot(List data, String plotName) {
+ final RealTimePlot plotter = new RealTimePlot<>(plotName, "");
+ plotter.createSeries("data");
+ for(int x = 0; x < data.size(); x ++) {
+ plotter.plotData("data", x, data.get(x));
+ }
+ }
+}
diff --git a/src/main/java/util/plot/RealTimePlot.java b/src/main/java/util/plot/RealTimePlot.java
new file mode 100644
index 0000000..7920423
--- /dev/null
+++ b/src/main/java/util/plot/RealTimePlot.java
@@ -0,0 +1,257 @@
+
+package util.plot;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
+import com.fasterxml.jackson.core.JsonParser;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.DeserializationContext;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
+import com.fasterxml.jackson.databind.module.SimpleModule;
+import lombok.Data;
+import org.jetbrains.annotations.NotNull;
+import org.knowm.xchart.*;
+import org.knowm.xchart.style.Styler;
+import org.knowm.xchart.style.Styler.ChartTheme;
+
+import javax.swing.*;
+import java.io.File;
+import java.io.IOException;
+import java.io.Serializable;
+import java.io.UncheckedIOException;
+import java.util.*;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+/**
+ * Real time updatable plot with support for an arbitrary number of series. Can also serialize the plotted data and
+ * recreate a plot from such data. Typically used for plot training/eval metrics for each iteration. Note: The amount
+ * of data points per timeseries is limited to 1000 as a significant slowdown was observed for higher numbers. When 1000
+ * points is reached, all even points will be removed. New points after this will be added as normal until the total hits
+ * 1000 again.
+ *
+ * @author Christian Skärby
+ */
+public class RealTimePlot implements Plot {
+
+ private final String title;
+ private final XYChart xyChart;
+ private final SwingWrapper swingWrapper;
+ private final String plotDir;
+
+ private final Map> plotSeries = new HashMap<>();
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, property = "@class")
+ @Data
+ private static class DataXY implements Serializable {
+
+ private static final long serialVersionUID = 7526471155622776891L;
+ private final String series;
+
+ private final LinkedList xData;
+ private final LinkedList yData;
+
+ DataXY(String series) {
+ this(series, new LinkedList<>(), new LinkedList<>());
+ }
+
+ DataXY(
+ @JsonProperty("series") String series,
+ @JsonProperty("xData") List xData,
+ @JsonProperty("yData") List yData) {
+ this.series = series;
+ this.xData = new LinkedList<>(xData);
+ this.yData = new LinkedList<>(yData);
+ }
+
+ private void addPoint(X x, Y y, XYChart xyChart, SwingWrapper swingWrapper) {
+ xData.addLast(x);
+ yData.addLast(y);
+ if (xData.size() > 1000) {
+ for (int i = 0; i < xData.size(); i += 2) {
+ xData.remove(i);
+ yData.remove(i);
+ }
+ }
+ plotData(xyChart, swingWrapper);
+ }
+
+ private void addData(List x, List y, XYChart xyChart, SwingWrapper swingWrapper) {
+ xData.addAll(x);
+ yData.addAll(y);
+ plotData(xyChart, swingWrapper);
+ }
+
+ private void plotData(XYChart xyChart, SwingWrapper swingWrapper) {
+ javax.swing.SwingUtilities.invokeLater(() -> {
+ if (!xyChart.getSeriesMap().containsKey(series)) {
+ xyChart.addSeries(series, xData, yData, null);
+ } else {
+ xyChart.updateXYSeries(series, xData, yData, null);
+ }
+ swingWrapper.repaintChart();
+ });
+ }
+
+ private void createSeries(final XYChart xyChart, SwingWrapper swingWrapper) {
+ javax.swing.SwingUtilities.invokeLater(() -> {
+ if (xData.size() == 0) {
+ xyChart.addSeries(series, Arrays.asList(0), Arrays.asList(1));
+ } else {
+ xyChart.addSeries(series, xData, yData);
+ }
+ swingWrapper.repaintChart();
+ });
+ }
+
+
+ private void clear() {
+ xData.clear();
+ yData.clear();
+ }
+ }
+
+ private class DataXYDeserializer extends StdDeserializer> {
+
+ public DataXYDeserializer() {
+ this(null);
+ }
+
+ public DataXYDeserializer(Class> vc) {
+ super(vc);
+ }
+
+ @Override
+ public DataXY deserialize(JsonParser jp, DeserializationContext ctxt)
+ throws IOException {
+ JsonNode node = jp.getCodec().readTree(jp);
+
+ final List xData = new ArrayList<>();
+ for (JsonNode xNode : node.get("xdata")) {
+ xData.add((X) xNode.numberValue());
+ }
+
+ final List yData = new ArrayList<>();
+ for (JsonNode xNode : node.get("ydata")) {
+ yData.add((Y) xNode.numberValue());
+ }
+ return new DataXY<>(node.get("series").toString(), xData, yData);
+ }
+
+ }
+
+ /**
+ * Constructor
+ *
+ * @param title Title of the plot
+ * @param plotDir Directory to store plots in.
+ */
+ public RealTimePlot(String title, String plotDir) {
+ // Create Chart
+ this.title = title;
+ xyChart = new XYChartBuilder().width(800).height(500).theme(ChartTheme.Matlab).title(title).build();
+ xyChart.getStyler().setLegendPosition(Styler.LegendPosition.OutsideE);
+ xyChart.getStyler().setDefaultSeriesRenderStyle(XYSeries.XYSeriesRenderStyle.Line);
+
+ this.swingWrapper = new SwingWrapper<>(xyChart);
+ swingWrapper.displayChart();
+ this.plotDir = plotDir;
+ }
+
+ @Override
+ public void plotData(String label, X x, Y y) {
+ final DataXY data = getOrCreateSeries(label);
+ data.addPoint(x, y, xyChart, swingWrapper);
+ }
+
+ @Override
+ public void plotData(String label, List x, List y) {
+ final DataXY data = getOrCreateSeries(label);
+ data.addData(x, y, xyChart, swingWrapper);
+ }
+
+ @Override
+ public void clearData(String label) {
+ final DataXY data = getOrCreateSeries(label);
+ data.clear();
+ }
+
+ @Override
+ public void createSeries(String label) {
+ getOrCreateSeries(label);
+ }
+
+ @NotNull
+ private DataXY getOrCreateSeries(String label) {
+ DataXY data = plotSeries.get(label);
+ if (data == null) {
+ data = restoreOrCreatePlotData(label);
+ plotSeries.put(label, data);
+ data.createSeries(xyChart, swingWrapper);
+ }
+ return data;
+ }
+
+ @Override
+ public void storePlotData() throws IOException {
+ for (String label : plotSeries.keySet()) {
+ storePlotData(label);
+ }
+ }
+
+ @Override
+ public void storePlotData(String label) throws IOException {
+ DataXY data = plotSeries.get(label);
+ if (data != null) {
+ new ObjectMapper().writeValue(new File(createFileName(label)), data);
+ }
+ }
+
+ @Override
+ public void savePicture(String suffix) {
+ SwingUtilities.invokeLater(() -> {
+ try {
+ BitmapEncoder.saveBitmap(xyChart, plotDir + File.separator + title + suffix, BitmapEncoder.BitmapFormat.PNG);
+ } catch (IOException e) {
+ throw new UncheckedIOException("Save picture in " + this.getClass() + " failed!", e);
+ }
+ });
+ }
+
+ private DataXY restoreOrCreatePlotData(String label) {
+ File dataFile = new File(createFileName(label));
+ if (dataFile.exists()) {
+ try {
+ final SimpleModule mod = new SimpleModule();
+ mod.addDeserializer(DataXY.class, new DataXYDeserializer());
+ final ObjectMapper mapper = new ObjectMapper().registerModule(mod);
+ return mapper.readValue(dataFile, new TypeReference>() {
+ });
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ return new DataXY<>(label);
+ }
+
+ private String createFileName(String label) {
+ return plotDir + File.separator + title + "_" + label + ".plt";
+ }
+
+ public static void main(String[] args) {
+
+ final RealTimePlot plotter = new RealTimePlot<>("Test plot", "");
+ IntStream.range(1000, 2000).forEach(x -> Stream.of("s1", "s2", "s3").forEach(str -> {
+ plotter.createSeries(str);
+ plotter.plotData(str, x, 1d / ((double) x + 10));
+ }));
+// try {
+// plotter.storePlotData("s1");
+// } catch (IOException e) {
+// e.printStackTrace();
+// }
+
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/util/random/SeededRandomFactory.java b/src/main/java/util/random/SeededRandomFactory.java
new file mode 100644
index 0000000..d4df45b
--- /dev/null
+++ b/src/main/java/util/random/SeededRandomFactory.java
@@ -0,0 +1,50 @@
+package util.random;
+
+import org.nd4j.linalg.api.rng.Random;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.factory.RandomFactory;
+
+/**
+ * {@link RandomFactory} with a configurable random seed
+ *
+ * @author Christian Skarby
+ */
+public class SeededRandomFactory extends RandomFactory {
+
+ final java.util.Random base;
+ private ThreadLocal threadRandom = new ThreadLocal<>();
+
+ public SeededRandomFactory(Class randomClass, long baseSeed) {
+ super(randomClass);
+ base = new java.util.Random(baseSeed);
+ }
+
+ @Override
+ public Random getRandom() {
+ // Copy pase from RandomFactory
+ try {
+ if (threadRandom.get() == null) {
+ Random t = super.getNewRandomInstance(base.nextLong());
+ threadRandom.set(t);
+ return t;
+ }
+
+ return threadRandom.get();
+ } catch (Exception e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ public Random getNewRandomInstance() {
+ return super.getNewRandomInstance(base.nextLong());
+ }
+
+ /**
+ * Set a base seed for all Nd4j random generators
+ * @param seed base seed
+ */
+ public static void setNd4jSeed(long seed) {
+ Nd4j.randomFactory = new SeededRandomFactory(Nd4j.randomFactory.getRandom().getClass(), seed);
+ }
+}
diff --git a/src/test/java/examples/spiral/AddKLDLabelTest.java b/src/test/java/examples/spiral/AddKLDLabelTest.java
new file mode 100644
index 0000000..5319c52
--- /dev/null
+++ b/src/test/java/examples/spiral/AddKLDLabelTest.java
@@ -0,0 +1,44 @@
+package examples.spiral;
+
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.MultiDataSet;
+import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test cases for {@link AddKLDLabel}
+ */
+public class AddKLDLabelTest {
+
+ /**
+ * Test that labels are added
+ */
+ @Test
+ public void preProcess() {
+ final long batchSize = 3;
+ final long nrofLatentDims = 5;
+
+ final MultiDataSet mds = new MultiDataSet(Nd4j.ones(batchSize, 13), Nd4j.ones(batchSize, 11));
+
+ final MultiDataSetPreProcessor addKld = new AddKLDLabel(1.23, 2.34, nrofLatentDims);
+ addKld.preProcess(mds);
+
+ final INDArray kldLabel = mds.getLabels(1);
+
+ assertArrayEquals("Incorrect shape!", new long[] {batchSize, 2*nrofLatentDims}, kldLabel.shape());
+
+ assertEquals("Expected first half to be mean!", 1.23,
+ kldLabel.get(NDArrayIndex.all(),
+ NDArrayIndex.interval(0, nrofLatentDims)).meanNumber().doubleValue(), 1e-5);
+
+ assertEquals("Expected second half to be log(var)!", Math.log(2.34),
+ kldLabel.get(NDArrayIndex.all(),
+ NDArrayIndex.interval(nrofLatentDims, 2*nrofLatentDims)).meanNumber().doubleValue(), 1e-5);
+
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/examples/spiral/LatentOdeTest.java b/src/test/java/examples/spiral/LatentOdeTest.java
new file mode 100644
index 0000000..fb87f8c
--- /dev/null
+++ b/src/test/java/examples/spiral/LatentOdeTest.java
@@ -0,0 +1,106 @@
+package examples.spiral;
+
+import ch.qos.logback.classic.Level;
+import examples.spiral.listener.PlotDecodedOutput;
+import examples.spiral.listener.SpiralPlot;
+import examples.spiral.loss.NormLogLikelihoodLoss;
+import ode.solve.conf.DormandPrince54Solver;
+import ode.solve.conf.SolverConfig;
+import ode.vertex.conf.helper.InputStep;
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.MultiDataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.Adam;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import util.listen.training.ZeroGrad;
+import util.plot.RealTimePlot;
+
+import java.awt.*;
+
+import static junit.framework.TestCase.assertTrue;
+
+/**
+ * Test cases for {@link LatentOdeBlock}, {@link DenseDecoderBlock} and {@link ReconstructionLossBlock} together
+ *
+ * @author Christian Skarby
+ */
+public class LatentOdeTest {
+
+ /**
+ * Test that the latent ODE can (over)fit to a simple line
+ */
+ @Test
+ public void fitLine() {
+ final long nrofTimeSteps = 10;
+ final long nrofLatentDims = 4;
+
+ //SeededRandomFactory.setNd4jSeed(0);
+
+ final ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
+ .seed(666)
+ .weightInit(WeightInit.UNIFORM)
+ .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ .updater(new Adam(0.01))
+ .graphBuilder()
+ .setInputTypes(InputType.feedForward(nrofLatentDims), InputType.feedForward(nrofTimeSteps));
+
+ String next = "z0";
+ builder.addInputs(next, "time");
+
+
+ next = new LatentOdeBlock(
+ 20,
+ nrofLatentDims,
+ new InputStep(
+ new DormandPrince54Solver(new SolverConfig(1e-12, 1e-6, 1e-20, 1e2)),
+ 1, true))
+ .add(builder, next, "time");
+
+ next = new DenseDecoderBlock(20, 2).add(builder, next);
+ final String decoded = next;
+ next = new ReconstructionLossBlock(new NormLogLikelihoodLoss(0.3)).add(builder, next);
+ builder.setOutputs(next);
+
+ final ComputationGraph graph = new ComputationGraph(builder.build());
+ graph.init();
+
+ final INDArray z0 = Nd4j.ones(1, nrofLatentDims);
+
+ final INDArray time = Nd4j.linspace(0, 3, nrofTimeSteps);
+ final INDArray label = Nd4j.hstack(time, Nd4j.linspace(0, 9, nrofTimeSteps)).reshape(1, 2, nrofTimeSteps);
+
+ if (!GraphicsEnvironment.isHeadless()
+ && !GraphicsEnvironment.getLocalGraphicsEnvironment().isHeadlessInstance()
+ && GraphicsEnvironment.getLocalGraphicsEnvironment().getScreenDevices() != null
+ && GraphicsEnvironment.getLocalGraphicsEnvironment().getScreenDevices().length > 0) {
+ ch.qos.logback.classic.Logger root = (ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME);
+ root.setLevel(Level.INFO);
+
+ final SpiralPlot linePlot = new SpiralPlot(new RealTimePlot<>("Decoded output", ""));
+ linePlot.plot("Ground truth", label.tensorAlongDimension(0, 1, 2));
+ graph.addListeners(new PlotDecodedOutput(linePlot, decoded, 0));
+ }
+
+ graph.addListeners(
+ new ZeroGrad(),
+ new ScoreIterationListener(1));
+
+ final MultiDataSet mds = new MultiDataSet(new INDArray[]{z0, time}, new INDArray[]{label});
+ boolean success = false;
+ for (int i = 0; i < 300; i++) {
+ graph.fit(mds);
+ success = graph.score() < 100;
+ if (success) break;
+ }
+ assertTrue("Model failed to train properly! Score after 300 iters: " + graph.score() + "!", success);
+ }
+}
diff --git a/src/test/java/examples/spiral/OdeNetModelTest.java b/src/test/java/examples/spiral/OdeNetModelTest.java
new file mode 100644
index 0000000..d7d16e0
--- /dev/null
+++ b/src/test/java/examples/spiral/OdeNetModelTest.java
@@ -0,0 +1,113 @@
+package examples.spiral;
+
+import com.beust.jcommander.JCommander;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.util.ModelSerializer;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.MultiDataSet;
+import org.nd4j.linalg.factory.Nd4j;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * Test cases for {@link OdeNetModel}
+ *
+ * @author Christian Skarby
+ */
+public class OdeNetModelTest {
+
+ /**
+ * Test that the model can be created and that it is possible to make train for two batches
+ */
+ @Test
+ public void fit() {
+ final OdeNetModel factory = new OdeNetModel();
+
+ final long nrofTimeSteps = 10;
+ final long batchSize = 3;
+ final long nrofLatentDims = 5;
+ final ComputationGraph model = factory.createNew(nrofTimeSteps, 0.3, nrofLatentDims).trainingModel();
+ model.fit(new MultiDataSet(
+ new INDArray[]{Nd4j.randn(new long[]{batchSize, 2, nrofTimeSteps}), Nd4j.linspace(0, 3, nrofTimeSteps)},
+ new INDArray[]{Nd4j.randn(new long[] {batchSize, 2, nrofTimeSteps}), Nd4j.zeros(batchSize, nrofLatentDims)}));
+ model.fit(new MultiDataSet(
+ new INDArray[]{Nd4j.randn(new long[]{batchSize, 2, nrofTimeSteps}), Nd4j.linspace(0, 3, nrofTimeSteps)},
+ new INDArray[]{Nd4j.randn(new long[] {batchSize, 2, nrofTimeSteps}), Nd4j.zeros(batchSize, nrofLatentDims)}));
+ }
+
+ /**
+ * Test that the model can be serialized and deserialized into the same thing
+ */
+ @Test
+ public void testSerializeDeserialize() throws IOException {
+ final OdeNetModel factory = new OdeNetModel();
+ final int nrofTimeSteps = 5;
+ JCommander.newBuilder()
+ .addObject(factory)
+ .build()
+ .parse();
+ final ComputationGraph graph = factory.createNew(nrofTimeSteps, 0.1, 4).trainingModel();
+
+ final Path baseDir = Paths.get("src", "test", "resources", "OdeNetModelTest");
+ final String fileName = Paths.get(baseDir.toString(), "testSerializeDeserialize.zip").toString();
+
+ try {
+
+ baseDir.toFile().mkdirs();
+ graph.save(new File(fileName), true);
+ final ComputationGraph newGraph = ModelSerializer.restoreComputationGraph(new File(fileName), true);
+
+ assertEquals("Config was not restored properly!", graph.getConfiguration(), newGraph.getConfiguration());
+
+ final long batchSize = 3;
+ final INDArray[] input = {Nd4j.randn(new long[]{batchSize, 2, nrofTimeSteps}), Nd4j.linspace(0, 3, nrofTimeSteps)};
+
+ assertEquals("Output not the same!", graph.output(input)[0], newGraph.output(input)[0]);
+
+ } catch (IOException e) {
+ e.printStackTrace();
+ fail("Failed to serialize or deserialize graph!");
+ } finally {
+ new File(fileName).delete();
+ Files.delete(baseDir);
+ }
+ }
+
+ /**
+ * Smoke test to assert that an {@link OdeNetModel} can be represented as a {@link TimeVae} without there being
+ * any exceptions.
+ */
+ @Test
+ public void asTimeVae() {
+ final long batchSize = 5;
+ final long nrofTimeSteps = 17;
+ final long nrofLatentDims = 6;
+
+ final TimeVae timeVae = new OdeNetModel().createNew(nrofTimeSteps, 0.3, nrofLatentDims);
+
+ final INDArray inputTraj = Nd4j.randn(new long[]{batchSize, 2, nrofTimeSteps});
+ final INDArray time = Nd4j.linspace(0, 3, nrofTimeSteps);
+
+ final INDArray z0 = timeVae.encode(inputTraj);
+
+ assertArrayEquals("Incorrect shape of z0!", new long[] {batchSize, nrofLatentDims}, z0.shape());
+
+ final INDArray zt = timeVae.timeDependency(z0, time);
+
+ assertArrayEquals("Incorrect shape of zt!", new long[] {batchSize, nrofLatentDims, nrofTimeSteps}, zt.shape());
+
+ final INDArray decoded = timeVae.decode(zt);
+
+ assertArrayEquals("Incorrect shape of decoded output!", new long[] {batchSize, 2, nrofTimeSteps}, decoded.shape());
+
+ }
+}
diff --git a/src/test/java/examples/spiral/SpiralFactoryTest.java b/src/test/java/examples/spiral/SpiralFactoryTest.java
new file mode 100644
index 0000000..acd60a5
--- /dev/null
+++ b/src/test/java/examples/spiral/SpiralFactoryTest.java
@@ -0,0 +1,30 @@
+package examples.spiral;
+
+import org.junit.Test;
+
+import java.util.List;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Test cases for {@link SpiralFactory}
+ *
+ * @author Christian Skarby
+ */
+public class SpiralFactoryTest {
+
+ /**
+ * Test that spirals can be sampled from factory
+ */
+ @Test
+ public void sample() {
+ final SpiralFactory factory = new SpiralFactory(0, 0.3, 0, 10, 1000);
+ final List sample = factory.sample(4, 100, () -> 0.3, () -> true);
+ assertEquals("Incorrect number of samples!", 4, sample.size());
+ for(SpiralFactory.Spiral spiral: sample) {
+ assertArrayEquals("Incorrect trajectory!", new long[] {2, 100}, spiral.trajectory().shape());
+ assertArrayEquals("Incorrect theta!", new long[] {1, 100}, spiral.theta().shape());
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/examples/spiral/SpiralIteratorTest.java b/src/test/java/examples/spiral/SpiralIteratorTest.java
new file mode 100644
index 0000000..51e0024
--- /dev/null
+++ b/src/test/java/examples/spiral/SpiralIteratorTest.java
@@ -0,0 +1,40 @@
+package examples.spiral;
+
+import org.junit.Test;
+import org.nd4j.linalg.dataset.api.MultiDataSet;
+import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
+
+import java.util.Random;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Test cases for {@link SpiralIterator}
+ *
+ * @author Christian Skarby
+ */
+public class SpiralIteratorTest {
+
+ /**
+ * Test that a spiral can be generated into a {@link MultiDataSet} and that the dimensions are as expected
+ */
+ @Test
+ public void next() {
+ final long nrofSamplesOrig = 1000;
+ final long nrofSamplesTrain = 100;
+ final int batchSize = 200;
+ final SpiralFactory spiralFactory = new SpiralFactory(0, 0.3, 0, 6*Math.PI, nrofSamplesOrig);
+ final MultiDataSetIterator iterator = new SpiralIterator(
+ new SpiralIterator.Generator(spiralFactory, 0.3, nrofSamplesTrain, new Random(666)),
+ batchSize);
+
+ final MultiDataSet mds = iterator.next();
+
+ final long[] expectedShapeSpiral = {batchSize, 2, nrofSamplesTrain};
+ assertArrayEquals("Incorrect shape of spiral!", expectedShapeSpiral, mds.getFeatures(0).shape());
+ assertArrayEquals("Incorrect shape of label!", expectedShapeSpiral, mds.getLabels(0).shape());
+
+ final long[] expectedShapeTime = {1, nrofSamplesTrain};
+ assertArrayEquals("Incorrect shape of time!", expectedShapeTime, mds.getFeatures(1).shape());
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/examples/spiral/loss/NormKLDLossTest.java b/src/test/java/examples/spiral/loss/NormKLDLossTest.java
new file mode 100644
index 0000000..24fc7b1
--- /dev/null
+++ b/src/test/java/examples/spiral/loss/NormKLDLossTest.java
@@ -0,0 +1,142 @@
+package examples.spiral.loss;
+
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.conf.layers.LossLayer;
+import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.junit.Test;
+import org.nd4j.linalg.activations.impl.ActivationIdentity;
+import org.nd4j.linalg.activations.impl.ActivationTanH;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.learning.config.Adam;
+import org.nd4j.linalg.lossfunctions.ILossFunction;
+import org.nd4j.linalg.primitives.Pair;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.assertTrue;
+
+/**
+ * Test cases for {@link NormKLDLoss}
+ *
+ * @author Christian Skarby
+ */
+public class NormKLDLossTest {
+
+ /**
+ * Test that score and gradient is 0 for a standard gaussian distribution
+ */
+ @Test
+ public void computeGradientAndScoreStandard() {
+ final long batchSize = 11;
+ final long nrofLatentDimsTimesTwo = 14;
+ final INDArray meanAndLogvar = Nd4j.zeros(batchSize, nrofLatentDimsTimesTwo);
+
+ final ILossFunction toTest = new NormKLDLoss();
+
+ final Pair scoreAndGrad = toTest.computeGradientAndScore(meanAndLogvar, meanAndLogvar, new ActivationIdentity(), null, false);
+
+ assertEquals("Score shall be 0! ", 0.0, scoreAndGrad.getFirst(), 1e-10);
+ assertEquals("Gradient shall be 0!", meanAndLogvar, scoreAndGrad.getSecond());
+ }
+
+ /**
+ * Test that gradient is pointing in direction of error
+ */
+ @Test
+ public void gradientMean() {
+ final long batchSize = 11;
+ final long nrofLatentDimsTimesTwo = 6;
+ final INDArray meanAndLogvar = Nd4j.zeros(batchSize, nrofLatentDimsTimesTwo);
+ final INDArray expectedMeanAndLogVar = meanAndLogvar.dup();
+
+ meanAndLogvar.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 2)).assign(Nd4j.create(new double[]{-1.23, 2.34}));
+ meanAndLogvar.get(NDArrayIndex.point(2), NDArrayIndex.interval(0, 2)).assign(Nd4j.create(new double[]{3.45, -4.56}));
+
+ final ILossFunction toTest = new NormKLDLoss();
+
+ final INDArray gradient = toTest.computeGradient(expectedMeanAndLogVar, meanAndLogvar, new ActivationIdentity(), null);
+
+ assertEquals("Gradient shall be same as input!", meanAndLogvar, gradient);
+ }
+
+ /**
+ * Test that gradient is pointing in direction of error
+ */
+ @Test
+ public void gradientLogVar() {
+ final long batchSize = 3;
+ final long nrofLatentDimsTimesTwo = 4;
+ final INDArray meanAndLogvar = Nd4j.zeros(batchSize, nrofLatentDimsTimesTwo);
+ final INDArray expectedMeanAndLogVar = meanAndLogvar.dup();
+
+ meanAndLogvar.get(NDArrayIndex.point(0), NDArrayIndex.interval(2, 4)).assign(Nd4j.create(new double[]{-1.23, 2.34}));
+ meanAndLogvar.get(NDArrayIndex.point(2), NDArrayIndex.interval(2, 4)).assign(Nd4j.create(new double[]{3.45, -4.56}));
+
+ final ILossFunction toTest = new NormKLDLoss();
+
+ final INDArray gradient = toTest.computeGradient(expectedMeanAndLogVar, meanAndLogvar, new ActivationIdentity(), null);
+
+ assertTrue("Gradient shall be < 0!", gradient.getDouble(0, 2) < 0);
+ assertTrue("Gradient shall be > 0!", gradient.getDouble(0, 3) > 0);
+ assertEquals("Gradient shall be 0!", 0d, gradient.getDouble(1, 2), 1e-10);
+ assertEquals("Gradient shall be 0!", 0d, gradient.getDouble(1, 3), 1e-10);
+ assertTrue("Gradient shall be < 0!", gradient.getDouble(2, 2) > 0);
+ assertTrue("Gradient shall be > 0!", gradient.getDouble(2, 3) < 0);
+ }
+
+ /**
+ * Test to teach a small neural network to output zero mean and unit variance
+ */
+ @Test
+ public void learnStandardGaussian() {
+ final long batchSize = 3;
+ final long nrofInputs = 4;
+ final long nrofTimeSteps = 5;
+
+ final DataSet ds = new DataSet(
+ Nd4j.linspace(-3, 3, batchSize*nrofInputs*nrofTimeSteps).reshape(batchSize, nrofInputs, nrofTimeSteps),
+ Nd4j.zeros(batchSize, nrofInputs));
+
+ final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
+ .seed(666)
+ .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ .updater(new Adam(0.01))
+ .graphBuilder()
+ .addInputs("input")
+ .setInputTypes(InputType.recurrent(nrofInputs, nrofTimeSteps))
+ .addLayer("rnn", new SimpleRnn.Builder()
+ .activation(new ActivationTanH())
+ .nOut(20)
+ .build(), "input")
+ .addVertex("lastStep", new LastTimeStepVertex("input"), "rnn")
+ .addLayer("dnn", new DenseLayer.Builder()
+ .nOut(nrofInputs)
+ .activation(new ActivationIdentity())
+ .build(), "lastStep")
+
+ .addLayer("out", new LossLayer.Builder()
+ .lossFunction(new NormKLDLoss())
+ .build(), "dnn")
+ .setOutputs("out")
+ .build());
+ graph.init();
+
+ boolean success = false;
+ for(int i = 0; i < 300; i++) {
+ graph.fit(ds);
+ success |= graph.score() < 0.001;
+ if(success) break;
+ }
+
+ assertTrue("Training did not succeed!", success);
+
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/java/examples/spiral/loss/NormLogLikelihoodLossTest.java b/src/test/java/examples/spiral/loss/NormLogLikelihoodLossTest.java
new file mode 100644
index 0000000..94f4d73
--- /dev/null
+++ b/src/test/java/examples/spiral/loss/NormLogLikelihoodLossTest.java
@@ -0,0 +1,67 @@
+package examples.spiral.loss;
+
+import org.junit.Test;
+import org.nd4j.linalg.activations.impl.ActivationIdentity;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.primitives.Pair;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Test cases for {@link NormLogLikelihoodLoss}
+ *
+ * @author Christian Skarby
+ */
+public class NormLogLikelihoodLossTest {
+
+ private final double[][][] expectedGrad = {{{-29.2222, -27.2069, -25.1916, -23.1762, -21.1609},
+ {-19.1456, -17.1303, -15.1149, -13.0996, -11.0843}},
+
+ {{-9.0690, -7.0536, -5.0383, -3.0230, -1.0077},
+ {1.0077, 3.0230, 5.0383, 7.0536, 9.0690}},
+
+ {{11.0843, 13.0996, 15.1149, 17.1303, 19.1456},
+ {21.1609, 23.1762, 25.1916, 27.2069, 29.2222}}};
+
+ /**
+ * Test that score and gradient is zero when label and prediction are the same
+ */
+ @Test
+ public void computeGradientAndScorePerfectMatch() {
+ final INDArray traj = Nd4j.linspace(-3.45, 2.34, 2 * 3 * 5).reshape(3, 2, 5);
+
+ final Pair out = new NormLogLikelihoodLoss(0.3)
+ .computeGradientAndScore(traj, traj, new ActivationIdentity(), null, false);
+
+ // Loss has a constant term
+ assertEquals("Expected minimum loss!", -8.55102825164795, out.getFirst(), 1e-5);
+ assertEquals("Expected zero grad!", 0, out.getSecond().amaxNumber().doubleValue(), 1e-10);
+ }
+
+ /**
+ * Test that score and gradient is correct when label and prediction are not the same. Numbers taken from original repo.
+ */
+ @Test
+ public void computeGradientAndScoreWithEps() {
+ final INDArray traj = Nd4j.linspace(-3.45, 2.34, 2 * 3 * 5).reshape(3, 2, 5);
+ final INDArray eps = Nd4j.linspace(-7.89, 7.89, traj.length()).reshape(traj.shape());
+
+ final Pair out = new NormLogLikelihoodLoss(0.3)
+ .computeGradientAndScore(traj, traj.add(eps), new ActivationIdentity(), null, true);
+
+
+ out.getSecond().divi(3); // Dl4j scales gradients with mini batch size centrally in BaseMultiLayerUpdater
+ assertEquals("Incorrect loss!", 1229.4709, out.getFirst(), 1e-4);
+
+ for (int i = 0; i < expectedGrad.length; i++) {
+ for (int j = 0; j < expectedGrad[i].length; j++) {
+ assertArrayEquals("Incorrect gradient along " + i + ", " + j + "!",
+ expectedGrad[i][j],
+ out.getSecond().get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all()).toDoubleVector(), 1e-4);
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/examples/spiral/vertex/impl/SampleGaussianVertexTest.java b/src/test/java/examples/spiral/vertex/impl/SampleGaussianVertexTest.java
new file mode 100644
index 0000000..d170906
--- /dev/null
+++ b/src/test/java/examples/spiral/vertex/impl/SampleGaussianVertexTest.java
@@ -0,0 +1,141 @@
+package examples.spiral.vertex.impl;
+
+import examples.spiral.vertex.conf.SampleGaussianVertex;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
+import org.deeplearning4j.nn.conf.layers.LossLayer;
+import org.deeplearning4j.nn.conf.memory.MemoryReport;
+import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.graph.vertex.GraphVertex;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.jetbrains.annotations.NotNull;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.ops.transforms.Transforms;
+import org.nd4j.linalg.primitives.Pair;
+
+import java.util.stream.LongStream;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test cases for {@link examples.spiral.vertex.conf.SampleGaussianVertex}
+ *
+ * @author Christian Skarby
+ */
+public class SampleGaussianVertexTest {
+
+ /**
+ * Test that output has the same mean and variance which is input
+ */
+ @Test
+ public void doForward() {
+ final long nrofLatentDims = 4;
+ final ComputationGraph graph = getGraph(nrofLatentDims, new SampleGaussianVertex(666));
+
+ final int batchSize = 100000;
+ final INDArray means = Nd4j.arange(nrofLatentDims);
+ final INDArray logVars = Nd4j.arange(nrofLatentDims, 2*nrofLatentDims);
+ INDArray output = graph.outputSingle(Nd4j.repeat(means, batchSize), Nd4j.repeat(logVars, batchSize));
+
+ assertArrayEquals("Incorrect mean!", means.toDoubleVector(), output.mean(0).toDoubleVector(), 1e-1);
+ assertArrayEquals("Incorrect logvar!", logVars.toDoubleVector(), Transforms.log(output.var(0)).toDoubleVector(), 1e-1);
+ }
+
+ /**
+ * Test doBackward. LogVar numbers verified in pytorch
+ */
+ @Test
+ public void doBackward() {
+ final long nrofLatentDims = 2;
+ final ComputationGraph graph = getGraph(nrofLatentDims, new TestSampleGaussianVertex());
+
+ final GraphVertex vertex = graph.getVertex("z");
+
+ // Need to do a forward pass to set inputs and calculate epsilon
+ vertex.setInput(0, Nd4j.arange(2*nrofLatentDims), LayerWorkspaceMgr.noWorkspaces());
+ vertex.doForward(true, LayerWorkspaceMgr.noWorkspaces());
+
+ vertex.setEpsilon(Nd4j.create(new double[] {1.3, 2.4}));
+ final Pair result = graph.getVertex("z").doBackward(false, LayerWorkspaceMgr.noWorkspaces());
+
+ assertEquals("Incorrect gradient for mean!",
+ vertex.getEpsilon(),
+ result.getSecond()[0].get(NDArrayIndex.all(), NDArrayIndex.interval(0, nrofLatentDims)));
+ assertEquals("Incorrect gradient for log var!",
+ Nd4j.create(new double[] {2.6503, 23.1255}),
+ result.getSecond()[0].get(NDArrayIndex.all(), NDArrayIndex.interval(nrofLatentDims, 2*nrofLatentDims)));
+ }
+
+ @NotNull
+ private static ComputationGraph getGraph(long nrofLatentDims, org.deeplearning4j.nn.conf.graph.GraphVertex vertex) {
+ final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
+ .graphBuilder()
+ .addInputs("mean", "logVar")
+ .setInputTypes(InputType.feedForward(nrofLatentDims), InputType.feedForward(nrofLatentDims))
+ .addVertex("z", vertex, "mean", "logVar") // Note: MergeVertex will be added
+ .setOutputs("output")
+ .addLayer("output", new LossLayer.Builder().build(), "z")
+ .build());
+ graph.init();
+ return graph;
+ }
+
+ private static class TestSampleGaussianVertex extends org.deeplearning4j.nn.conf.graph.GraphVertex {
+
+ private final SampleGaussianVertex helper = new SampleGaussianVertex(666);
+
+ @Override
+ public org.deeplearning4j.nn.conf.graph.GraphVertex clone() {
+ return null;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return 0;
+ }
+
+ @Override
+ public long numParams(boolean backprop) {
+ return helper.numParams(backprop);
+ }
+
+ @Override
+ public int minVertexInputs() {
+ return helper.minVertexInputs();
+ }
+
+ @Override
+ public int maxVertexInputs() {
+ return helper.maxVertexInputs();
+ }
+
+ @Override
+ public GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) {
+ return new examples.spiral.vertex.impl.SampleGaussianVertex(graph, name, idx, shape -> {
+ final long sum = LongStream.of(shape).reduce(1, (l1,l2) -> l1*l2);
+ return Nd4j.linspace(1.5, 4.3, sum).reshape(shape);
+ });
+ }
+
+ @Override
+ public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
+ return helper.getOutputType(layerIndex, vertexInputs);
+ }
+
+ @Override
+ public MemoryReport getMemoryReport(InputType... inputTypes) {
+ return helper.getMemoryReport(inputTypes);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/solve/impl/DormandPrince54SolverTest.java b/src/test/java/ode/solve/impl/DormandPrince54SolverTest.java
index 9dd5c1a..af9e1ae 100644
--- a/src/test/java/ode/solve/impl/DormandPrince54SolverTest.java
+++ b/src/test/java/ode/solve/impl/DormandPrince54SolverTest.java
@@ -5,6 +5,7 @@
import ode.solve.api.StepListener;
import ode.solve.commons.FirstOrderSolverAdapter;
import ode.solve.conf.SolverConfig;
+import ode.solve.impl.util.SolverState;
import org.apache.commons.math3.ode.nonstiff.DormandPrince54Integrator;
import org.junit.AfterClass;
import org.junit.BeforeClass;
@@ -103,8 +104,8 @@ public void begin(INDArray t, INDArray y0) {
}
@Override
- public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) {
- times.add(currTime.detach());
+ public void step(SolverState solverState, INDArray step, INDArray error) {
+ times.add(solverState.time().detach());
}
@Override
diff --git a/src/test/java/ode/solve/impl/InterpolatingMultiStepSolverTest.java b/src/test/java/ode/solve/impl/InterpolatingMultiStepSolverTest.java
new file mode 100644
index 0000000..5df8d7d
--- /dev/null
+++ b/src/test/java/ode/solve/impl/InterpolatingMultiStepSolverTest.java
@@ -0,0 +1,57 @@
+package ode.solve.impl;
+
+import ode.solve.CircleODE;
+import ode.solve.api.FirstOrderEquation;
+import ode.solve.api.FirstOrderSolver;
+import ode.solve.conf.SolverConfig;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.primitives.Pair;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Test cases for {@link InterpolatingMultiStepSolver}
+ *
+ * @author Christian Skarby
+ */
+public class InterpolatingMultiStepSolverTest {
+
+ /**
+ * Test that output from a {@link SingleSteppingMultiStepSolver} is the same as the output from an
+ * {@link InterpolatingMultiStepSolver} when solving the {@link CircleODE} when using the same input.
+ */
+ @Test
+ public void integrateCircle() {
+ final Pair multiAndInterp = solveCircleMultiInterpol();
+
+ final INDArray multi = multiAndInterp.getFirst();
+ final INDArray interp = multiAndInterp.getSecond();
+
+ for(int i = 0; i < multi.columns(); i++)
+ assertArrayEquals("Solutions are different in column " + i +"!!",
+ multi.getColumn(i).toDoubleVector(),
+ interp.getColumn(i).toDoubleVector(),
+ 1e-4);
+ }
+
+ private static Pair solveCircleMultiInterpol() {
+ final FirstOrderSolver singleStepSolver = new DormandPrince54Solver(
+ new SolverConfig(1e-7, 1e-7, 1e-10, 1e2));
+
+ final double omega = 5.67;
+ final FirstOrderEquation equation = new CircleODE(new double[] {1.23, 4.56}, omega);
+
+ final int nrofSteps = 25;
+ final INDArray y0 = Nd4j.create(new double[] {-5.6, 7.3});
+ final INDArray ySingle = y0.dup();
+ final INDArray yMulti = Nd4j.repeat(ySingle,nrofSteps-1).reshape(nrofSteps-1, y0.length()).assign(0);
+ final INDArray t = Nd4j.linspace(-Math.PI/omega,Math.PI/omega , nrofSteps);
+
+ final INDArray expected = new SingleSteppingMultiStepSolver(singleStepSolver).integrate(equation, t, y0, yMulti.dup()).transposei();
+ final INDArray actual = new InterpolatingMultiStepSolver(singleStepSolver).integrate(equation, t, y0, yMulti.dup()).transposei();
+
+ return new Pair<>(expected, actual);
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/solve/impl/SingleSteppingMultiStepSolverTest.java b/src/test/java/ode/solve/impl/SingleSteppingMultiStepSolverTest.java
new file mode 100644
index 0000000..b899f0c
--- /dev/null
+++ b/src/test/java/ode/solve/impl/SingleSteppingMultiStepSolverTest.java
@@ -0,0 +1,40 @@
+package ode.solve.impl;
+
+import ode.solve.CircleODE;
+import ode.solve.api.FirstOrderEquation;
+import ode.solve.api.FirstOrderSolver;
+import ode.solve.conf.SolverConfig;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import static junit.framework.TestCase.assertEquals;
+
+/**
+ * Test cases for {@link SingleSteppingMultiStepSolver}
+ *
+ * @author Christian Skarby
+ */
+public class SingleSteppingMultiStepSolverTest {
+
+ /**
+ * Test that solving {@link ode.solve.CircleODE} in multiple steps gives the same result as if all steps are done
+ * in one solve.
+ */
+ @Test
+ public void integrate() {
+ final int nrofSteps = 20;
+ final FirstOrderEquation circle = new CircleODE(new double[] {1.23, 4.56}, 1);
+ final FirstOrderSolver actualSolver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 10));
+
+ final INDArray t = Nd4j.linspace(0, Math.PI, nrofSteps);
+ final INDArray y0 = Nd4j.create(new double[] {0, 0});
+ final INDArray ySingle = y0.dup();
+ final INDArray yMulti = Nd4j.repeat(ySingle,nrofSteps-1).reshape(nrofSteps-1, y0.length());
+
+ actualSolver.integrate(circle, t.getColumns(0, nrofSteps-1), y0, ySingle);
+ new SingleSteppingMultiStepSolver(actualSolver).integrate(circle, t, y0, yMulti);
+
+ assertEquals("Incorrect solution!", ySingle, yMulti.getRow(nrofSteps-2));
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicyTest.java b/src/test/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicyTest.java
index eb8bb34..c1b182a 100644
--- a/src/test/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicyTest.java
+++ b/src/test/java/ode/solve/impl/util/AdaptiveRungeKuttaStepPolicyTest.java
@@ -52,7 +52,7 @@ DormandPrince54Integrator setEquation(FirstOrderDifferentialEquations equations)
true, 5, scale, t.getDouble(0), y0.toDoubleVector(), yDot,
y1, yDot1);
- final FirstOrderEquationWithState eqState = new FirstOrderEquationWithState(equation, t.getColumn(0), y0, 5);
+ final FirstOrderEquationWithState eqState = new FirstOrderEquationWithState(equation, t.getColumn(0), y0, new double [5]);
final INDArray stepAct = new AdaptiveRungeKuttaStepPolicy(
new SolverConfigINDArray(absTol, relTol, 1e-20, 1e20), 5)
.initializeStep(eqState, t);
diff --git a/src/test/java/ode/solve/impl/util/AggStepListenerTest.java b/src/test/java/ode/solve/impl/util/AggStepListenerTest.java
index af79155..0b223b8 100644
--- a/src/test/java/ode/solve/impl/util/AggStepListenerTest.java
+++ b/src/test/java/ode/solve/impl/util/AggStepListenerTest.java
@@ -25,7 +25,7 @@ public void addListenersAndRemoveOne() {
aggStepListener.addListeners(second,third);
aggStepListener.begin(Nd4j.linspace(0,1 ,2), Nd4j.zeros(1));
- aggStepListener.step(Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1));
+ aggStepListener.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(1), Nd4j.zeros(1));
aggStepListener.done();
first.assertNrofCalls(1,1,1);
second.assertNrofCalls(1,1,1);
@@ -34,7 +34,7 @@ public void addListenersAndRemoveOne() {
aggStepListener.clearListeners(second);
aggStepListener.begin(Nd4j.linspace(0,1 ,2), Nd4j.zeros(1));
- aggStepListener.step(Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1));
+ aggStepListener.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(1), Nd4j.zeros(1));
aggStepListener.done();
first.assertNrofCalls(2,2,2);
second.assertNrofCalls(1,1,1);
@@ -43,7 +43,7 @@ public void addListenersAndRemoveOne() {
aggStepListener.clearListeners(third, first);
aggStepListener.begin(Nd4j.linspace(0,1 ,2), Nd4j.zeros(1));
- aggStepListener.step(Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1));
+ aggStepListener.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(1), Nd4j.zeros(1));
aggStepListener.done();
first.assertNrofCalls(2,2,2);
second.assertNrofCalls(1,1,1);
@@ -53,7 +53,7 @@ public void addListenersAndRemoveOne() {
aggStepListener.clearListeners();
aggStepListener.begin(Nd4j.linspace(0,1 ,2), Nd4j.zeros(1));
- aggStepListener.step(Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1), Nd4j.zeros(1));
+ aggStepListener.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(1), Nd4j.zeros(1));
aggStepListener.done();
first.assertNrofCalls(2,2,2);
second.assertNrofCalls(1,1,1);
diff --git a/src/test/java/ode/solve/impl/util/InterpolatingStepListenerTest.java b/src/test/java/ode/solve/impl/util/InterpolatingStepListenerTest.java
new file mode 100644
index 0000000..d1fefe2
--- /dev/null
+++ b/src/test/java/ode/solve/impl/util/InterpolatingStepListenerTest.java
@@ -0,0 +1,116 @@
+package ode.solve.impl.util;
+
+import examples.spiral.listener.SpiralPlot;
+import ode.solve.CircleODE;
+import ode.solve.api.FirstOrderEquation;
+import ode.solve.api.FirstOrderSolver;
+import ode.solve.conf.SolverConfig;
+import ode.solve.impl.DormandPrince54Solver;
+import ode.solve.impl.SingleSteppingMultiStepSolver;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.primitives.Pair;
+import util.plot.RealTimePlot;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Test cases for {@link InterpolatingStepListener}
+ *
+ * @author Christian Skarby
+ */
+public class InterpolatingStepListenerTest {
+
+ /**
+ * Verify that the result from using an {@link InterpolatingStepListener} is equivalent to using a
+ * {@link SingleSteppingMultiStepSolver} when solving the {@link CircleODE} when using the same input.
+ */
+ @Test
+ public void iterpolateCircleForward() {
+ final Pair multiAndInterp = solveCircleMultiInterpol(false);
+
+ final INDArray multi = multiAndInterp.getFirst();
+ final INDArray interp = multiAndInterp.getSecond();
+
+ for (int i = 0; i < multi.columns(); i++)
+ assertArrayEquals("Solutions are different in column " + i + "!!",
+ multi.getColumn(i).toDoubleVector(),
+ interp.getColumn(i).toDoubleVector(),
+ 1e-4);
+ }
+
+ /**
+ * Verify that the result from using an {@link InterpolatingStepListener} is equivalent to using a
+ * {@link SingleSteppingMultiStepSolver} when solving the {@link CircleODE} when using the same input.
+ */
+ @Test
+ public void iterpolateCircleBackward() {
+ final Pair multiAndInterp = solveCircleMultiInterpol(true);
+
+ final INDArray multi = multiAndInterp.getFirst();
+ final INDArray interp = multiAndInterp.getSecond();
+
+ for (int i = 0; i < multi.columns(); i++)
+ assertArrayEquals("Solutions are different in column " + i + "!!",
+ multi.getColumn(i).toDoubleVector(),
+ interp.getColumn(i).toDoubleVector(),
+ 1e-4);
+ }
+
+
+ private static Pair solveCircleMultiInterpol(boolean backwards) {
+ final FirstOrderSolver singleStepSolver = new DormandPrince54Solver(
+ new SolverConfig(1e-7, 1e-7, 1e-10, 1e2));
+
+ final double omega = 5.67;
+ final FirstOrderEquation equation = new CircleODE(new double[]{1.23, 4.56}, omega);
+
+ final int nrofSteps = 25;
+ final INDArray y0 = Nd4j.create(new double[]{-5.6, 7.3});
+ final INDArray ySingle = y0.dup();
+ final INDArray yMulti = Nd4j.repeat(ySingle, nrofSteps - 1).reshape(nrofSteps - 1, y0.length()).assign(0);
+
+ final INDArray tStartEnd = Nd4j.create(new double[]{-Math.PI / omega, Math.PI / omega});
+ if (backwards) tStartEnd.negi();
+ final INDArray t = Nd4j.linspace(tStartEnd.getDouble(0), tStartEnd.getDouble(1), nrofSteps);
+
+ final INDArray expected = new SingleSteppingMultiStepSolver(singleStepSolver).integrate(equation, t, y0, yMulti.dup()).transposei();
+
+ singleStepSolver.addListener(new InterpolatingStepListener(t.get(NDArrayIndex.interval(1, nrofSteps)), yMulti));
+ final INDArray yLast = y0.dup();
+ singleStepSolver.integrate(equation, tStartEnd, y0, yLast);
+
+ return new Pair<>(expected, yMulti.transposei());
+ }
+
+ /**
+ * Main method for plotting results from both approaches
+ *
+ * @param args not used
+ */
+ public static void main(String[] args) {
+
+ final Pair multiAndInterpBackwards = solveCircleMultiInterpol(true);
+ plotSolutions(new SpiralPlot(new RealTimePlot<>("Multi step vs interpol bwd", "")), multiAndInterpBackwards);
+
+ final Pair multiAndInterpForwards = solveCircleMultiInterpol(false);
+ plotSolutions(new SpiralPlot(new RealTimePlot<>("Multi step vs interpol fwd" , "")), multiAndInterpForwards);
+ }
+
+ private static void plotSolutions(SpiralPlot plot, Pair multiAndInterp) {
+ final INDArray multi = multiAndInterp.getFirst();
+ final INDArray interp = multiAndInterp.getSecond();
+
+ plot.plot("Multi", multi);
+ plot.plot( "Interp", interp);
+
+ plot.plot("Multi start", multi.getColumn(0));
+ plot.plot("Interp start", interp.getColumn(0));
+
+ final long lastStep = multi.size(1) - 1;
+ plot.plot("Multi stop", multi.getColumn(lastStep));
+ plot.plot("Interp stop", interp.getColumn(lastStep));
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/solve/impl/util/InterpolationTest.java b/src/test/java/ode/solve/impl/util/InterpolationTest.java
new file mode 100644
index 0000000..3d85f52
--- /dev/null
+++ b/src/test/java/ode/solve/impl/util/InterpolationTest.java
@@ -0,0 +1,48 @@
+package ode.solve.impl.util;
+
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import static junit.framework.TestCase.assertEquals;
+
+/**
+ * Test cases for {@link Interpolation}
+ *
+ * @author Christian Skarby
+ */
+public class InterpolationTest {
+
+ /**
+ * Test interpolation by comparing to the result for the same input from original repo
+ */
+ @Test
+ public void interpolate() {
+ final long[] shape = {2, 3, 5};
+ final long nrofElems = 2 * 3 * 5;
+ final INDArray y0 = Nd4j.linspace(-10, 10, nrofElems).reshape(shape);
+ final INDArray yMid = Nd4j.linspace(-7, 13, nrofElems).reshape(shape);
+ final INDArray y1 = Nd4j.linspace(-12, 7, nrofElems).reshape(shape);
+ final INDArray f0 = Nd4j.linspace(2, 5, nrofElems).reshape(shape);
+ final INDArray f1 = Nd4j.linspace(-3, -1, nrofElems).reshape(shape);
+ final INDArray dt = Nd4j.scalar(1.23);
+
+ final Interpolation interpolation = new Interpolation();
+
+ interpolation.fitCoeffs(y0, y1, yMid, f0, f1, dt);
+
+ final INDArray output = interpolation.interpolate(-2, 3, 2.34);
+
+ // Output from pytorch repo
+ final INDArray expected = Nd4j.create(new double[][][]{
+ {{-10.8218, -10.1690, -9.5162, -8.8633, -8.2105},
+ {-7.5577, -6.9049, -6.2521, -5.5993, -4.9465},
+ {-4.2936, -3.6408, -2.9880, -2.3352, -1.6824}},
+ {{-1.0296, -0.3768, 0.2760, 0.9288, 1.5817},
+ {2.2345, 2.8873, 3.5401, 4.1929, 4.8457},
+ {5.4985, 6.1513, 6.8042, 7.4570, 8.1098}}
+ });
+
+ assertEquals("Incorrect output!", output.toString(), expected.toString());
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/conf/helper/backward/AbstractHelperConfTest.java b/src/test/java/ode/vertex/conf/helper/backward/AbstractHelperConfTest.java
new file mode 100644
index 0000000..2b2ad6c
--- /dev/null
+++ b/src/test/java/ode/vertex/conf/helper/backward/AbstractHelperConfTest.java
@@ -0,0 +1,188 @@
+package ode.vertex.conf.helper.backward;
+
+import ode.vertex.impl.gradview.NonContiguous1DView;
+import ode.vertex.impl.helper.backward.OdeHelperBackward.InputArrays;
+import ode.vertex.impl.helper.backward.OdeHelperBackward.MiscPar;
+import org.deeplearning4j.nn.conf.ConvolutionMode;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.WorkspaceMode;
+import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.Convolution2D;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.jetbrains.annotations.NotNull;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import java.io.IOException;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+
+abstract class AbstractHelperConfTest {
+
+ private final static InputType convInput = InputType.convolutional(9,9,2);
+
+ /**
+ * Create a {@link OdeHelperBackward} which shall be tested
+ *
+ * @param nrofTimeSteps Number of time steps to create helper for
+ * @return a new {@link OdeHelperBackward}
+ */
+ abstract OdeHelperBackward create(int nrofTimeSteps, boolean needTimeGradient);
+
+ /**
+ * Create inputs from a given input array
+ *
+ * @param input input array
+ * @param nrofTimeSteps number of time steps to create input for
+ * @return all needed inputs
+ */
+ abstract INDArray[] createInputs(INDArray input, int nrofTimeSteps);
+
+ /**
+ * Test serialization and deserialization
+ */
+ @Test
+ public void serializeDeserializeSingleStep() throws IOException {
+ final OdeHelperBackward conf = create(2, true);
+ final String json = NeuralNetConfiguration.mapper().writeValueAsString(conf);
+ final OdeHelperBackward newConf = NeuralNetConfiguration.mapper().readValue(json, OdeHelperBackward.class);
+ assertEquals("Did not deserialize into the same thing!", conf, newConf);
+ }
+
+ /**
+ * Test serialization and deserialization
+ */
+ @Test
+ public void serializeDeserializeMultiStep() throws IOException {
+ final OdeHelperBackward conf = create(5, true);
+ final String json = NeuralNetConfiguration.mapper().writeValueAsString(conf);
+ final OdeHelperBackward newConf = NeuralNetConfiguration.mapper().readValue(json, OdeHelperBackward.class);
+ assertEquals("Did not deserialize into the same thing!", conf, newConf);
+ }
+
+ /**
+ * Test that helper can be instantiated and that it does something
+ */
+ @Test
+ public void instantiateAndSolveConvSingleStep() {
+ final int nrofTimeSteps = 2;
+ final ode.vertex.impl.helper.backward.OdeHelperBackward helper = create(nrofTimeSteps, true).instantiate();
+ final ComputationGraph graph = createGraph();
+ final InputArrays input = getTestInputArrays(nrofTimeSteps, graph);
+ final INDArray[] output = helper.solve(graph, input, new MiscPar(
+ false,
+ LayerWorkspaceMgr.noWorkspaces()));
+
+
+ assertNotEquals("Expected non-zero param gradient!", 0, graph.getGradientsViewArray().sumNumber().doubleValue() ,1e-10);
+ for(INDArray inputGrad: output) {
+ assertNotEquals("Expected non-zero param gradient!", 0, inputGrad.sumNumber().doubleValue() ,1e-10);
+ }
+ }
+
+ /**
+ * Test that helper can be instantiated and that it does something
+ */
+ @Test
+ public void instantiateAndSolveConvSingleStepNoTimeGrad() {
+ final int nrofTimeSteps = 2;
+ final ode.vertex.impl.helper.backward.OdeHelperBackward helper = create(nrofTimeSteps, false).instantiate();
+ final ComputationGraph graph = createGraph();
+ final InputArrays input = getTestInputArrays(nrofTimeSteps, graph);
+ final INDArray[] output = helper.solve(graph, input, new MiscPar(
+ false,
+ LayerWorkspaceMgr.noWorkspaces()));
+
+
+ assertNotEquals("Expected non-zero param gradient!", 0, graph.getGradientsViewArray().sumNumber().doubleValue() ,1e-10);
+ assertNotEquals("Expected non-zero param gradient!", 0, output[0].amaxNumber().doubleValue() ,1e-10);
+ if(output.length > 1) {
+ assertEquals("Expected zero time gradient!", 0, output[1].amaxNumber().doubleValue(), 1e-10);
+ }
+ }
+
+ /**
+ * Test that helper can be instantiated and that it does something
+ */
+ @Test
+ public void instantiateAndSolveConvMultiStep() {
+ final int nrofTimeSteps = 7;
+ final ode.vertex.impl.helper.backward.OdeHelperBackward helper = create(nrofTimeSteps, true).instantiate();
+ final ComputationGraph graph = createGraph();
+ final InputArrays input = getTestInputArrays(nrofTimeSteps, graph);
+ final INDArray[] output = helper.solve(graph, input, new MiscPar(
+ false,
+ LayerWorkspaceMgr.noWorkspaces()));
+
+ assertNotEquals("Expected non-zero param gradient!", 0, graph.getGradientsViewArray().amaxNumber().doubleValue() ,1e-10);
+ for(INDArray inputGrad: output) {
+ assertNotEquals("Expected non-zero param gradient!", 0, inputGrad.amaxNumber().doubleValue() ,1e-10);
+ }
+ }
+
+ /**
+ * Test that helper can be instantiated and that it does something
+ */
+ @Test
+ public void instantiateAndSolveConvMultiStepNoTimeGrad() {
+ final int nrofTimeSteps = 7;
+ final ode.vertex.impl.helper.backward.OdeHelperBackward helper = create(nrofTimeSteps, false).instantiate();
+ final ComputationGraph graph = createGraph();
+ final InputArrays input = getTestInputArrays(nrofTimeSteps, graph);
+ final INDArray[] output = helper.solve(graph, input, new MiscPar(
+ false,
+ LayerWorkspaceMgr.noWorkspaces()));
+
+ assertNotEquals("Expected non-zero param gradient!", 0, graph.getGradientsViewArray().amaxNumber().doubleValue() ,1e-10);
+ assertNotEquals("Expected non-zero param gradient!", 0, output[0].amaxNumber().doubleValue() ,1e-10);
+ if(output.length > 1) {
+ assertEquals("Expected zero time gradient!", 0, output[1].amaxNumber().doubleValue(), 1e-10);
+ }
+ }
+
+ private ComputationGraph createGraph() {
+ final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
+ .trainingWorkspaceMode(WorkspaceMode.NONE)
+ .inferenceWorkspaceMode(WorkspaceMode.NONE)
+ .weightInit(new ConstantDistribution(0.01))
+ .graphBuilder()
+ .allowNoOutput(true)
+ .addInputs("input")
+ .setInputTypes(convInput)
+ .addLayer("1", new Convolution2D.Builder(3, 3).nOut(2).convolutionMode(ConvolutionMode.Same).build(), "input")
+ .build());
+ graph.init();
+ graph.initGradientsView();
+ return graph;
+ }
+
+ @NotNull
+ private InputArrays getTestInputArrays(int nrofTimeSteps, ComputationGraph graph) {
+
+ final int batchSize = 3;
+ final long[] shape = convInput.getShape(true);
+ shape[0] = batchSize;
+
+ final int nrofTimeStepsToUse = nrofTimeSteps == 2 ? 1 : nrofTimeSteps;
+ final long[] outputShape = nrofTimeSteps == 2 ? shape : new long[]{ batchSize, nrofTimeStepsToUse, shape[1], shape[2], shape[3]};
+
+
+ final INDArray input = Nd4j.arange(batchSize * convInput.arrayElementsPerExample()).reshape(shape);
+ final INDArray output = Nd4j.arange(batchSize * nrofTimeStepsToUse * convInput.arrayElementsPerExample())
+ .reshape(outputShape);
+ final INDArray epsilon = Nd4j.ones(outputShape).assign(0.01);
+ final NonContiguous1DView realGrads = new NonContiguous1DView();
+ realGrads.addView(graph.getGradientsViewArray());
+
+ return new InputArrays(
+ createInputs(input, nrofTimeSteps),
+ output,
+ epsilon,
+ realGrads
+ );
+ }
+}
diff --git a/src/test/java/ode/vertex/conf/helper/backward/FixedStepAdjointTest.java b/src/test/java/ode/vertex/conf/helper/backward/FixedStepAdjointTest.java
new file mode 100644
index 0000000..8cf86e4
--- /dev/null
+++ b/src/test/java/ode/vertex/conf/helper/backward/FixedStepAdjointTest.java
@@ -0,0 +1,23 @@
+package ode.vertex.conf.helper.backward;
+
+import ode.solve.conf.DormandPrince54Solver;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+/**
+ * Test cases for {@link FixedStepAdjoint}
+ *
+ * @author Christian Skarby
+ */
+public class FixedStepAdjointTest extends AbstractHelperConfTest {
+
+ @Override
+ OdeHelperBackward create(int nrofTimeSteps, boolean needTimeGradient) {
+ return new FixedStepAdjoint(new DormandPrince54Solver(), Nd4j.linspace(0,3,nrofTimeSteps));
+ }
+
+ @Override
+ INDArray[] createInputs(INDArray input, int nrofTimeSteps) {
+ return new INDArray[] {input};
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/conf/helper/backward/InputStepAdjointTest.java b/src/test/java/ode/vertex/conf/helper/backward/InputStepAdjointTest.java
new file mode 100644
index 0000000..c5d6ed8
--- /dev/null
+++ b/src/test/java/ode/vertex/conf/helper/backward/InputStepAdjointTest.java
@@ -0,0 +1,23 @@
+package ode.vertex.conf.helper.backward;
+
+import ode.solve.conf.DormandPrince54Solver;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+/**
+ * Test cases for {@link InputStepAdjoint}
+ *
+ * @author Christian Skarby
+ */
+public class InputStepAdjointTest extends AbstractHelperConfTest {
+
+ @Override
+ OdeHelperBackward create(int nrofTimeSteps, boolean needTimeGrad) {
+ return new InputStepAdjoint(new DormandPrince54Solver(), 1, needTimeGrad);
+ }
+
+ @Override
+ INDArray[] createInputs(INDArray input, int nrofTimeSteps) {
+ return new INDArray[]{input, Nd4j.linspace(0, 2, nrofTimeSteps)};
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/conf/helper/forward/AbstractHelperConfTest.java b/src/test/java/ode/vertex/conf/helper/forward/AbstractHelperConfTest.java
new file mode 100644
index 0000000..5bdaf38
--- /dev/null
+++ b/src/test/java/ode/vertex/conf/helper/forward/AbstractHelperConfTest.java
@@ -0,0 +1,64 @@
+package ode.vertex.conf.helper.forward;
+
+import org.deeplearning4j.nn.conf.ConvolutionMode;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.Convolution2D;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import java.io.IOException;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+
+abstract class AbstractHelperConfTest {
+
+ /**
+ * Create a {@link OdeHelperForward} which shall be tested
+ *
+ * @return a new {@link OdeHelperForward}
+ */
+ abstract OdeHelperForward create();
+
+ /**
+ * Create inputs from a given input array
+ *
+ * @param input input array
+ * @return all needed inputs
+ */
+ abstract INDArray[] createInputs(INDArray input);
+
+ /**
+ * Test serialization and deserialization
+ */
+ @Test
+ public void serializeDeserialize() throws IOException {
+ final OdeHelperForward conf = create();
+ final String json = NeuralNetConfiguration.mapper().writeValueAsString(conf);
+ final OdeHelperForward newConf = NeuralNetConfiguration.mapper().readValue(json, OdeHelperForward.class);
+ assertEquals("Did not deserialize into the same thing!", conf, newConf);
+ }
+
+ @Test
+ public void instantiateAndSolveConv() {
+ final ode.vertex.impl.helper.forward.OdeHelperForward helper = create().instantiate();
+ final INDArray output = helper.solve(createGraph(), LayerWorkspaceMgr.noWorkspaces(), createInputs(Nd4j.randn(new long[] {5, 2, 3, 3})));
+ assertNotEquals("Expected non-zero output!", 0, output.sumNumber().doubleValue() ,1e-10);
+ }
+
+ private ComputationGraph createGraph() {
+ final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
+ .graphBuilder()
+ .allowNoOutput(true)
+ .addInputs("input")
+ .setInputTypes(InputType.convolutional(9, 9, 2))
+ .addLayer("1", new Convolution2D.Builder(3, 3).nOut(2).convolutionMode(ConvolutionMode.Same).build(), "input")
+ .build());
+ graph.init();
+ return graph;
+ }
+}
diff --git a/src/test/java/ode/vertex/conf/helper/forward/FixedStepTest.java b/src/test/java/ode/vertex/conf/helper/forward/FixedStepTest.java
new file mode 100644
index 0000000..765aa62
--- /dev/null
+++ b/src/test/java/ode/vertex/conf/helper/forward/FixedStepTest.java
@@ -0,0 +1,24 @@
+package ode.vertex.conf.helper.forward;
+
+import ode.solve.conf.DormandPrince54Solver;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+/**
+ * Test cases for {@link FixedStep}
+ *
+ * @author Christian Skarby
+ */
+public class FixedStepTest extends AbstractHelperConfTest{
+
+
+ @Override
+ OdeHelperForward create() {
+ return new FixedStep(new DormandPrince54Solver(), Nd4j.linspace(0,3,4), false);
+ }
+
+ @Override
+ INDArray[] createInputs(INDArray input) {
+ return new INDArray[] {input};
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/conf/helper/forward/InputStepTest.java b/src/test/java/ode/vertex/conf/helper/forward/InputStepTest.java
new file mode 100644
index 0000000..d839272
--- /dev/null
+++ b/src/test/java/ode/vertex/conf/helper/forward/InputStepTest.java
@@ -0,0 +1,24 @@
+package ode.vertex.conf.helper.forward;
+
+import ode.solve.conf.DormandPrince54Solver;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+
+/**
+ * Test cases for {@link InputStep}
+ *
+ * @author Christian Skarby
+ */
+public class InputStepTest extends AbstractHelperConfTest {
+
+ @Override
+ OdeHelperForward create() {
+ return new InputStep(new DormandPrince54Solver(), 1, false);
+ }
+
+ @Override
+ INDArray[] createInputs(INDArray input) {
+ return new INDArray[] {input, Nd4j.linspace(0, 10, 5)};
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/NonContiguous1DViewTest.java b/src/test/java/ode/vertex/impl/NonContiguous1DViewTest.java
deleted file mode 100644
index 28c6ef3..0000000
--- a/src/test/java/ode/vertex/impl/NonContiguous1DViewTest.java
+++ /dev/null
@@ -1,50 +0,0 @@
-package ode.vertex.impl;
-
-import org.junit.Test;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-
-import static org.junit.Assert.assertEquals;
-
-/**
- * Test cases for {@link NonContiguous1DView}
- *
- * @author Christian Skarby
- */
-public class NonContiguous1DViewTest {
-
- /**
- * Test assignment of {@link NonContiguous1DView} from another {@link INDArray}. Set first and last three elements
- * and leave three in the middle untouched
- */
- @Test
- public void assignFrom() {
- final INDArray toView = Nd4j.zeros(9);
-
- final NonContiguous1DView view = new NonContiguous1DView();
- view.addView(toView, 0, 3);
- view.addView(toView, 6,9);
- view.assignFrom(Nd4j.ones(new long[] {6}));
-
- final INDArray expected = Nd4j.create(new double[] {1,1,1,0,0,0,1,1,1});
- assertEquals("Viewed array was not changed!", expected, toView);
- }
-
- /**
- * Test assignment to another {@link INDArray} from a {@link NonContiguous1DView}.
- */
- @Test
- public void assignTo() {
- final INDArray toView = Nd4j.linspace(0,8,9);
-
- final NonContiguous1DView view = new NonContiguous1DView();
- view.addView(toView, 0, 3);
- view.addView(toView, 6,9);
-
- final INDArray actual = Nd4j.create(view.length());
- view.assignTo(actual);
-
- final INDArray expected = Nd4j.create(new double[] {0,1,2,6,7,8});
- assertEquals("Viewed array was not changed!", expected, actual);
- }
-}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/OdeVertexTest.java b/src/test/java/ode/vertex/impl/OdeVertexTest.java
index d5e39c8..9ec5558 100644
--- a/src/test/java/ode/vertex/impl/OdeVertexTest.java
+++ b/src/test/java/ode/vertex/impl/OdeVertexTest.java
@@ -1,5 +1,7 @@
package ode.vertex.impl;
+import ode.solve.conf.DormandPrince54Solver;
+import ode.vertex.conf.helper.InputStep;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
@@ -10,6 +12,7 @@
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
@@ -93,4 +96,34 @@ public void fit() {
graph.fit(new DataSet(Nd4j.randn(new long[]{1, 1, 9, 9}), Nd4j.create(new double[] {0,1,0})));
assertNotEquals("Expected parameters to be updated!", before, graph.getVertex("1").params().dup());
}
+
+ /**
+ * Smoke test to see that it is possible to do a forward pass when time is one of the inputs
+ */
+ @Test
+ public void fitWithTimeAsInput() {
+ final long nOut = 8;
+ final long nrofTimeSteps = 10;
+ final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
+ .graphBuilder()
+ .addInputs("input", "time")
+ .setInputTypes(InputType.feedForward(5), InputType.feedForward(nrofTimeSteps))
+ .addLayer("0", new DenseLayer.Builder().nOut(nOut).build(), "input")
+ .addVertex("odeVertex", new ode.vertex.conf.OdeVertex.Builder("ode0",
+ new DenseLayer.Builder().nOut(nOut).build())
+ .odeConf(new InputStep(new DormandPrince54Solver(), 1))
+ .build(), "0", "time")
+ .setOutputs("output")
+ .addLayer("output", new RnnOutputLayer.Builder().nOut(3).build(), "odeVertex")
+ .build());
+
+ graph.init();
+
+ final INDArray before = graph.getVertex("odeVertex").params().dup();
+ final int batchSize = 3;
+ graph.fit(new MultiDataSet(
+ new INDArray[] {Nd4j.randn(new long[]{batchSize, 5}), Nd4j.linspace(0, 2, nrofTimeSteps)},
+ new INDArray[] {Nd4j.repeat(Nd4j.create(new double[] {0,1,0}).transposei(), batchSize*(int)nrofTimeSteps)}));
+ assertNotEquals("Expected parameters to be updated!", before, graph.getVertex("odeVertex").params().dup());
+ }
}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/gradview/Contiguous1DViewTest.java b/src/test/java/ode/vertex/impl/gradview/Contiguous1DViewTest.java
new file mode 100644
index 0000000..0612839
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/gradview/Contiguous1DViewTest.java
@@ -0,0 +1,54 @@
+package ode.vertex.impl.gradview;
+
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test cases for {@link Contiguous1DView}
+ *
+ * @author Christian Skarby
+ */
+public class Contiguous1DViewTest {
+
+ /**
+ * Test assignment of {@link Contiguous1DView} from another {@link INDArray}. Set first and last three elements
+ * and leave three in the middle untouched
+ */
+ @Test
+ public void assignFrom() {
+ final INDArray toView = Nd4j.ones(13);
+ final INDArray other = Nd4j.zeros(toView.shape()).reshape(toView.length());
+
+ final INDArray1DView view = new Contiguous1DView(toView);
+
+ view.assignFrom(other);
+ assertEquals("View not set!", other, toView);
+ }
+
+ /**
+ * Test assignment to another {@link INDArray} from a {@link Contiguous1DView}.
+ */
+ @Test
+ public void assignTo() {
+ final INDArray toView = Nd4j.ones(13);
+ final INDArray other = Nd4j.zeros(toView.shape()).reshape(toView.length());
+
+ final INDArray1DView view = new Contiguous1DView(toView);
+
+ view.assignTo(other);
+ assertEquals("View not set!", other, toView);
+ }
+
+ /**
+ * Test that length is correct
+ */
+ @Test
+ public void length() {
+ final long length = 27;
+
+ assertEquals("Incorrect length!", length, new Contiguous1DView(Nd4j.ones(length)).length());
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklistedTest.java b/src/test/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklistedTest.java
new file mode 100644
index 0000000..3a11f55
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/gradview/GradientViewSelectionFromBlacklistedTest.java
@@ -0,0 +1,125 @@
+package ode.vertex.impl.gradview;
+
+import org.deeplearning4j.nn.conf.ConvolutionMode;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.BatchNormalization;
+import org.deeplearning4j.nn.conf.layers.Convolution2D;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.jetbrains.annotations.NotNull;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.primitives.Pair;
+import org.nd4j.shade.jackson.databind.ObjectMapper;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.assertTrue;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertFalse;
+
+/**
+ * Test cases for {@link GradientViewSelectionFromBlacklisted}
+ *
+ * @author Christian Skarby
+ */
+public class GradientViewSelectionFromBlacklistedTest {
+
+ /**
+ * Test that gradients in the black list are not selected
+ */
+ @Test
+ public void createWithBlacklisted() {
+ final ComputationGraph graph = createGraph();
+ final GradientViewFactory factory = new GradientViewSelectionFromBlacklisted();
+ final ParameterGradientView gradView = factory.create(graph);
+
+ assertEquals("Incorrect number of parameter gradients in view!",
+ graph.getGradientsViewArray().length() - graph.getLayer("1").getGradientsViewArray().length() / 2,
+ gradView.realGradientView().length());
+
+ for(Map.Entry nameGradEntry: gradView.allGradientsPerParam().gradientForVariable().entrySet()){
+ final Pair vertexAndParName = factory.paramNameMapping().reverseMap(nameGradEntry.getKey());
+ final long[] expectedShape = graph.getLayer(vertexAndParName.getFirst()).getParam(vertexAndParName.getSecond()).shape();
+ assertArrayEquals("Incorrect grad size for " + nameGradEntry.getKey() + "!",
+ expectedShape,
+ nameGradEntry.getValue().shape());
+ }
+ }
+
+ /**
+ * Test that gradients in the black list are not selected
+ */
+ @Test
+ public void createNoBlackList() {
+ final ComputationGraph graph = createGraph();
+ final GradientViewFactory factory =new GradientViewSelectionFromBlacklisted(new ArrayList<>());
+ final ParameterGradientView gradView = factory.create(graph);
+
+ assertEquals("Incorrect number of parameter gradients in view!",
+ graph.getGradientsViewArray().length(),
+ gradView.realGradientView().length());
+
+ for(Map.Entry nameGradEntry: gradView.allGradientsPerParam().gradientForVariable().entrySet()){
+ final Pair vertexAndParName = factory.paramNameMapping().reverseMap(nameGradEntry.getKey());
+ final long[] expectedShape = graph.getLayer(vertexAndParName.getFirst()).getParam(vertexAndParName.getSecond()).shape();
+ assertArrayEquals("Incorrect grad size for " + nameGradEntry.getKey() + "!",
+ expectedShape,
+ nameGradEntry.getValue().shape());
+ }
+ }
+
+ /**
+ * Test that a clone is equal to the original
+ */
+ @Test
+ public void clonetest() {
+ final GradientViewFactory factory = new GradientViewSelectionFromBlacklisted(Arrays.asList("ff", "gg"));
+ assertTrue("Clones shall be equal!" , factory.equals(factory.clone()));
+ }
+
+ /**
+ * Test equals
+ */
+ @Test
+ public void equals() {
+ assertTrue("Shall be equal!", new GradientViewSelectionFromBlacklisted().equals(new GradientViewSelectionFromBlacklisted()));
+ assertTrue("Shall be equal!",
+ new GradientViewSelectionFromBlacklisted(Arrays.asList("aa", "bb")).equals(
+ new GradientViewSelectionFromBlacklisted(Arrays.asList("aa", "bb"))));
+ assertFalse("Shall not be equal!",
+ new GradientViewSelectionFromBlacklisted(Arrays.asList("aa", "bb")).equals(
+ new GradientViewSelectionFromBlacklisted(Arrays.asList("aa", "bb", "cc"))));
+ }
+
+ /**
+ * Test that a {@link GradientViewSelectionFromBlacklisted} can be serialized and then deserialized into the same thing
+ */
+ @Test
+ public void serializeDeserialize() throws IOException {
+ final GradientViewFactory factory = new GradientViewSelectionFromBlacklisted(Arrays.asList("qq", "ww"));
+ final String json = new ObjectMapper().writeValueAsString(factory);
+ final GradientViewFactory deserialized = new ObjectMapper().readValue(json, GradientViewSelectionFromBlacklisted.class);
+ assertTrue("Did not deserialize to the same thing!", factory.equals(deserialized));
+ }
+
+ @NotNull
+ ComputationGraph createGraph() {
+ final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
+ .graphBuilder()
+ .addInputs("input")
+ .addLayer("0", new Convolution2D.Builder().convolutionMode(ConvolutionMode.Same).nOut(3).build(), "input")
+ .addLayer("1", new BatchNormalization.Builder().nOut(3).build(), "0")
+ .addLayer("2", new Convolution2D.Builder().convolutionMode(ConvolutionMode.Same).nOut(3).build(), "1")
+ .setOutputs("2")
+ .setInputTypes(InputType.convolutional(5, 5, 3))
+ .build());
+ graph.init();
+ graph.initGradientsView();
+ return graph;
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/gradview/NonContiguous1DViewTest.java b/src/test/java/ode/vertex/impl/gradview/NonContiguous1DViewTest.java
new file mode 100644
index 0000000..d240d72
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/gradview/NonContiguous1DViewTest.java
@@ -0,0 +1,102 @@
+package ode.vertex.impl.gradview;
+
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test cases for {@link NonContiguous1DView}
+ *
+ * @author Christian Skarby
+ */
+public class NonContiguous1DViewTest {
+
+ /**
+ * Test assignment of {@link NonContiguous1DView} from another {@link INDArray}. Set first and last three elements
+ * and leave three in the middle untouched
+ */
+ @Test
+ public void assignFrom() {
+ final INDArray toView = Nd4j.zeros(9);
+
+ final NonContiguous1DView view = new NonContiguous1DView();
+ view.addView(toView.get(NDArrayIndex.interval(0, 3)));
+ view.addView(toView.get(NDArrayIndex.interval(6, 9)));
+ view.assignFrom(Nd4j.ones(new long[] {6}));
+
+ final INDArray expected = Nd4j.create(new double[] {1,1,1,0,0,0,1,1,1});
+ assertEquals("Viewed array was not changed!", expected, toView);
+ }
+
+ /**
+ * Test assignment to another {@link INDArray} from a {@link NonContiguous1DView}.
+ */
+ @Test
+ public void assignTo() {
+ final INDArray toView = Nd4j.linspace(0,8,9);
+
+ final NonContiguous1DView view = new NonContiguous1DView();
+ view.addView(toView.get(NDArrayIndex.interval(0, 3)));
+ view.addView(toView.get(NDArrayIndex.interval(6, 9)));
+
+ final INDArray actual = Nd4j.create(view.length());
+ view.assignTo(actual);
+
+ final INDArray expected = Nd4j.create(new double[] {0,1,2,6,7,8});
+ assertEquals("Viewed array was not changed!", expected, actual);
+ }
+
+ /**
+ * Test that length is correct
+ */
+ @Test
+ public void length() {
+ final long length0 = 13;
+ final long length1 = 7;
+
+ final INDArray toView = Nd4j.ones(length0+length1+10).reshape(length0+length1+10);
+
+ final NonContiguous1DView view = new NonContiguous1DView();
+ view.addView(toView.get(NDArrayIndex.interval(0, length0)));
+ view.addView(toView.get(NDArrayIndex.interval(length0, length0+length1)));
+
+ assertEquals("Incorrect length!", length0 + length1, view.length());
+ }
+
+ /**
+ * Test assignment of {@link NonContiguous1DView} with one 2x3x4 view and one 2x2 view from another {@link INDArray}.
+ */
+ @Test
+ public void assignFrom2x3x4and2x2() {
+ final INDArray toView = Nd4j.zeros(2*3*4*5);
+
+ final NonContiguous1DView view = new NonContiguous1DView();
+ view.addView(toView.get(NDArrayIndex.interval(0, 2*3*4)).reshape(2,3,4));
+ view.addView(toView.get(NDArrayIndex.interval(2*3*4+4, 2*3*4+8)).reshape(2,2));
+ view.assignFrom(Nd4j.ones(view.length()));
+
+ final double expectedSum = view.length();
+ assertEquals("Viewed array was not changed!", expectedSum, toView.sumNumber().doubleValue(),1e-10);
+ }
+
+ /**
+ * Test assignment to another {@link INDArray} from a {@link NonContiguous1DView}.
+ */
+ @Test
+ public void assignTo2x3x4() {
+ final INDArray toView = Nd4j.linspace(0,36,37);
+
+ final NonContiguous1DView view = new NonContiguous1DView();
+ view.addView(toView.get(NDArrayIndex.interval(0, 12)).reshape(2,3,2));
+ view.addView(toView.get(NDArrayIndex.interval(17, 17+10)).reshape(5,2));
+
+ final INDArray actual = Nd4j.create(view.length());
+ view.assignTo(actual);
+
+ final INDArray expected = Nd4j.create(new double[] {0,1,2,3,4,5,6,7,8,9,10,11, 17,18,19,20,21,22,23,24,25,26});
+ assertEquals("Viewed array was not changed!", expected, actual);
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/gradview/parname/ConcatTest.java b/src/test/java/ode/vertex/impl/gradview/parname/ConcatTest.java
new file mode 100644
index 0000000..53f6d63
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/gradview/parname/ConcatTest.java
@@ -0,0 +1,38 @@
+package ode.vertex.impl.gradview.parname;
+
+import org.junit.Test;
+import org.nd4j.linalg.primitives.Pair;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test cases for {@link Concat}
+ *
+ * @author Christian Skarby
+ */
+public class ConcatTest {
+
+ /**
+ * Test that input is concatenated
+ */
+ @Test
+ public void map() {
+ assertEquals("vertex-param", new Concat().map("vertex", "param"));
+ }
+
+ /**
+ * Test that input is de-concatenated
+ */
+ @Test
+ public void reverseMap() {
+ assertEquals(new Pair<>("vertex", "param"), new Concat().reverseMap("vertex-param"));
+ }
+
+ /**
+ * Test that an error is thrown for illegal input
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void reverseMapIllegal() {
+ new Concat().reverseMap("aaa-vertex-param");
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/gradview/parname/PrefixTest.java b/src/test/java/ode/vertex/impl/gradview/parname/PrefixTest.java
new file mode 100644
index 0000000..6dfc909
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/gradview/parname/PrefixTest.java
@@ -0,0 +1,30 @@
+package ode.vertex.impl.gradview.parname;
+
+import org.junit.Test;
+import org.nd4j.linalg.primitives.Pair;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test cases for {@link Prefix}
+ *
+ * @author Christian Skarby
+ */
+public class PrefixTest {
+
+ /**
+ * Test that a prefix is added
+ */
+ @Test
+ public void map() {
+ assertEquals("prefix_vertex-param", new Prefix("prefix_", new Concat()).map("vertex", "param"));
+ }
+
+ /**
+ * Test that prefixing is reversed correctly
+ */
+ @Test
+ public void reverseMap() {
+ assertEquals(new Pair<>("vertex", "param"), new Prefix("prefix_", new Concat()).reverseMap("prefix_vertex-param"));
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/helper/backward/MultiStepAdjointTest.java b/src/test/java/ode/vertex/impl/helper/backward/MultiStepAdjointTest.java
new file mode 100644
index 0000000..562a9ca
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/helper/backward/MultiStepAdjointTest.java
@@ -0,0 +1,327 @@
+package ode.vertex.impl.helper.backward;
+
+import ode.solve.conf.SolverConfig;
+import ode.solve.impl.DormandPrince54Solver;
+import ode.vertex.conf.OdeVertex;
+import ode.vertex.conf.helper.InputStep;
+import ode.vertex.impl.gradview.NonContiguous1DView;
+import ode.vertex.impl.helper.backward.timegrad.CalcMultiStepTimeGrad;
+import ode.vertex.impl.helper.backward.timegrad.NoMultiStepTimeGrad;
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.graph.vertex.GraphVertex;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.jetbrains.annotations.NotNull;
+import org.junit.Test;
+import org.nd4j.linalg.activations.impl.ActivationIdentity;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.primitives.Pair;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertNotEquals;
+
+/**
+ * Test cases for {@link MultiStepAdjoint}
+ *
+ * @author Christian Skarby
+ */
+public class MultiStepAdjointTest {
+
+ /**
+ * Smoke test for backwards solve without time gradients
+ */
+ @Test
+ public void solveNoTime() {
+ final int nrofInputs = 7;
+ final int nrofTimeSteps = 5;
+ final ComputationGraph graph = SingleStepAdjointTest.getTestGraph(nrofInputs);
+ final OdeHelperBackward.InputArrays inputArrays = getTestInputArrays(nrofInputs, nrofTimeSteps, graph);
+
+ final INDArray time = Nd4j.arange(nrofTimeSteps);
+ final OdeHelperBackward helper = new MultiStepAdjoint(
+ new DormandPrince54Solver(new SolverConfig(1e-3, 1e-3, 0.1, 10)),
+ time, NoMultiStepTimeGrad.factory);
+
+ INDArray[] gradients = helper.solve(graph, inputArrays, new OdeHelperBackward.MiscPar(
+ false,
+ LayerWorkspaceMgr.noWorkspacesImmutable()));
+
+ assertEquals("Incorrect number of input gradients!", 1, gradients.length);
+
+ final INDArray inputGrad = gradients[0];
+ assertArrayEquals("Incorrect input gradient shape!", inputArrays.getLastInputs()[0].shape(), inputGrad.shape());
+
+ final INDArray parGrad = graph.getGradientsViewArray();
+ assertNotEquals("Expected non-zero parameter gradient!", 0.0, parGrad.maxNumber().doubleValue(), 1e-10);
+ assertNotEquals("Expected non-zero input gradient!", 0.0, inputGrad.maxNumber().doubleValue(), 1e-10);
+ }
+
+ /**
+ * Smoke test for backwards solve without time gradients
+ */
+ @Test
+ public void solveWithTime() throws InterruptedException {
+ final int nrofInputs = 5;
+ final int nrofTimeSteps = 7;
+ final ComputationGraph graph = SingleStepAdjointTest.getTestGraph(nrofInputs);
+ final OdeHelperBackward.InputArrays inputArrays = getTestInputArrays(nrofInputs, nrofTimeSteps, graph);
+
+ final INDArray time = Nd4j.arange(nrofTimeSteps);
+ final OdeHelperBackward helper = new MultiStepAdjoint(
+ new DormandPrince54Solver(new SolverConfig(1e-3, 1e-3, 0.1, 10)),
+ time, new CalcMultiStepTimeGrad.Factory(time, 1));
+
+ INDArray[] gradients = helper.solve(graph, inputArrays, new OdeHelperBackward.MiscPar(
+ false,
+ LayerWorkspaceMgr.noWorkspacesImmutable()));
+
+
+ assertEquals("Incorrect number of input gradients!", 2, gradients.length);
+
+ final INDArray inputGrad = gradients[0];
+ final INDArray timeGrad = gradients[1];
+ assertArrayEquals("Incorrect input gradient shape!", inputArrays.getLastInputs()[0].shape(), inputGrad.shape());
+ assertArrayEquals("Incorrect time gradient shape!", time.shape(), timeGrad.shape());
+
+ final INDArray parGrad = graph.getGradientsViewArray();
+ assertNotEquals("Expected non-zero parameter gradient!", 0.0, parGrad.maxNumber().doubleValue(), 1e-10);
+ assertNotEquals("Expected non-zero input gradient!", 0.0, inputGrad.maxNumber().doubleValue(), 1e-10);
+ assertNotEquals("Expected non-zero time gradient!", 0.0, timeGrad.maxNumber().doubleValue(), 1e-10);
+ }
+
+ @NotNull
+ private static OdeHelperBackward.InputArrays getTestInputArrays(int nrofInputs, int nrofTimeSteps, ComputationGraph graph) {
+ final int batchSize = 3;
+ final INDArray input = Nd4j.arange(batchSize * nrofInputs).reshape(batchSize, nrofInputs);
+ final INDArray output = Nd4j.arange(batchSize * nrofInputs * nrofTimeSteps).reshape(batchSize, nrofInputs, nrofTimeSteps);
+ final INDArray epsilon = Nd4j.ones(batchSize, nrofInputs, nrofTimeSteps).assign(0.01);
+ final NonContiguous1DView realGrads = new NonContiguous1DView();
+ realGrads.addView(graph.getGradientsViewArray());
+
+ return new OdeHelperBackward.InputArrays(
+ new INDArray[]{input},
+ output,
+ epsilon,
+ realGrads
+ );
+ }
+
+ /**
+ * Test that an exception is thrown if time array is not sorted
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void timeNotSorted() {
+ final INDArray time = Nd4j.create(new double[]{0, 1, 2, 1.5, 3, 4});
+ new MultiStepAdjoint(new DormandPrince54Solver(
+ new SolverConfig(1, 1, 1e-1, 1)),
+ time, new CalcMultiStepTimeGrad.Factory(time, 1));
+ }
+
+ /**
+ * Test the result vs the result from the original repo. Reimplementation of test_adjoint in
+ * https://github.com/rtqichen/torchdiffeq/blob/master/tests/gradient_tests.py.
+ *
+ * Numbers below from the following test ODE:
+ *
+ *
+ * class Linear1DODE(torch.nn.Module):
+ *
+ * def __init__(self, device):
+ * super(Linear1DODE, self).__init__()
+ * self.a = torch.nn.Parameter(torch.tensor([0.2, 0.2, 0.2, 0.2]).reshape(2,2).to(device))
+ * self.b = torch.nn.Parameter(torch.tensor([3.0, 3.0]).reshape(1,2).to(device))
+ *
+ * def forward(self, t, y):
+ * return y.matmul(self.a)+ self.b
+ *
+ * def y_exact(self, t):
+ * return t
+ * Test code (to set same inputs as below):
+ * f, y0, t_points, _ = construct_problem(TEST_DEVICE, ode='linear1D')
+ *
+ * y0 = torch.linspace(-1.23,2.34,4).reshape(2, 2).type_as(y0)
+ * y0.requires_grad = True
+ *
+ * print('t: ', t_points)
+ * print('y0: ', y0)
+ *
+ * params = (list(f.parameters()))
+ * optimizer = torch.optim.SGD(params, lr=0.0001)
+ *
+ *
+ * optimizer.zero_grad()
+ * func = lambda y0, t_points: torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
+ * ys = func(y0, t_points).permute(1, 2, 0)
+ *
+ * gradys = torch.linspace(3, 13, ys.numel()).type_as(ys).reshape(ys.size())
+ * for i in range (0, int(gradys.size()[0]/2)):
+ * gradys[i*2, :, i*2+1] = -gradys[i*2,:,i*2+1]
+ *
+ * ys.backward(gradys)
+ *
+ */
+ @Test
+ public void testGradVsReferenceLinear1dOdeForward() {
+
+ final long nrofTimeSteps = 10;
+ final long nrofDims = 2;
+
+ final GraphVertex odevert = createOdeVertex(nrofTimeSteps, nrofDims);
+
+ final INDArray time = Nd4j.linspace(1, 8, nrofTimeSteps);
+ final INDArray y0 = Nd4j.linspace(-1.23, 2.34, 2 * nrofDims).reshape(2, nrofDims);
+
+ odevert.setInputs(y0, time);
+ final INDArray ys = odevert.doForward(true, LayerWorkspaceMgr.noWorkspacesImmutable());
+
+ // From original repo
+ final double[][][] expectedYs = {{{-1.23, 1.2753194351313588, 4.694932254437282, 9.36250433337755, 15.73345986285669, 24.429438907671752, 36.29894018833409, 52.50011823313958, 74.61376567400269, 104.79757535834692},
+ {-0.040000000000000036, 2.4653194351313603, 5.884932254437281, 10.552504333377543, 16.92345986285671, 25.61943890767168, 37.488940188334105, 53.69011823313958, 75.8037656740027, 105.98757535834685}},
+ {{1.15, 4.523878831433274, 9.129023844467975, 15.414778231911933, 23.994455416185012, 35.70520942482531, 51.689701681157864, 73.50760277718506, 103.28774415967294, 143.93586077027197},
+ {2.34, 5.7138788314332745, 10.319023844467976, 16.604778231911936, 25.184455416185003, 36.89520942482521, 52.879701681157854, 74.69760277718505, 104.47774415967294, 145.12586077027197}}};
+
+ for (int i = 0; i < expectedYs.length; i++) {
+ for (int j = 0; j < expectedYs.length; j++) {
+ assertArrayEquals("Incorrect ys: ",
+ expectedYs[i][j],
+ ys.reshape(ys.size(0), nrofDims, nrofTimeSteps).getRow(i).getRow(j).toDoubleVector(), 1e-3);
+ }
+ }
+
+ final INDArray lossgrad = Nd4j.linspace(3, 13, ys.length()).reshape(ys.shape());
+ for (int i = 0; i < lossgrad.size(0) / 2; i++) {
+ lossgrad.get(NDArrayIndex.point(i * 2), NDArrayIndex.all(), NDArrayIndex.point(i * 2 + 1)).negi();
+ }
+
+ odevert.setEpsilon(lossgrad.reshape(ys.shape()));
+ Pair grads = odevert.doBackward(false, LayerWorkspaceMgr.noWorkspacesImmutable());
+
+ final double[][] expectedYsGrad = {{330.3413671858052, 350.8541876986256}, {641.5288548171546, 667.1698804581804}};
+ for (int i = 0; i < expectedYsGrad.length; i++) {
+ assertArrayEquals("Incorrect loss gradient: ",
+ expectedYsGrad[i],
+ grads.getSecond()[0].getRow(i).toDoubleVector(), 1e-3);
+ }
+
+ // Note: Order of element 1 and 2 are swapped
+ final double[] expectedParsGrad = {26399.28182908193, 28805.83177942673, 29982.48959443847, 32597.88284962647, 2022.3108828170273, 2197.8094583156044};
+ // Compare one by one due to large dynamic range
+ for (int i = 0; i < expectedParsGrad.length; i++) {
+ assertEquals("Incorrect parameter gradient for param " + i + "!",
+ 1,
+ grads.getFirst().gradient().getDouble(i) / expectedParsGrad[i], 1e-4);
+ }
+
+ // First element from reference implementation sure looks weird...
+ final double[] expectedTimeGrad = {-6617.017543489908, 63.564528808863464, 185.79311916695133, 262.00021152296233, 369.0851188922797, 519.4357040639363,
+ 730.3690674996153, 1026.0795986642863, 1440.3518309918456, 2020.3383638791681};
+ // Compare one by one due to large dynamic range
+ for (int i = 0; i < expectedTimeGrad.length; i++) {
+ assertEquals("Incorrect time gradient for param " + i + "!", 1, grads.getSecond()[1].getDouble(i) / expectedTimeGrad[i], 1e-4);
+ }
+ }
+
+ /**
+ * Same test as above but backwards
+ */
+ @Test
+ public void testGradVsReferenceLinear1dOdeBackward() {
+
+ final long nrofTimeSteps = 10;
+ final long nrofDims = 2;
+
+ final GraphVertex odevert = createOdeVertex(nrofTimeSteps, nrofDims);
+
+ final INDArray time = Nd4j.linspace(-1, -8, nrofTimeSteps);
+ final INDArray y0 = Nd4j.linspace(-1.23, 2.34, 2 * nrofDims).reshape(2, nrofDims);
+
+ odevert.setInputs(y0, time);
+ final INDArray ys = odevert.doForward(true, LayerWorkspaceMgr.noWorkspacesImmutable());
+
+ // From original repo
+ final double[][][] expectedYs = {{{-1.23, -3.0654783048539187, -4.410209436983333, -5.395402672421891, -6.117186729574334, -6.645989569834943, -7.033407758920375, -7.317242594011386, -7.525189149078903, -7.677538128370835},
+ {-0.040000000000000036, -1.875478304853917, -3.2202094369833336, -4.205402672421891, -4.927186729574334, -5.4559895698349425, -5.843407758920375, -6.127242594011386, -6.335189149078852, -6.487538128370829}},
+ {{1.15, -1.321813099544714, -3.132743808435675, -4.459489833436336, -5.4315063823619125, -6.143637811088716, -6.665368496900053, -7.047604920850001, -7.327643653785053, -7.532809904849017},
+ {2.34, -0.13181309954471376, -1.9427438084356758, -3.2694898334363294, -4.241506382361924, -4.953637811088717, -5.475368496900053, -5.857604920850002, -6.137643653785053, -6.342809904849003}}};
+
+ for (int i = 0; i < expectedYs.length; i++) {
+ for (int j = 0; j < expectedYs.length; j++) {
+ assertArrayEquals("Incorrect ys: ",
+ expectedYs[i][j],
+ ys.reshape(ys.size(0), nrofDims, nrofTimeSteps).getRow(i).getRow(j).toDoubleVector(), 1e-3);
+ }
+ }
+
+ final INDArray lossgrad = Nd4j.linspace(3, 13, ys.length()).reshape(ys.shape());
+ for (int i = 0; i < lossgrad.size(0) / 2; i++) {
+ lossgrad.get(NDArrayIndex.point(i * 2), NDArrayIndex.all(), NDArrayIndex.point(i * 2 + 1)).negi();
+ }
+
+ odevert.setEpsilon(lossgrad.reshape(ys.shape()));
+ Pair grads = odevert.doBackward(false, LayerWorkspaceMgr.noWorkspacesImmutable());
+
+ final double[][] expectedYsGrad = {{0.4791608727137171, 20.99198138553419}, {22.890949753489018, 48.53197539451459}};
+ for (int i = 0; i < expectedYsGrad.length; i++) {
+ assertArrayEquals("Incorrect loss gradient: ",
+ expectedYsGrad[i],
+ grads.getSecond()[0].getRow(i).toDoubleVector(), 1e-3);
+ }
+
+ // Note: Order of element 1 and 2 are swapped
+ final double[] expectedParsGrad = {902.9542328576254, 696.559656699922, 1683.4912664748028, 1268.25338547379, -173.44082030059053, -348.93939579916616};
+ // Compare one by one due to large dynamic range
+ for (int i = 0; i < expectedParsGrad.length; i++) {
+ assertEquals("Incorrect parameter gradient for param " + i + "!",
+ 1,
+ grads.getFirst().gradient().getDouble(i) / expectedParsGrad[i], 1e-4);
+ }
+
+ // First element from reference implementation sure looks weird...
+ final double[] expectedTimeGrad = {-229.96656912353825, 34.11827941396485, 53.52715859322072, 40.515245326331545, 30.634856397062812, 23.14160607429379,
+ 17.465314677923114, 13.17005445759693, 9.923109065345606, 7.470945117798939};
+ // Compare one by one due to large dynamic range
+ for (int i = 0; i < expectedTimeGrad.length; i++) {
+ assertEquals("Incorrect time gradient for param " + i + "!", 1, grads.getSecond()[1].getDouble(i) / expectedTimeGrad[i], 1e-4);
+ }
+ }
+
+ GraphVertex createOdeVertex(long nrofTimeSteps, long nrofDims) {
+ final ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
+ .seed(666)
+ .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ .graphBuilder()
+ .setInputTypes(InputType.feedForward(nrofDims), InputType.feedForward(nrofTimeSteps));
+
+ String next = "y0";
+ builder.addInputs(next, "time");
+
+ builder.addVertex("ode", new OdeVertex.Builder("0", new DenseLayer.Builder()
+ .activation(new ActivationIdentity())
+ .weightInit(new ConstantDistribution(0.2))
+ .biasInit(3)
+ .nOut(nrofDims)
+ .build())
+ .odeConf(new InputStep(
+ new ode.solve.conf.DormandPrince54Solver(
+ new SolverConfig(1e-12, 1e-6, 1e-20, 1e2)),
+ 1, true, true))
+ .build(), next, "time");
+
+ builder.allowNoOutput(true);
+
+ final ComputationGraph graph = new ComputationGraph(builder.build());
+ graph.init();
+ graph.initGradientsView();
+
+ return graph.getVertex("ode");
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/helper/backward/SingleStepAdjointTest.java b/src/test/java/ode/vertex/impl/helper/backward/SingleStepAdjointTest.java
new file mode 100644
index 0000000..462fe32
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/helper/backward/SingleStepAdjointTest.java
@@ -0,0 +1,126 @@
+package ode.vertex.impl.helper.backward;
+
+import ode.solve.conf.SolverConfig;
+import ode.solve.impl.DormandPrince54Solver;
+import ode.vertex.impl.gradview.NonContiguous1DView;
+import ode.vertex.impl.helper.backward.timegrad.CalcTimeGrad;
+import ode.vertex.impl.helper.backward.timegrad.NoTimeGrad;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.jetbrains.annotations.NotNull;
+import org.junit.Test;
+import org.nd4j.linalg.activations.impl.ActivationIdentity;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertNotEquals;
+
+/**
+ * Test cases for {@link SingleStepAdjoint}
+ *
+ * @author Christian Skarby
+ */
+public class SingleStepAdjointTest {
+
+ /**
+ * Smoke test for backwards solve without time gradients
+ */
+ @Test
+ public void solveNoTime() {
+ final int nrofInputs = 7;
+ final ComputationGraph graph = getTestGraph(nrofInputs);
+ final OdeHelperBackward.InputArrays inputArrays = getTestInputArrays(nrofInputs, graph);
+
+ final INDArray time = Nd4j.arange(2);
+ final OdeHelperBackward helper = new SingleStepAdjoint(
+ new DormandPrince54Solver(new SolverConfig(1e-3, 1e-3, 0.1, 10)),
+ time, NoTimeGrad.factory);
+
+ INDArray[] gradients = helper.solve(graph, inputArrays, new OdeHelperBackward.MiscPar(
+ false,
+ LayerWorkspaceMgr.noWorkspaces()));
+
+ assertEquals("Incorrect number of input gradients!", 1, gradients.length);
+
+ final INDArray inputGrad = gradients[0];
+ assertArrayEquals("Incorrect input gradient shape!", inputArrays.getLastInputs()[0].shape(), inputGrad.shape());
+
+ final INDArray parGrad = graph.getGradientsViewArray();
+ assertNotEquals("Expected non-zero parameter gradient!", 0.0, parGrad.sumNumber().doubleValue(),1e-10);
+ assertNotEquals("Expected non-zero input gradient!", 0.0, inputGrad.sumNumber().doubleValue(), 1e-10);
+ }
+
+ /**
+ * Smoke test for backwards solve without time gradients
+ */
+ @Test
+ public void solveWithTime() {
+ final int nrofInputs = 5;
+ final ComputationGraph graph = getTestGraph(nrofInputs);
+ final OdeHelperBackward.InputArrays inputArrays = getTestInputArrays(nrofInputs, graph);
+
+ final INDArray time = Nd4j.arange(2);
+ final OdeHelperBackward helper = new SingleStepAdjoint(
+ new DormandPrince54Solver(new SolverConfig(1e-3, 1e-3, 0.1, 10)),
+ time, new CalcTimeGrad.Factory(inputArrays.getLossGradient(), 1));
+
+ INDArray[] gradients = helper.solve(graph, inputArrays, new OdeHelperBackward.MiscPar(
+ false,
+ LayerWorkspaceMgr.noWorkspaces()));
+
+ assertEquals("Incorrect number of input gradients!", 2, gradients.length);
+
+ final INDArray inputGrad = gradients[0];
+ final INDArray timeGrad = gradients[1];
+ assertArrayEquals("Incorrect input gradient shape!", inputArrays.getLastInputs()[0].shape(), inputGrad.shape());
+ assertArrayEquals("Incorrect time gradient shape!", time.shape(), timeGrad.shape());
+
+ final INDArray parGrad = graph.getGradientsViewArray();
+ assertNotEquals("Expected non-zero parameter gradient!", 0.0, parGrad.sumNumber().doubleValue(),1e-10);
+ assertNotEquals("Expected non-zero input gradient!", 0.0, inputGrad.sumNumber().doubleValue(), 1e-10);
+ assertNotEquals("Expected non-zero time gradient!", 0.0, timeGrad.sumNumber().doubleValue(), 1e-10);
+ }
+
+ @NotNull
+ private static OdeHelperBackward.InputArrays getTestInputArrays(int nrofInputs, ComputationGraph graph) {
+ final INDArray input = Nd4j.arange(nrofInputs);
+ final INDArray output = input.add(1);
+ final INDArray epsilon = Nd4j.ones(nrofInputs).assign(0.01);
+ final NonContiguous1DView realGrads = new NonContiguous1DView();
+ realGrads.addView(graph.getGradientsViewArray());
+
+ return new OdeHelperBackward.InputArrays(
+ new INDArray[]{input},
+ output,
+ epsilon,
+ realGrads
+ );
+ }
+
+ /**
+ *
+ * Create a simple graph for testing
+ * @param nrofInputs Determines the number of inputs (nIn) to the graph
+ * @return a {@link ComputationGraph}
+ */
+ @NotNull
+ static ComputationGraph getTestGraph(int nrofInputs) {
+ final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
+ .weightInit(new ConstantDistribution(0.01))
+ .graphBuilder()
+ .setInputTypes(InputType.feedForward(nrofInputs))
+ .addInputs("input")
+ .addLayer("dense", new DenseLayer.Builder().nOut(nrofInputs).activation(new ActivationIdentity()).build(), "input")
+ .allowNoOutput(true)
+ .build());
+ graph.init();
+ graph.initGradientsView();
+ return graph;
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/helper/forward/ForwardPassTest.java b/src/test/java/ode/vertex/impl/helper/forward/ForwardPassTest.java
new file mode 100644
index 0000000..9a2ff81
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/helper/forward/ForwardPassTest.java
@@ -0,0 +1,53 @@
+package ode.vertex.impl.helper.forward;
+
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.junit.Test;
+import org.nd4j.linalg.activations.impl.ActivationIdentity;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test cases for {@link ForwardPass}
+ *
+ * @author Christian Skarby
+ */
+public class ForwardPassTest {
+
+ /**
+ * Test that the derivative is a forward pass through the layers
+ */
+ @Test
+ public void calculateDerivative() {
+ final long nrofInputs = 5;
+ final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
+ .graphBuilder()
+ .setInputTypes(InputType.feedForward(nrofInputs))
+ .allowNoOutput(true)
+ .addInputs("input")
+ // Very simple dense layer which just performs element wise multiplication of input
+ .addLayer("dense", new DenseLayer.Builder()
+ .nOut(nrofInputs)
+ .hasBias(false)
+ .weightInit(WeightInit.IDENTITY)
+ .activation(new ActivationIdentity())
+ .build(), "input")
+ .build());
+ graph.init();
+ double mul = 1.23;
+ graph.params().muli(mul);
+
+ final INDArray input = Nd4j.arange(nrofInputs);
+ final INDArray expected = input.mul(mul);
+ final INDArray actual = new ForwardPass(graph, LayerWorkspaceMgr.noWorkspaces(), false, new INDArray[]{input})
+ .calculateDerivative(input, Nd4j.scalar(0), input.dup());
+
+ assertEquals("Incorrect output!", expected, actual);
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/helper/forward/InputStepTest.java b/src/test/java/ode/vertex/impl/helper/forward/InputStepTest.java
new file mode 100644
index 0000000..d444be6
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/helper/forward/InputStepTest.java
@@ -0,0 +1,69 @@
+package ode.vertex.impl.helper.forward;
+
+import ode.solve.api.FirstOrderSolver;
+import ode.solve.conf.SolverConfig;
+import ode.solve.impl.DormandPrince54Solver;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.ops.transforms.Transforms;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Test cases for {@link InputStep}. Also tests {@link FixedStep} as {@link InputStep} depends on it
+ *
+ * @author Christian Skarby
+ */
+public class InputStepTest {
+
+ /**
+ * Test if a simple ODE can be solved
+ */
+ @Test
+ public void solveSingleStep() {
+ final double exponent = 2;
+ final int nrofInputs = 7;
+ final ComputationGraph graph = SingleStepTest.getSimpleExpGraph(exponent, nrofInputs);
+
+ final INDArray input = Nd4j.arange(nrofInputs);
+ final INDArray expected = input.mul(Math.exp(2));
+
+ final FirstOrderSolver solver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 100));
+ final OdeHelperForward helper = new InputStep(
+ solver,
+ 1, false);
+
+ final INDArray actual = helper.solve(graph, LayerWorkspaceMgr.noWorkspaces(), new INDArray[]{input, Nd4j.linspace(0, 1, 2)});
+
+ assertArrayEquals("Incorrect answer!", expected.toDoubleVector(),actual.toDoubleVector(), 1e-3);
+ }
+
+
+ /**
+ * Test if a simple ODE can be solved
+ */
+ @Test
+ public void solveMultiStep() {
+ final INDArray t = Nd4j.arange(5);
+ final double exponent = 0.12;
+ final int nrofInputs = 4;
+ final ComputationGraph graph = SingleStepTest.getSimpleExpGraph(exponent, nrofInputs);
+
+ final INDArray input = Nd4j.arange(nrofInputs);
+ final INDArray expected = input.transpose().mmul(Transforms.exp(t.mul(exponent)));
+
+ final FirstOrderSolver solver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 100));
+ final OdeHelperForward helper = new InputStep(
+ solver,
+ 1, false);
+
+ final INDArray actual = helper.solve(graph, LayerWorkspaceMgr.noWorkspaces(), new INDArray[]{input, t}).reshape(expected.shape());
+
+ for(int row = 0; row < actual.rows(); row++) {
+ assertArrayEquals("Incorrect answer!", expected.getRow(row).toDoubleVector(), actual.getRow(row).toDoubleVector(), 1e-3);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/helper/forward/MultiStepTest.java b/src/test/java/ode/vertex/impl/helper/forward/MultiStepTest.java
new file mode 100644
index 0000000..c7caf56
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/helper/forward/MultiStepTest.java
@@ -0,0 +1,47 @@
+package ode.vertex.impl.helper.forward;
+
+import ode.solve.api.FirstOrderSolver;
+import ode.solve.conf.SolverConfig;
+import ode.solve.impl.DormandPrince54Solver;
+import ode.solve.impl.SingleSteppingMultiStepSolver;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.ops.transforms.Transforms;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Test cases for {@link MultiStep}
+ *
+ * @author Christian Skarby
+ */
+public class MultiStepTest {
+
+ /**
+ * Test if a simple ODE can be solved
+ */
+ @Test
+ public void solve() {
+ final INDArray t = Nd4j.arange(4);
+ final double exponent = 1.23;
+ final int nrofInputs = 3;
+ final ComputationGraph graph = SingleStepTest.getSimpleExpGraph(exponent, nrofInputs);
+
+ final INDArray input = Nd4j.arange(nrofInputs);
+ final INDArray expected = input.transpose().mmul(Transforms.exp(t.mul(exponent)));
+
+ final FirstOrderSolver solver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 100));
+ final OdeHelperForward helper = new MultiStep(
+ new SingleSteppingMultiStepSolver(solver),
+ t);
+
+ final INDArray actual = helper.solve(graph, LayerWorkspaceMgr.noWorkspaces(), new INDArray[]{input}).reshape(expected.shape());
+
+ for(int row = 0; row < actual.rows(); row++) {
+ assertArrayEquals("Incorrect answer!", expected.getRow(row).toDoubleVector(), actual.getRow(row).toDoubleVector(), 1e-3);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/ode/vertex/impl/helper/forward/SingleStepTest.java b/src/test/java/ode/vertex/impl/helper/forward/SingleStepTest.java
new file mode 100644
index 0000000..7ba9f13
--- /dev/null
+++ b/src/test/java/ode/vertex/impl/helper/forward/SingleStepTest.java
@@ -0,0 +1,66 @@
+package ode.vertex.impl.helper.forward;
+
+import ode.solve.api.FirstOrderSolver;
+import ode.solve.conf.SolverConfig;
+import ode.solve.impl.DormandPrince54Solver;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.graph.ScaleVertex;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.jetbrains.annotations.NotNull;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Test cases for {@link SingleStep}
+ *
+ * @author Christian Skarby
+ */
+public class SingleStepTest {
+
+ /**
+ * Test if a simple ODE can be solved
+ */
+ @Test
+ public void solve() {
+ final double exponent = 2;
+ final int nrofInputs = 7;
+ final ComputationGraph graph = getSimpleExpGraph(exponent, nrofInputs);
+
+ final INDArray input = Nd4j.arange(nrofInputs);
+ final INDArray expected = input.mul(Math.exp(2));
+
+ final FirstOrderSolver solver = new DormandPrince54Solver(new SolverConfig(1e-10, 1e-10, 1e-10, 100));
+ final OdeHelperForward helper = new SingleStep(
+ solver,
+ Nd4j.linspace(0, 1, 2));
+
+ final INDArray actual = helper.solve(graph, LayerWorkspaceMgr.noWorkspaces(), new INDArray[]{input});
+
+ assertArrayEquals("Incorrect answer!", expected.toDoubleVector(),actual.toDoubleVector(), 1e-3);
+ }
+
+ /**
+ * dy/dt = exponent*y => y = y0*e^(exponent*t)
+ * @param exponent Exponent of e
+ * @param nrofInputs Number of elements in y
+ * @return a {@link ComputationGraph}
+ */
+ @NotNull
+ static ComputationGraph getSimpleExpGraph(double exponent, int nrofInputs) {
+ final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
+ .graphBuilder()
+ .setInputTypes(InputType.feedForward(nrofInputs))
+ .addInputs("input")
+
+ .addVertex("scale", new ScaleVertex(exponent), "input")
+ .allowNoOutput(true)
+ .build());
+ graph.init();
+ return graph;
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/util/listen/step/MaskTest.java b/src/test/java/util/listen/step/MaskTest.java
index 4e68ee1..53190c9 100644
--- a/src/test/java/util/listen/step/MaskTest.java
+++ b/src/test/java/util/listen/step/MaskTest.java
@@ -1,6 +1,7 @@
package util.listen.step;
import ode.solve.api.StepListener;
+import ode.solve.impl.util.StateContainer;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@@ -34,13 +35,13 @@ public void backward() {
private void testMask(ProbeStepListener probeStepListener, StepListener mask, INDArray masked, INDArray notMasked) {
mask.begin(masked, Nd4j.zeros(0));
- mask.step(Nd4j.zeros(0), Nd4j.zeros(0), Nd4j.zeros(0), Nd4j.zeros(0));
+ mask.step(new StateContainer(0, new double[] {0}, new double[] {0}), Nd4j.zeros(0), Nd4j.zeros(0));
mask.done();
probeStepListener.assertNrofCalls(0,0,0);
mask.begin(notMasked, Nd4j.zeros(0));
- mask.step(Nd4j.zeros(0), Nd4j.zeros(0), Nd4j.zeros(0), Nd4j.zeros(0));
+ mask.step(new StateContainer(1, new double[] {0}, new double[] {0}), Nd4j.zeros(0), Nd4j.zeros(0));
mask.done();
probeStepListener.assertNrofCalls(1,1,1);
diff --git a/src/test/java/util/listen/step/ProbeStepListener.java b/src/test/java/util/listen/step/ProbeStepListener.java
index 0659237..a8ad9c4 100644
--- a/src/test/java/util/listen/step/ProbeStepListener.java
+++ b/src/test/java/util/listen/step/ProbeStepListener.java
@@ -1,6 +1,7 @@
package util.listen.step;
import ode.solve.api.StepListener;
+import ode.solve.impl.util.SolverState;
import org.nd4j.linalg.api.ndarray.INDArray;
import static org.junit.Assert.assertEquals;
@@ -23,7 +24,7 @@ public void begin(INDArray t, INDArray y0) {
}
@Override
- public void step(INDArray currTime, INDArray step, INDArray error, INDArray y) {
+ public void step(SolverState solverState, INDArray step, INDArray error) {
nrofStep++;
}
diff --git a/src/test/java/util/listen/step/StepCounterTest.java b/src/test/java/util/listen/step/StepCounterTest.java
index 979d0bd..409d3ed 100644
--- a/src/test/java/util/listen/step/StepCounterTest.java
+++ b/src/test/java/util/listen/step/StepCounterTest.java
@@ -1,6 +1,7 @@
package util.listen.step;
import ode.solve.api.StepListener;
+import ode.solve.impl.util.StateContainer;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
@@ -24,8 +25,8 @@ public void countAndReport() {
for(int i = 0; i < 7; i++) {
listener.begin(Nd4j.linspace(0, 1, 2), Nd4j.zeros(1));
- listener.step(Nd4j.ones(1), Nd4j.ones(1), Nd4j.zeros(1), Nd4j.zeros(0));
- listener.step(Nd4j.ones(1), Nd4j.ones(1), Nd4j.zeros(1), Nd4j.zeros(0));
+ listener.step(new StateContainer(1, new double[] {0}, new double[] {0}), Nd4j.ones(1), Nd4j.zeros(1));
+ listener.step(new StateContainer(1, new double[] {0}, new double[] {0}), Nd4j.ones(1), Nd4j.zeros(1));
listener.done();
assertEquals("Incorrect number of calls!", (i+1) / 3, probe.nrofCalls);
diff --git a/src/test/java/util/random/SeededRandomFactoryTest.java b/src/test/java/util/random/SeededRandomFactoryTest.java
new file mode 100644
index 0000000..5287293
--- /dev/null
+++ b/src/test/java/util/random/SeededRandomFactoryTest.java
@@ -0,0 +1,40 @@
+package util.random;
+
+import org.junit.Test;
+import org.nd4j.linalg.api.rng.Random;
+import org.nd4j.linalg.factory.Nd4j;
+
+import static junit.framework.TestCase.assertEquals;
+
+public class SeededRandomFactoryTest {
+
+ /**
+ * Test that the random seed can be reset to generate the same sequence again
+ */
+ @Test
+ public void getRandom() {
+ final long baseSeed = 666;
+ final long[] shape = {2,3,4};
+ SeededRandomFactory.setNd4jSeed(baseSeed);
+ final Random first = Nd4j.getRandom();
+ SeededRandomFactory.setNd4jSeed(baseSeed);
+ final Random second = Nd4j.getRandom();
+
+ assertEquals("Not same random!", first.nextGaussian(shape), second.nextGaussian(shape));
+ }
+
+ /**
+ * Test that the new random instances
+ */
+ @Test
+ public void getNewRandomInstance() {
+ final long baseSeed = 666;
+ final long[] shape = {2,3,4};
+ SeededRandomFactory.setNd4jSeed(baseSeed);
+ final Random first = Nd4j.getRandomFactory().getNewRandomInstance();
+ SeededRandomFactory.setNd4jSeed(baseSeed);
+ final Random second = Nd4j.getRandomFactory().getNewRandomInstance();
+
+ assertEquals("Not same random!", first.nextGaussian(shape), second.nextGaussian(shape));
+ }
+}
\ No newline at end of file