-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauto_train.py
77 lines (59 loc) · 1.92 KB
/
auto_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.optim as optim
from extempauto import ExTempAuto
from dataset import scaffold_loaders
from constants import DEVICE
from constants import NUM_EPOCHS
model = ExTempAuto()
model = model.to(DEVICE)
# Define data
train_data = None
test_data = None
train_labels = None
test_labels = None
train_loader, test_loader = scaffold_loaders(
train_data, train_labels, test_data, test_labels
)
c_loss = nn.CrossEntropyLoss()
a_loss = nn.MSELoss()
opt = optim.Adam(model.parameters())
history = {"loss": [], "test_loss": [], "acc": [], "test_acc": []}
for epoch in range(NUM_EPOCHS):
print(f"Epoch: {epoch}", end=" ")
correct, total = 0, 0
running_loss = 0.0
for x, y in train_loader:
softmax, decoded = model(x)
_, predictions = torch.max(softmax, 1)
total += x.size(0)
correct += (predictions == y).sum().item()
opt.zero_grad()
c = c_loss(softmax, y)
a = a_loss(decoded, x)
t = a + c
t.backward()
opt.step()
running_loss += t.item()
running_loss /= len(train_loader)
acc = 100.0 * correct / total
history["loss"].append(running_loss)
history["acc"].append(acc)
print(f"Training Loss: {running_loss} Accuracy: {acc}", end=" ")
with torch.no_grad():
correct, total = 0, 0
running_loss = 0.0
for x, y in test_loader:
softmax, decoded = model(x)
_, predictions = torch.max(softmax, 1)
total += x.size(0)
correct += (predictions == y).sum().item()
c = c_loss(softmax, y)
a = a_loss(decoded, x)
t = a + c
running_loss += t.item()
running_loss /= len(train_loader)
acc = 100.0 * correct / total
history["test_loss"].append(running_loss)
history["test_acc"].append(acc)
print(f"Testing Loss: {running_loss} Accuracy: {acc}")