-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconditional_generator_trainer.py
76 lines (63 loc) · 2.66 KB
/
conditional_generator_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import time
import torch
from pretrainedmodels import utils
from torch import nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence
from torch.optim import Adam
from torch.utils.data import DataLoader
from conditional_generator import ConditionalGenerator
from corpus import Corpus
from file_path_manager import FilePathManager
from g_coco_dataset import GCocoDataset
from vgg16_extractor import Vgg16Extractor
if not os.path.exists(FilePathManager.resolve("models")):
os.makedirs(FilePathManager.resolve("models"))
extractor = Vgg16Extractor(use_gpu=True, transform=False)
tf_img = utils.TransformImage(extractor.cnn)
corpus = Corpus.load(FilePathManager.resolve("data/corpus.pkl"))
print("Corpus loaded")
captions_per_image = 2
max_length = 16
batch_size = 32
dataset = GCocoDataset(corpus, transform=tf_img, captions_per_image=captions_per_image)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
generator = ConditionalGenerator(corpus).cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = Adam(generator.parameters(), lr=1e-2, weight_decay=1e-5)
pretrained = False
if pretrained:
st = torch.load("./models/generator.pth")
generator.load_state_dict(st['state_dict'])
generator.eval()
optimizer.load_state_dict(st['optimizer'])
generator.train(True)
epochs = 100
print(f"number of batches = {len(dataset) // batch_size}")
start = time.time()
print("Begin Training")
for epoch in range(epochs):
epoch_loss = 0
for i, (images, inputs, targets) in enumerate(dataloader, 0):
# print(f"Batch = {i}, Time: {time.time() - start}, Loss: {epoch_loss}")
images = Variable(images).cuda()
images = extractor.forward(images)
k = images.shape[0]
images = torch.stack([images] * captions_per_image).permute(1, 0, 2).contiguous().view(-1, images.shape[-1])
inputs = inputs.view(-1, max_length, inputs.shape[-1])
targets = targets.view(-1, max_length)
inputs = pack_padded_sequence(inputs[:, :-1], [max_length] * captions_per_image * k, True).cuda()
targets = pack_padded_sequence(targets[:, 1:], [max_length] * captions_per_image * k, True).cuda()[0]
optimizer.zero_grad()
outputs = generator.forward(images, inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
end = time.time()
print(f"Epoch: {epoch}, Time: {end - start}, Loss: {epoch_loss}")
start = end
torch.save({"state_dict": generator.state_dict(), 'optimizer': optimizer.state_dict()},
FilePathManager.resolve("models/generator.pth"))
print(f"Epoch = {epoch + 1}")