-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from DrChainsaw/spiraldemo
Add first set of classes for spiral demo
- Loading branch information
Showing
137 changed files
with
7,836 additions
and
387 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package examples.spiral; | ||
|
||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.dataset.api.MultiDataSet; | ||
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.nd4j.linalg.indexing.INDArrayIndex; | ||
import org.nd4j.linalg.indexing.NDArrayIndex; | ||
|
||
/** | ||
* Adds a label for KLD loss | ||
* | ||
* @author Christian Skarby | ||
*/ | ||
public class AddKLDLabel implements MultiDataSetPreProcessor { | ||
|
||
private final double mean; | ||
private final double logvar; | ||
private final long nrofLatentDims; | ||
|
||
public AddKLDLabel(double mean, double var, long nrofLatentDims) { | ||
this.mean = mean; | ||
this.logvar = Math.log(var); | ||
this.nrofLatentDims = nrofLatentDims; | ||
} | ||
|
||
@Override | ||
public void preProcess(MultiDataSet multiDataSet) { | ||
final INDArray label0 = multiDataSet.getLabels(0); | ||
final long batchSize = label0.size(0); | ||
final INDArray kldLabel = Nd4j.zeros(batchSize, 2*nrofLatentDims); | ||
kldLabel.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, nrofLatentDims)}, mean); | ||
kldLabel.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(nrofLatentDims, 2*nrofLatentDims)}, logvar); | ||
multiDataSet.setLabels(new INDArray[]{label0, kldLabel}); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package examples.spiral; | ||
|
||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | ||
|
||
/** | ||
* A simple block of layers | ||
* | ||
* @author Christian Skarby | ||
*/ | ||
interface Block { | ||
|
||
/** | ||
* Add layers to given builder | ||
* @param builder Builder to add layers to | ||
* @param prev previous layers | ||
* @return name of last layer added | ||
*/ | ||
String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package examples.spiral; | ||
|
||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | ||
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex; | ||
import org.deeplearning4j.nn.conf.layers.DenseLayer; | ||
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; | ||
import org.nd4j.linalg.activations.impl.ActivationIdentity; | ||
import org.nd4j.linalg.activations.impl.ActivationReLU; | ||
|
||
/** | ||
* Simple decoder using {@link DenseLayer}s. Also uses a {@link FeedForwardToRnnPreProcessor} as it assumes 3D input. | ||
* | ||
* @author Christian Skarby | ||
*/ | ||
public class DenseDecoderBlock implements Block { | ||
|
||
private final long nrofHidden; | ||
private final long nrofOutputs; | ||
|
||
public DenseDecoderBlock(long nrofHidden, long nrofOutputs) { | ||
this.nrofHidden = nrofHidden; | ||
this.nrofOutputs = nrofOutputs; | ||
} | ||
|
||
@Override | ||
public String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev) { | ||
builder | ||
.addLayer("dec0", new DenseLayer.Builder() | ||
.nOut(nrofHidden) | ||
.activation(new ActivationReLU()) | ||
.build(), prev) | ||
.addLayer("dec1", new DenseLayer.Builder() | ||
.nOut(nrofOutputs) | ||
.activation(new ActivationIdentity()) | ||
.build(), "dec0") | ||
.addVertex("decodedOutput", | ||
new PreprocessorVertex( | ||
new FeedForwardToRnnPreProcessor()), | ||
"dec1"); | ||
|
||
return "decodedOutput"; | ||
} | ||
} |
Oops, something went wrong.