-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
104 lines (80 loc) · 2.86 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from torch.backends import cudnn
from backbone import EfficientDetBackbone
import numpy as np
import os
import json
from efficientdet.utils import BBoxTransform, ClipBoxes
from utils.utils import preprocess, invert_affine, postprocess
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
# load data
num_files = 0
for subdir, dirs, files in os.walk('test'):
num_files = len(files)
break
tags = np.linspace(1, num_files, num_files)
names = []
for tag in tags:
names.append('test/' + str(int(tag)) + '.png')
# settings
anchor_ratios = [(1.0, 1.0), (1.4, 0.7), (1.5, 0.5)]
anchor_scales = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]
threshold = 0.2
iou_threshold = 0.2
compound_coef = 3
load_path = 'weights.pth'
use_cuda = torch.cuda.is_available()
use_float16 = False
cudnn.fastest = True
cudnn.benchmark = True
force_input_size = 896
input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
if force_input_size:
input_size = force_input_size
else:
input_size = input_sizes[compound_coef]
obj_list_num = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0']
# model
model = EfficientDetBackbone(compound_coef=compound_coef,
num_classes=len(obj_list_num),
ratios=anchor_ratios, scales=anchor_scales)
try:
model.load_state_dict(torch.load(load_path, map_location='cpu'))
except RuntimeError as e:
print(f'[Warning] Ignoring {e}')
model.requires_grad_(False)
model.eval()
if use_cuda:
model = model.cuda()
# inferring
preds = []
for name in names:
ori_imgs, framed_imgs, framed_metas = preprocess(name, max_size=input_size)
if use_cuda:
x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)
else:
x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)
x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)
with torch.no_grad():
features, regression, classification, anchors = model(x)
regressBoxes = BBoxTransform()
clipBoxes = ClipBoxes()
out = postprocess(x,
anchors, regression, classification,
regressBoxes, clipBoxes,
threshold, iou_threshold)
out = invert_affine(framed_metas, out)
o_bbox = out[0]['rois']
o_cls = out[0]['class_ids']
o_score = out[0]['scores']
for i, ele in enumerate(o_bbox):
o_bbox[i][0], o_bbox[i][1] = int(o_bbox[i][1]), int(o_bbox[i][0])
o_bbox[i][2], o_bbox[i][3] = int(o_bbox[i][3]), int(o_bbox[i][2])
o_cls[i] += 1
pred = dict()
pred['bbox'] = o_bbox.tolist()
pred['score'] = o_score.tolist()
pred['label'] = o_cls.tolist()
preds.append(pred)
with open('predictions.json', 'w') as f:
json.dump(preds, f)