-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDuplicateScalarToShape.java
78 lines (66 loc) · 2.48 KB
/
DuplicateScalarToShape.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package util.preproc;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Arrays;
/**
* Duplicates scalar input to given shape. Typically used for when time is used as input to a layer in an OdeVertex.
*
* @author Christian Skarby
*/
@Data
@EqualsAndHashCode
public class DuplicateScalarToShape implements InputPreProcessor {
private final long[] shape;
public DuplicateScalarToShape() {
this(new long[] {-1, 1});
}
/**
* Constructor.
* @param shape Desired shape. Set element 0 to -1 in order to use given mini batch size.
*/
public DuplicateScalarToShape(@JsonProperty("shape") long[] shape) {
this.shape = shape;
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
if(!input.isScalar()) {
throw new IllegalArgumentException("Can only process scalar input. Got: " + Arrays.toString(input.shape()));
}
long[] tmpShape = getShapeFor(miniBatchSize);
return workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, tmpShape).assign(input.sumNumber());
}
long[] getShapeFor(int miniBatchSize) {
long[] tmpShape = shape.clone();
if (tmpShape[0] == -1) {
tmpShape[0] = miniBatchSize;
}
return tmpShape;
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.sum());
}
@Override
public InputPreProcessor clone() {
return new DuplicateScalarToShape(shape);
}
@Override
public InputType getOutputType(InputType inputType) {
long[] tmpShape = getShapeFor(1);
return InputType.inferInputType(Nd4j.createUninitialized(tmpShape));
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
// Mask a scalar??
throw new UnsupportedOperationException("Not implemented!");
}
}