-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathmodel.py
61 lines (48 loc) · 2.16 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class VGGNet(nn.Module):
def __init__(self):
"""Select conv1_1 ~ conv5_1 activation maps."""
super(VGGNet, self).__init__()
self.vgg = models.vgg19_bn(pretrained=True)
self.vgg_features = self.vgg.features
self.fc_features = nn.Sequential(*list(self.vgg.classifier.children())[:-2])
def forward(self, x):
"""Extract multiple convolutional feature maps."""
features = self.vgg_features(x).view(x.shape[0], -1)
features = self.fc_features(features)
return features
class ImgNN(nn.Module):
"""Network to learn image representations"""
def __init__(self, input_dim=4096, output_dim=1024):
super(ImgNN, self).__init__()
self.denseL1 = nn.Linear(input_dim, output_dim)
def forward(self, x):
out = F.relu(self.denseL1(x))
return out
class TextNN(nn.Module):
"""Network to learn text representations"""
def __init__(self, input_dim=1024, output_dim=1024):
super(TextNN, self).__init__()
self.denseL1 = nn.Linear(input_dim, output_dim)
def forward(self, x):
out = F.relu(self.denseL1(x))
return out
class IDCM_NN(nn.Module):
"""Network to learn text representations"""
def __init__(self, img_input_dim=4096, img_output_dim=2048,
text_input_dim=1024, text_output_dim=2048, minus_one_dim=1024, output_dim=10):
super(IDCM_NN, self).__init__()
self.img_net = ImgNN(img_input_dim, img_output_dim)
self.text_net = TextNN(text_input_dim, text_output_dim)
self.linearLayer = nn.Linear(img_output_dim, minus_one_dim)
self.linearLayer2 = nn.Linear(minus_one_dim, output_dim)
def forward(self, img, text):
view1_feature = self.img_net(img)
view2_feature = self.text_net(text)
view1_feature = self.linearLayer(view1_feature)
view2_feature = self.linearLayer(view2_feature)
view1_predict = self.linearLayer2(view1_feature)
view2_predict = self.linearLayer2(view2_feature)
return view1_feature, view2_feature, view1_predict, view2_predict