-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_models.py
125 lines (103 loc) · 4.06 KB
/
test_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from models.resnet import QuantizedResNet18, QuantizedResNet34, QuantizedResNet50, QuantizedResNet101, QuantizedResNet152
from models.vgg import QuantizedVGG11, QuantizedVGG11_bn, QuantizedVGG13, QuantizedVGG13_bn, QuantizedVGG16, QuantizedVGG16_bn, QuantizedVGG19, QuantizedVGG19_bn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch
from tqdm import tqdm
import random
def train(model, batch_size=240, num_workers=8, dataset_folder_path='/home/marcelo/datasets/ILSVRC2012/ILSVRC2012_train'):
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
data_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(dataset_folder_path, transform),
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
else:
device = "cpu"
model.to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
model = model.train()
for epoch in range(2):
running_loss = 0.0
for i, (inputImage, target) in enumerate(tqdm(data_loader)):
optimizer.zero_grad()
target = target.to(device)
inputImage = inputImage.to(device)
model.module.quantize()
outputs = model(inputImage)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 0 and i != 0:
tqdm.write("[%d, %d] loss: %.5f" % (epoch+1, i+1, running_loss/500))
running_loss = 0.0
model = model.module
correct1, correct5, total = test(model)
print(correct1/total, correct5/total)
def test(model, batch_size=2048, num_workers=8, dataset_folder_path='/home/marcelo/datasets/ILSVRC2012/ILSVRC2012_val'):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
data_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(dataset_folder_path, transform),
batch_size=batch_size,
shuffle=True,
num_workers = num_workers,
pin_memory = True
)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
else:
device = "cpu"
model = model.to(device)
model.module.quantize()
model.eval()
correct1, correct5, total = 0, 0, 0
with torch.no_grad():
for i, (inputImage, target) in enumerate(tqdm(data_loader)):
target = target.to(device)
inputImage = inputImage.to(device)
outputs = model(inputImage)
_, predicted = torch.max(outputs.data, 1)
a = torch.argsort(outputs.data, 1, True)[:, 0:5]
total += target.size(0)
correct1+=(predicted==target).sum().item()
correct5+=(a==target.unsqueeze(1)).sum().item()
return correct1, correct5, total
from DSConv.nn.dsconv2d import DSConv2d
def counting_dsconv(model):
count = 0
for m in model.modules():
if isinstance(m, DSConv2d):
count+=1
return count
if __name__=="__main__":
# model = QuantizedResNet50(4, 32, pretrained=True)
bits = [random.randint(2, 8) for i in range(QuantizedResNet50.number_bits)]
bits2 = bits.copy()
print(bits)
input('')
model = QuantizedResNet50(bits, 32, pretrained=True)
print(model)
model.update_bits(bits2)
print(model)
input('')
train(model)
correct1, correct5, total = test(model)
print(correct1/total, correct5/total)