-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcount_flops.py
143 lines (112 loc) · 4.36 KB
/
count_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
from torch.autograd import Variable
from functools import reduce
import operator
count_ops = 0
count_params = 0
def get_num_gen(gen):
return sum(1 for x in gen)
def is_pruned(layer):
try:
layer.mask
return True
except AttributeError:
return False
def is_leaf(model):
return get_num_gen(model.children()) == 0
def get_layer_info(layer):
layer_str = str(layer)
type_name = layer_str[:layer_str.find('(')].strip()
return type_name
def get_layer_param(model):
return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])
### The input batch size should be 1 to call this function
def measure_layer(layer, x):
global count_ops, count_params
delta_ops = 0
delta_params = 0
multi_add = 1
type_name = get_layer_info(layer)
### ops_conv
if type_name in ['Conv2d']:
out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
layer.stride[0] + 1)
out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
layer.stride[1] + 1)
delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
delta_params = get_layer_param(layer)
elif type_name in ['ConvTranspose2d']:
_, _, in_h, in_w = x.size()
out_h = int((in_h-1)*layer.stride[0] - 2 * layer.padding[0] + layer.kernel_size[0] + layer.output_padding[0])
out_w = int((in_w-1)*layer.stride[1] - 2 * layer.padding[1] + layer.kernel_size[1] + layer.output_padding[1])
delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
delta_params = get_layer_param(layer)
### ops_nonlinearity
elif type_name in ['ReLU']:
delta_ops = x.numel()
delta_params = get_layer_param(layer)
### ops_pooling
elif type_name in ['AvgPool2d','MaxPool2d']:
in_w = x.size()[2]
kernel_ops = layer.kernel_size * layer.kernel_size
out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops
delta_params = get_layer_param(layer)
### ops_linear
elif type_name in ['Linear']:
weight_ops = layer.weight.numel() * multi_add
bias_ops = layer.bias.numel()
delta_ops = x.size()[0] * (weight_ops + bias_ops)
delta_params = get_layer_param(layer)
### ops_nothing
elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']:
delta_params = get_layer_param(layer)
### sequential takes no extra time
elif type_name in ['Sequential']:
pass
### unknown layer type
else:
raise TypeError('unknown layer type: %s' % type_name)
count_ops += delta_ops
count_params += delta_params
return
def measure_model(model, H, W):
global count_ops, count_params
count_ops = 0
count_params = 0
data = Variable(torch.zeros(1, 3, H, W))
def should_measure(x):
return is_leaf(x) or is_pruned(x)
# this is a dangerous recursive function
def modify_forward(model):
for child in model.children():
if should_measure(child):
def new_forward(m):
def lambda_forward(x):
measure_layer(m, x)
return m.old_forward(x)
return lambda_forward
child.old_forward = child.forward
child.forward = new_forward(child)
else:
modify_forward(child)
def restore_forward(model):
for child in model.children():
# leaf node
if is_leaf(child) and hasattr(child, 'old_forward'):
child.forward = child.old_forward
child.old_forward = None
else:
restore_forward(child)
modify_forward(model)
model.forward(data)
restore_forward(model)
return count_ops, count_params