-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
76 lines (68 loc) · 2.63 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision.transforms import ToTensor
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
class dataset(Dataset):
def __init__(self, train=True, max_num=10000000, preload=True):
super(dataset, self).__init__()
root = 'vimeo_triplet/sequences'
self.root = root
if train:
data_dic_txt = 'vimeo_triplet/tri_trainlist.txt'
else:
data_dic_txt = 'vimeo_triplet/tri_testlist.txt'
with open(data_dic_txt, 'r') as f:
lines = f.readlines()
self.length = len(lines)
self.length = min(self.length, max_num)
lines = lines[:self.length]
self.lines = lines
self.preload = preload
self.images = []
if preload:
for i, line in enumerate(lines):
line = line[:-1]
x1 = ToTensor()(Image.open(os.path.join(self.root, line, 'im1.png')))
x2 = ToTensor()(Image.open(os.path.join(self.root, line, 'im2.png')))
x3 = ToTensor()(Image.open(os.path.join(self.root, line, 'im3.png')))
self.images.append((x1, x2, x3))
if i % 10 == 9:
print("Preloaded {} data.".format(i + 1))
print("Data preloaded.")
def __len__(self):
return self.length
def __getitem__(self, item):
if self.preload:
return self.images[item]
else:
line = self.lines[item]
line = line[:-1]
x1 = Image.open(os.path.join(self.root, line, 'im1.png'))
x2 = Image.open(os.path.join(self.root, line, 'im2.png'))
x3 = Image.open(os.path.join(self.root, line, 'im3.png'))
return ToTensor()(x1), ToTensor()(x2), ToTensor()(x3)
if __name__ == "__main__":
print('begin')
# test whether all of the pictures are readable
data = dataset(train=True)
dataloader = DataLoader(data, batch_size=16)
for i, (x1, x2, x3) in enumerate(dataloader):
if i % 10 == 0:
print(i)
data = dataset(train=False)
dataloader = DataLoader(data, batch_size=16)
for i, (x1, x2, x3) in enumerate(dataloader):
if i % 10 == 0:
print(i)
# for i, (x1, x2, x3) in enumerate(dataloader):
# print(x1)
# print(x1.min())
# print(x1.max())
# print(x1.shape)
# x1 = x1[0]
# img_array = x1.numpy().transpose((1, 2, 0))
# plt.imshow(img_array)
# plt.show()