-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathg_coco_dataset.py
75 lines (65 loc) · 2.95 KB
/
g_coco_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pickle
import torch
import torchvision.datasets as dset
from joblib import cpu_count
from pretrainedmodels import utils
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from corpus import Corpus
from file_path_manager import FilePathManager
from vgg16_extractor import Vgg16Extractor
class GCocoDataset(Dataset):
def __init__(self, corpus: Corpus, transform=None, captions_per_image=2):
self.corpus = corpus
self.captions = dset.CocoCaptions(root=FilePathManager.resolve(f'data/train'),
annFile=FilePathManager.resolve(
f"data/annotations/captions_train2017.json"),
transform=transform)
self.captions_per_image = captions_per_image
def __getitem__(self, index):
image, caption = self.captions[index]
inputs = self.corpus.embed_sentence(caption[0], one_hot=False)
targets = self.corpus.sentence_indices(caption[0])
# inputs = torch.stack(
# [self.corpus.embed_sentence(caption[i], one_hot=False) for i in range(self.captions_per_image)])
# targets = torch.stack([self.corpus.sentence_indices(caption[i]) for i in range(self.captions_per_image)])
return image, inputs, targets
def __len__(self):
return len(self.captions)
if __name__ == '__main__':
extractor = Vgg16Extractor(transform=False)
captions = dset.CocoCaptions(root=FilePathManager.resolve(f'data/train'),
annFile=FilePathManager.resolve(
f"data/annotations/captions_train2017.json"),
transform=utils.TransformImage(extractor.cnn))
batch_size = 3
dataloader = DataLoader(captions, batch_size=batch_size, shuffle=True, num_workers=cpu_count())
print(f"number of images = {len(captions.coco.imgs)}")
images = []
i = 1
for image, _ in dataloader:
print(f"batch = {i}")
item = extractor.forward(image).cpu().data
images.append(item)
i += 1
with open(FilePathManager.resolve("data/embedded_images.pkl"), "wb") as f:
pickle.dump(images, f)
# corpus = Corpus.load(FilePathManager.resolve("data/corpus.pkl"))
# one_hot = []
# i = 1
# for _, capts in captions:
# print(f"caption = {i}")
# for capt in capts:
# one_hot.append(corpus.embed_sentence(capt, one_hot=True))
# i += 1
# with open(FilePathManager.resolve("data/one_hot_sentences.pkl"), "wb") as f:
# pickle.dump(one_hot, f)
# i = 1
# embedded_sentences = []
# for _, capts in captions:
# print(f"caption = {i}")
# for capt in capts:
# embedded_sentences.append(corpus.embed_sentence(capt, one_hot=False))
# i += 1
# with open(FilePathManager.resolve("data/embedded_sentences.pkl"), "wb") as f:
# pickle.dump(embedded_sentences, f)