forked from Yancey1989/gotorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_test.go
58 lines (51 loc) · 1.46 KB
/
mnist_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package gotorch_test
import (
"log"
"time"
torch "github.com/wangkuiyi/gotorch"
nn "github.com/wangkuiyi/gotorch/nn"
F "github.com/wangkuiyi/gotorch/nn/functional"
"github.com/wangkuiyi/gotorch/vision/datasets"
"github.com/wangkuiyi/gotorch/vision/transforms"
)
type MLPMNISTSequential struct {
nn.Module
Layers *nn.SequentialModule
}
func (s *MLPMNISTSequential) Forward(x torch.Tensor) torch.Tensor {
x = torch.View(x, []int64{-1, 28 * 28})
return s.Layers.Forward(x).(torch.Tensor).LogSoftmax(1)
}
func ExampleTrainMNISTSequential() {
net := &MLPMNISTSequential{Layers: nn.Sequential(
nn.Linear(28*28, 512, false),
nn.Functional(torch.Tanh),
nn.Linear(512, 512, false),
nn.Functional(torch.Tanh),
nn.Linear(512, 10, false))}
net.Init(net)
net.ZeroGrad()
mnist := datasets.MNIST("",
[]transforms.Transform{transforms.Normalize(0.1307, 0.3081)})
opt := torch.SGD(0.1, 0.5, 0, 0, false)
opt.AddParameters(net.Parameters())
epochs := 1
startTime := time.Now()
for i := 0; i < epochs; i++ {
trainLoader := datasets.NewMNISTLoader(mnist, 64)
for trainLoader.Scan() {
batch := trainLoader.Batch()
opt.ZeroGrad()
pred := net.Forward(batch.Data)
loss := F.NllLoss(pred, batch.Target, torch.Tensor{}, -100, "mean")
loss.Backward()
opt.Step()
}
trainLoader.Close()
}
throughput := float64(60000*epochs) / time.Since(startTime).Seconds()
log.Printf("Throughput: %f samples/sec", throughput)
mnist.Close()
torch.FinishGC()
// Output:
}