diff --git a/hannah/callbacks/summaries.py b/hannah/callbacks/summaries.py index bea60477..d0f6b6f3 100644 --- a/hannah/callbacks/summaries.py +++ b/hannah/callbacks/summaries.py @@ -512,6 +512,53 @@ def get_attn2d(node, output, args, kwargs): return num_weights, macs, attrs +# quantized conv +def get_qconv(node, output, args, kwargs): + weight_bits = kwargs["qconfig"]["weight"]["bits"] + activation_bits = kwargs["qconfig"]["activation"]["bits"] + + volume_ofm = prod(output.shape) + weight = args[1] + out_channels = weight.shape[0] + in_channels = weight.shape[1] + kernel_size = weight.shape[2] + + # default weights and macs in 32 bits + num_weights = np.prod(weight.shape) + macs = volume_ofm * in_channels * kernel_size**2 + + # convert to lower bits + num_weights = (num_weights / 32) * weight_bits + macs = (macs / 32) * activation_bits + + attrs = "k=" + "(%d, %d)" % (kernel_size, kernel_size) + attrs += ", s=" + "(%d, %d)" % (kwargs["stride"], kwargs["stride"]) + attrs += ", g=(%d)" % kwargs["groups"] + attrs += ", dsc=(%s)" % str(in_channels == out_channels == kwargs["groups"]) + attrs += ", d=" + "(%d, %d)" % (kwargs["dilation"], kwargs["dilation"]) + return num_weights, macs, attrs + + +# quantized linear +def get_qlinear(node, output, args, kwargs): + weight_bits = kwargs["qconfig"]["weight"]["bits"] + activation_bits = kwargs["qconfig"]["activation"]["bits"] + + weight = args[1] + in_features = weight.shape[0] + out_features = weight.shape[1] + + # default weights and macs in 32 bits + num_weights = macs = in_features * out_features + + # convert to lower bits + num_weights = (num_weights / 32) * weight_bits + macs = (macs / 32) * activation_bits + + attrs = "" + return num_weights, macs, attrs + + def get_type(node): try: return node.name.split("_")[-2] diff --git a/hannah/conf/model/resnet.yaml b/hannah/conf/model/resnet.yaml index 088ea3ba..fce12abc 100644 --- a/hannah/conf/model/resnet.yaml +++ b/hannah/conf/model/resnet.yaml @@ -1,3 +1,3 @@ -_target_: hannah.models.resnet.models.search_space +_target_: hannah.models.resnet.models.resnet name: resnet num_classes: 10 \ No newline at end of file diff --git a/hannah/models/resnet/models.py b/hannah/models/resnet/models.py index 22e562a4..667fc056 100644 --- a/hannah/models/resnet/models.py +++ b/hannah/models/resnet/models.py @@ -1,10 +1,12 @@ from hannah.models.embedded_vision_net.expressions import expr_product from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter +from hannah.nas.functional_operators.op import search_space from hannah.models.resnet.operators import dynamic_depth from hannah.models.resnet.blocks import block, conv_relu_bn, classifier_head -def search_space(name, input, num_classes=10): +@search_space +def resnet(name, input, num_classes=10): out_channels = IntScalarParameter(16, 64, step_size=4, name='out_channels') kernel_size = CategoricalParameter([3, 5, 7, 9], name='kernel_size') stride = CategoricalParameter([1, 2], name='stride') diff --git a/hannah/models/resnet/operators.py b/hannah/models/resnet/operators.py index 678403ce..84a97d06 100644 --- a/hannah/models/resnet/operators.py +++ b/hannah/models/resnet/operators.py @@ -10,8 +10,11 @@ def conv2d(input, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, p shape=(out_channels, in_channels, kernel_size, kernel_size), axis=('O', 'I', 'kH', 'kW'), grad=True) + bias = Tensor(name='bias', shape=(out_channels,), axis=('O',), grad=True) - conv = Conv2d(stride=stride, dilation=dilation, groups=groups, padding=padding)(input, weight) + conv = Conv2d(stride=stride, dilation=dilation, groups=groups, padding=padding)( + input, weight, bias + ) return conv @@ -22,8 +25,9 @@ def linear(input, out_features): shape=(in_features, out_features), axis=('in_features', 'out_features'), grad=True) + bias = Tensor(name='bias', shape=(out_features,), axis=('O',), grad=True) - out = Linear()(input, weight) + out = Linear()(input, weight, bias) return out diff --git a/hannah/modules/vision/image_classifier.py b/hannah/modules/vision/image_classifier.py index 7156ebb6..15414588 100644 --- a/hannah/modules/vision/image_classifier.py +++ b/hannah/modules/vision/image_classifier.py @@ -86,13 +86,13 @@ def common_step(self, step_name, batch, batch_idx): provs = torch.softmax(logits, dim=1) # For HMM and Viterbi Post-Processing - if ((step_name is "train") or (step_name is 'val')) and self.current_epoch == self.trainer.max_epochs-1: + if ((step_name == "train") or (step_name == 'val')) and self.current_epoch == self.trainer.max_epochs-1: metadata = batch.get('metadata', {}).copy() metadata.update({'preds_cnn': preds.cpu().numpy(), 'labels': labels.cpu().numpy()}) df = pd.DataFrame(metadata) df.to_csv(os.getcwd() + f'_cnn_{step_name}_output', mode='a', index=False, header=True) - if step_name is 'test': + if step_name == 'test': metadata = batch.get('metadata', {}).copy() metadata.update({'preds_cnn': preds.cpu().numpy(), 'labels': labels.cpu().numpy()}) df = pd.DataFrame(metadata) diff --git a/hannah/nas/functional_operators/executor.py b/hannah/nas/functional_operators/executor.py index 8e7dcfd8..cc9aadd1 100644 --- a/hannah/nas/functional_operators/executor.py +++ b/hannah/nas/functional_operators/executor.py @@ -1,5 +1,6 @@ from copy import deepcopy from typing import Iterator, Tuple +import math import torch from hannah.nas.functional_operators.op import ChoiceOp, Op, Tensor, get_nodes from collections import defaultdict @@ -17,7 +18,7 @@ def __init__(self, net, input_node_name='input', init=None) -> None: if init is not None: self.init = init else: - self.init = torch.nn.init.xavier_uniform_ + self.init = torch.nn.init.kaiming_uniform_ self.nodes = [] def initialize(self): @@ -27,9 +28,21 @@ def initialize_tensor(self, node): if isinstance(node, Tensor): node_name = node.id.replace(".", "_") if node.grad: - data = torch.empty(node.current_shape()) - data = torch.nn.Parameter(self.init(data)) - self.register_parameter(node_name, data) + if node.name == 'bias': + # get weight data + weight_name = node_name.replace('bias', 'weight') + weight_param = self.get_parameter(weight_name) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(weight_param.data) + # register bias + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + data = torch.empty(node.current_shape()) + data = torch.nn.Parameter(torch.nn.init.uniform_(data, -bound, bound)) + self.register_parameter(node_name, data) + else: # weight tensor + data = torch.empty(node.current_shape()) + data = torch.nn.Parameter(self.init(data, a=math.sqrt(5))) + self.register_parameter(node_name, data) if node.name == self.input_node_name: self.input = node if node.name == 'running_mean': diff --git a/hannah/nas/functional_operators/operators.py b/hannah/nas/functional_operators/operators.py index 139e044c..79d42720 100644 --- a/hannah/nas/functional_operators/operators.py +++ b/hannah/nas/functional_operators/operators.py @@ -56,10 +56,11 @@ def conv1d(input, weight, stride, padding, dilation, groups, *, id): @torch.fx.wrap -def conv2d(input, weight, stride=1, padding=1, dilation=1, groups=1, *, id): +def conv2d(input, weight, bias, stride=1, padding=1, dilation=1, groups=1, *, id): return F.conv2d( input=input, weight=weight, + bias=bias, stride=lazy(stride), padding=lazy(padding), dilation=lazy(dilation), @@ -68,8 +69,8 @@ def conv2d(input, weight, stride=1, padding=1, dilation=1, groups=1, *, id): @torch.fx.wrap -def linear(input, weight, *, id): - return F.linear(input=input, weight=weight.T) +def linear(input, weight, bias, *, id): + return F.linear(input=input, weight=weight.T, bias=bias) @torch.fx.wrap @@ -105,12 +106,12 @@ def adaptive_avg_pooling1d(input, output_size=(1, 1), *, id): @torch.fx.wrap -def max_pool(input, kernel_size, stride, padding, dilation): +def max_pool(input, kernel_size, stride, padding, dilation, *, id): return F.max_pool2d(input, kernel_size, stride, padding, dilation) @torch.fx.wrap -def avg_pool(input, kernel_size, stride, padding): +def avg_pool(input, kernel_size, stride, padding, *, id): return F.avg_pool2d(input, kernel_size, stride, padding) @@ -142,7 +143,7 @@ def self_attention2d(q, k, v, num_heads, d_model, *, id): k: Tensor, shape ``[B, h*d, H, W]`` v: Tensor, shape ``[B, h*d, H, W]`` """ - scale = d_model**-0.5 + scale = d_model ** -0.5 b, _, h, w = q.shape q = q.view(b, num_heads, d_model, h * w) k = k.view(b, num_heads, d_model, h * w) @@ -253,7 +254,9 @@ def shape_fun(self): @parametrize class Conv2d(Op): def __init__(self, stride=1, dilation=1, groups=1, padding=None) -> None: - super().__init__(name="Conv2d", stride=stride, dilation=dilation, groups=groups) + super().__init__( + name="Conv2d", stride=stride, dilation=dilation, groups=groups, + ) self.stride = stride self.dilation = dilation self.groups = groups @@ -264,6 +267,8 @@ def __call__(self, *operands) -> Any: input_shape = operands[0].shape() weight_shape = operands[1].shape() operands[1].id = f"{new_conv.id}.{operands[1].id}" + if len(operands) >= 3: + operands[2].id = f"{new_conv.id}.{operands[2].id}" new_conv.in_channels = input_shape[1] new_conv.out_channels = weight_shape[0] @@ -275,10 +280,11 @@ def __call__(self, *operands) -> Any: return new_conv - def _forward_implementation(self, x, weight): + def _forward_implementation(self, input, weight, bias=None): return conv2d( - x, + input, weight, + bias, stride=lazy(self.stride), padding=lazy(self.padding), dilation=lazy(self.dilation), @@ -306,14 +312,19 @@ def __call__(self, *operands) -> Any: new_linear.in_features = operands[1].shape()[0] new_linear.out_features = operands[1].shape()[1] operands[1].id = f"{new_linear.id}.{operands[1].id}" + if len(operands) >= 3: + operands[2].id = f"{new_linear.id}.{operands[2].id}" return new_linear def shape_fun(self): return linear_shape(*self.operands) - def _forward_implementation(self, input, weight): + def _forward_implementation(self, input, weight, bias=None): input = torch.flatten(input, start_dim=1) - return linear(input, weight, id=self.id) + return linear( + input, weight, bias, + id=self.id + ) @parametrize @@ -376,7 +387,6 @@ def _forward_implementation(self, *operands): class Requantize(Op): def __init__(self, *args, **kwargs) -> None: super().__init__(name="Quantize") - self.quantize = FakeQuantize() @property @@ -460,6 +470,7 @@ def _forward_implementation(self, *operands): stride=lazy(self.stride), padding=lazy(self.padding), dilation=lazy(self.dilation), + id=self.id ) @@ -488,6 +499,7 @@ def _forward_implementation(self, *operands): kernel_size=lazy(self.kernel_size), stride=lazy(self.stride), padding=lazy(self.padding), + id=self.id ) @@ -516,6 +528,7 @@ def _forward_implementation(self, *operands): kernel_size=lazy(self.kernel_size), stride=lazy(self.stride), padding=lazy(self.padding), + id=self.id ) diff --git a/hannah/nas/graph_conversion.py b/hannah/nas/graph_conversion.py index 3474966d..568d8827 100644 --- a/hannah/nas/graph_conversion.py +++ b/hannah/nas/graph_conversion.py @@ -495,10 +495,16 @@ def add_nodes_conv_fun(self, target, mod, args, kwargs, output): attrs["padding"] = to_int(kwargs["padding"]) # FIXME: How to handle quantization - weight_attrs = {"quant": None, "shape": args[1].tensor.shape} + qconfig = kwargs.get("qconfig", None) + weight_quant_attrs = qconfig["weight"] if qconfig else None + weight_attrs = {"quant": weight_quant_attrs, "shape": args[1].tensor.shape} + # weight_attrs = {"quant": None, "shape": args[1].tensor.shape} - bias_attrs = None # FIXME: Bias missing + if hasattr(args[2], 'tensor'): + bias_attrs = {"quant": None, "shape": args[2].tensor.shape} + else: + bias_attrs = None name = target + "_conv" input_attrs = self.extract_input_attrs([args[0]]) @@ -514,7 +520,10 @@ def add_nodes_conv_fun(self, target, mod, args, kwargs, output): output=output_attr, ) - input_names = [arg.name for arg in args] + input_names = list() + for arg in args: + if hasattr(arg, 'name'): + input_names.append(arg.name) for input_name in input_names: self.nx_graph.add_edge(input_name, name) @@ -530,10 +539,18 @@ def add_nodes_linear_fun(self, target, mod, args, kwargs, output): attrs["in_features"] = args[1].tensor.shape[0] attrs["out_features"] = args[1].tensor.shape[0] - weight_attrs = {"quant": None, "shape": args[1].tensor.shape} - bias_attrs = None + qconfig = kwargs.get("qconfig", None) + weight_quant_attrs = qconfig["weight"] if qconfig else None + weight_attrs = {"quant": weight_quant_attrs, "shape": args[1].tensor.shape} + # weight_attrs = {"quant": None, "shape": args[1].tensor.shape} + + if hasattr(args[2], 'tensor'): + bias_attrs = {"quant": None, "shape": args[2].tensor.shape} + else: + bias_attrs = None + name = target + "_linear" - input_attrs = self.extract_input_attrs(args) + input_attrs = self.extract_input_attrs([args[0]]) output_quant = {"dtype": "float", "bits": 32, "method": "none"} output_attr = {"name": name, "quant": output_quant, "shape": output.shape} self.nx_graph.add_node( @@ -546,7 +563,10 @@ def add_nodes_linear_fun(self, target, mod, args, kwargs, output): output=output_attr, ) - input_names = [arg.name for arg in args] + input_names = list() + for arg in args: + if hasattr(arg, 'name'): + input_names.append(arg.name) for input_name in input_names: self.nx_graph.add_edge(input_name, name) quantization = None diff --git a/hannah/quantization/utils.py b/hannah/quantization/utils.py new file mode 100644 index 00000000..f8071c1a --- /dev/null +++ b/hannah/quantization/utils.py @@ -0,0 +1,99 @@ +import torch +import torch.autograd as autograd +from hannah.nas.functional_operators.lazy import lazy + + +class QConfig(): + def __init__(self, weight_bits, activation_bits, per_channel): + self.weight_bits = weight_bits + self.activation_bits = activation_bits + self.per_channel = per_channel + + def create(self): + # set parameterized values + qconfig = { + "weight": { + "dtype": "int", + "bits": lazy(self.weight_bits), + "method": "symmetric", + "per_channel": lazy(self.per_channel) + }, + "activation": { + "dtype": "int", + "bits": lazy(self.activation_bits), + "method": "symmetric", + "per_channel": False + } + } + return qconfig + + +def quantize(input, scale, zero_point): + """ + Range-based Linear Quantization + """ + return torch.round(torch.div(input, scale) - zero_point) + + +def dequantize(q_input, scale, zero_point): + """ + Dequantization of linear-quantized input + """ + return (q_input + zero_point) * (scale) + + +def calculate_qparams( + bits, min_range, max_range, + mode='symmetric', per_channel=False +): + """ + Calculate scaling factor and zero-point + + Parameters: + bits: number of bits for quantization + min_range: min quantization range + quant_max: max quantization range + mode: symmetric or asymmetric quantization + per_channel: calculate scaling factor per channel + """ + + with torch.no_grad(): + n = 2.0 ** (bits - 1) - 1 + + # Symmetric quantization mode + if per_channel: + scale, _ = torch.max( + torch.stack([min_range.abs(), max_range.abs()], dim=1), dim=1 + ) + scale = torch.clamp(scale, min=1e-8) / n + else: + scale = max(min_range.abs(), max_range.abs()) + scale = torch.clamp(scale, min=1e-8) / n + + zero_point = torch.tensor(0.) + # TODO: add asymmetric quantization mode (for activations) + + return scale, zero_point + + +class SymmetricQuantization(autograd.Function): + """ + Symmetric quantization of floating-point values, + given quantization bits and scale. + """ + @staticmethod + def forward(ctx, x, bits, scale): + n = 2.0 ** (bits - 1) - 1 + zero_point = torch.tensor(0.) + + # Quantization: scale, round, clamp + x_q = quantize(x, scale, zero_point) + x_q = torch.clamp(x_q, -n - 1, n) + + ctx.scale = scale + return x_q + + @staticmethod + def backward(ctx, grad_output): + scale = ctx.scale + return grad_output.clone() / scale, None, None \ No newline at end of file