-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmask_generator.py
167 lines (131 loc) · 5.4 KB
/
mask_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Originally inspired by impl at https://github.com/microsoft/unilm/tree/master/beit
Modified by Haoyu Lu, for generating the spatial-temporal masked position for video diffusion transformer
"""
import random
import math
import numpy as np
import torch
class MaskingGenerator:
def __init__(
self, input_size, num_masking_patches, min_num_patches=4,
min_aspect=0.3,):
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_masking_patches = num_masking_patches
self.min_num_patches = min_num_patches
self.max_num_patches = num_masking_patches
max_aspect = 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def __repr__(self):
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height, self.width, self.min_num_patches, self.max_num_patches,
self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for attempt in range(10):
target_area = random.uniform(self.min_num_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = random.randint(0, self.height - h)
left = random.randint(0, self.width - w)
num_masked = mask[top: top + h, left: left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self):
mask = np.zeros(shape=self.get_shape(), dtype=np.int)
mask_count = 0
while mask_count < self.num_masking_patches:
max_mask_patches = self.num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return mask
class VideoMaskGenerator:
def __init__(self, input_size, spatial_mask_ratio=0.5):
self.length, self.height, self.width = input_size
self.spatial_generator = MaskingGenerator((self.height, self.width), spatial_mask_ratio * self.height * self.width)
# idx = 0 Predict
self.predict_given_frame_length = 8
# idx = 1 Backward
self.backward_given_frame_length = 8
# idx = 2 Interpreation
self.interpreation_step = 4
# idx = 5 MLM ratio
self.mlm_ratio = 0.8
def __repr__(self):
repr_str = "Generator(%d, %d, %d)" % (
self.length, self.height, self.width)
return repr_str
def get_shape(self):
return self.length, self.height, self.width
def spatial_mask(self):
mask = np.zeros(shape=self.get_shape(), dtype=np.int)
start_idx = random.randint(0, 3)
end_idx = random.randint(0, 3)
spatial_mask = self.spatial_generator()
# print("start_idx, end_idx", start_idx, end_idx)
mask[start_idx:-end_idx] = spatial_mask
return mask
def temporal_mask(self, idx=0):
mask = np.zeros(shape=self.get_shape(), dtype=np.int)
# Predict
if idx == 0:
mask[self.predict_given_frame_length:] = 1
# Backward
elif idx == 1:
mask[:-self.backward_given_frame_length] = 1
# Interpreation
elif idx == 2:
mask = np.ones(shape=self.get_shape(), dtype=np.int)
mask[::self.interpreation_step] = 0
# Unconditional Generation
elif idx == 3:
mask = np.ones(shape=self.get_shape(), dtype=np.int)
# Only one frames
elif idx == 4:
frame_idx = random.randint(0, mask.shape[0]-1)
mask = np.ones(shape=self.get_shape(), dtype=np.int)
mask[frame_idx] = 0
# MLM
else:
for frame_idx in range(mask.shape[0]):
if random.random() < self.mlm_ratio:
mask[frame_idx] = 1
return mask
def __call__(self, batch_size=1, device=None, idx=-1):
if idx >= 0:
if idx < 6:
mask = self.temporal_mask(idx)
else:
mask = self.spatial_mask()
return torch.tensor(mask).unsqueeze(0).repeat(batch_size,1,1,1).to(device)
if random.random() < 0.2:
mask = self.spatial_mask()
else:
idx = random.randint(0, 5)
mask = self.temporal_mask(idx)
return torch.tensor(mask).unsqueeze(0).repeat(batch_size,1,1,1).to(device)
if __name__ == '__main__':
generator = VideoMaskGenerator((10,10,10))
print(generator())
mask = generator(4)
print(mask.shape)
a = torch.ones(32, 4, 10, 10, 10,)
print((a[:] * mask).shape)