-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathranking.py
61 lines (51 loc) · 1.79 KB
/
ranking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# -*- coding: utf-8 -*-
import argparse
import torch
import torchvision.models
import torchvision.transforms as transforms
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def prepare_image(image):
if image.mode != "RGB":
image = image.convert("RGB")
Transform = transforms.Compose(
[
transforms.Resize([224, 224]),
transforms.ToTensor(),
]
)
image = Transform(image)
image = image.unsqueeze(0)
return image.to(device)
def predict(image_path, model):
image = Image.open(image_path)
image = prepare_image(image)
with torch.no_grad():
preds = model(image)
print(r"Popularity score: %.2f" % preds.item())
return float(preds.item())
def rank(images_paths):
ranking_map = {}
for image_path in images_paths:
model = torchvision.models.resnet50()
# model.avgpool = nn.AdaptiveAvgPool2d(1) # for any size of the input
model.fc = torch.nn.Linear(in_features=2048, out_features=1)
model.load_state_dict(
torch.load("model/model-resnet50.pth", map_location=device)
)
model.eval().to(device)
ranking_map[image_path] = predict(image_path, model)
return ranking_map
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, default='images/0.jpg')
config = parser.parse_args()
image = Image.open(config.image_path)
model = torchvision.models.resnet50()
# model.avgpool = nn.AdaptiveAvgPool2d(1) # for any size of the input
model.fc = torch.nn.Linear(in_features=2048, out_features=1)
model.load_state_dict(torch.load('model/model-resnet50.pth', map_location=device))
model.eval().to(device)
predict(image, model)
"""