-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathe_coco_dataset.py
57 lines (46 loc) · 2.04 KB
/
e_coco_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import random
import torch
import torchvision.datasets as dset
from torch.utils.data import Dataset
from corpus import Corpus
from file_path_manager import FilePathManager
class ECocoDataset(Dataset):
def __init__(self, corpus: Corpus, evaluator: bool = True, tranform=None, captions_per_image=2):
self.corpus = corpus
self.evaluator = evaluator
self.captions = dset.CocoCaptions(root=FilePathManager.resolve(f'data/train'),
annFile=FilePathManager.resolve(
f"data/annotations/captions_train2017.json"),
transform=tranform)
self.captions_per_image = captions_per_image
self.length = len(self.captions)
def __getitem__(self, index):
image, caption = self.captions[index]
# print(f"real: {caption}")
captions = torch.stack(
[self.corpus.embed_sentence(caption[i], one_hot=False) for i in range(self.captions_per_image)])
others = []
s = set(range(self.length))
s.remove(index)
s = list(s)
for i in range(self.captions_per_image):
other_index = random.choice(s)
if index == other_index:
print(f"index: {index}, other: {other_index}")
other_caption = self.get_captions(other_index)
other_index = random.choice(range(self.captions_per_image))
other_caption = other_caption[other_index]
# print(f"other: {other_caption}")
other_caption = self.corpus.embed_sentence(other_caption, one_hot=False)
others.append(other_caption)
others = torch.stack(others)
return image, captions, others
def get_captions(self, index):
coco = self.captions.coco
img_id = self.captions.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
anns = coco.loadAnns(ann_ids)
target = [ann['caption'] for ann in anns]
return target
def __len__(self):
return self.length