Skip to content

Commit

Permalink
Add bias in Conv2d and Linear operators
Browse files Browse the repository at this point in the history
  • Loading branch information
mikhaeldj authored and moreib committed Oct 29, 2024
1 parent ba4488a commit 0829856
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 29 deletions.
47 changes: 47 additions & 0 deletions hannah/callbacks/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,53 @@ def get_attn2d(node, output, args, kwargs):
return num_weights, macs, attrs


# quantized conv
def get_qconv(node, output, args, kwargs):
weight_bits = kwargs["qconfig"]["weight"]["bits"]
activation_bits = kwargs["qconfig"]["activation"]["bits"]

volume_ofm = prod(output.shape)
weight = args[1]
out_channels = weight.shape[0]
in_channels = weight.shape[1]
kernel_size = weight.shape[2]

# default weights and macs in 32 bits
num_weights = np.prod(weight.shape)
macs = volume_ofm * in_channels * kernel_size**2

# convert to lower bits
num_weights = (num_weights / 32) * weight_bits
macs = (macs / 32) * activation_bits

attrs = "k=" + "(%d, %d)" % (kernel_size, kernel_size)
attrs += ", s=" + "(%d, %d)" % (kwargs["stride"], kwargs["stride"])
attrs += ", g=(%d)" % kwargs["groups"]
attrs += ", dsc=(%s)" % str(in_channels == out_channels == kwargs["groups"])
attrs += ", d=" + "(%d, %d)" % (kwargs["dilation"], kwargs["dilation"])
return num_weights, macs, attrs


# quantized linear
def get_qlinear(node, output, args, kwargs):
weight_bits = kwargs["qconfig"]["weight"]["bits"]
activation_bits = kwargs["qconfig"]["activation"]["bits"]

weight = args[1]
in_features = weight.shape[0]
out_features = weight.shape[1]

# default weights and macs in 32 bits
num_weights = macs = in_features * out_features

# convert to lower bits
num_weights = (num_weights / 32) * weight_bits
macs = (macs / 32) * activation_bits

attrs = ""
return num_weights, macs, attrs


def get_type(node):
try:
return node.name.split("_")[-2]
Expand Down
2 changes: 1 addition & 1 deletion hannah/conf/model/resnet.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
_target_: hannah.models.resnet.models.search_space
_target_: hannah.models.resnet.models.resnet
name: resnet
num_classes: 10
4 changes: 3 additions & 1 deletion hannah/models/resnet/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from hannah.models.embedded_vision_net.expressions import expr_product
from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter
from hannah.nas.functional_operators.op import search_space
from hannah.models.resnet.operators import dynamic_depth
from hannah.models.resnet.blocks import block, conv_relu_bn, classifier_head


def search_space(name, input, num_classes=10):
@search_space
def resnet(name, input, num_classes=10):
out_channels = IntScalarParameter(16, 64, step_size=4, name='out_channels')
kernel_size = CategoricalParameter([3, 5, 7, 9], name='kernel_size')
stride = CategoricalParameter([1, 2], name='stride')
Expand Down
8 changes: 6 additions & 2 deletions hannah/models/resnet/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ def conv2d(input, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, p
shape=(out_channels, in_channels, kernel_size, kernel_size),
axis=('O', 'I', 'kH', 'kW'),
grad=True)
bias = Tensor(name='bias', shape=(out_channels,), axis=('O',), grad=True)

conv = Conv2d(stride=stride, dilation=dilation, groups=groups, padding=padding)(input, weight)
conv = Conv2d(stride=stride, dilation=dilation, groups=groups, padding=padding)(
input, weight, bias
)
return conv


Expand All @@ -22,8 +25,9 @@ def linear(input, out_features):
shape=(in_features, out_features),
axis=('in_features', 'out_features'),
grad=True)
bias = Tensor(name='bias', shape=(out_features,), axis=('O',), grad=True)

out = Linear()(input, weight)
out = Linear()(input, weight, bias)
return out


Expand Down
4 changes: 2 additions & 2 deletions hannah/modules/vision/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def common_step(self, step_name, batch, batch_idx):
provs = torch.softmax(logits, dim=1)

# For HMM and Viterbi Post-Processing
if ((step_name is "train") or (step_name is 'val')) and self.current_epoch == self.trainer.max_epochs-1:
if ((step_name == "train") or (step_name == 'val')) and self.current_epoch == self.trainer.max_epochs-1:
metadata = batch.get('metadata', {}).copy()
metadata.update({'preds_cnn': preds.cpu().numpy(), 'labels': labels.cpu().numpy()})
df = pd.DataFrame(metadata)
df.to_csv(os.getcwd() + f'_cnn_{step_name}_output', mode='a', index=False, header=True)

if step_name is 'test':
if step_name == 'test':
metadata = batch.get('metadata', {}).copy()
metadata.update({'preds_cnn': preds.cpu().numpy(), 'labels': labels.cpu().numpy()})
df = pd.DataFrame(metadata)
Expand Down
21 changes: 17 additions & 4 deletions hannah/nas/functional_operators/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
from typing import Iterator, Tuple
import math
import torch
from hannah.nas.functional_operators.op import ChoiceOp, Op, Tensor, get_nodes
from collections import defaultdict
Expand All @@ -17,7 +18,7 @@ def __init__(self, net, input_node_name='input', init=None) -> None:
if init is not None:
self.init = init
else:
self.init = torch.nn.init.xavier_uniform_
self.init = torch.nn.init.kaiming_uniform_
self.nodes = []

def initialize(self):
Expand All @@ -27,9 +28,21 @@ def initialize_tensor(self, node):
if isinstance(node, Tensor):
node_name = node.id.replace(".", "_")
if node.grad:
data = torch.empty(node.current_shape())
data = torch.nn.Parameter(self.init(data))
self.register_parameter(node_name, data)
if node.name == 'bias':
# get weight data
weight_name = node_name.replace('bias', 'weight')
weight_param = self.get_parameter(weight_name)
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(weight_param.data)
# register bias
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
data = torch.empty(node.current_shape())
data = torch.nn.Parameter(torch.nn.init.uniform_(data, -bound, bound))
self.register_parameter(node_name, data)
else: # weight tensor
data = torch.empty(node.current_shape())
data = torch.nn.Parameter(self.init(data, a=math.sqrt(5)))
self.register_parameter(node_name, data)
if node.name == self.input_node_name:
self.input = node
if node.name == 'running_mean':
Expand Down
37 changes: 25 additions & 12 deletions hannah/nas/functional_operators/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ def conv1d(input, weight, stride, padding, dilation, groups, *, id):


@torch.fx.wrap
def conv2d(input, weight, stride=1, padding=1, dilation=1, groups=1, *, id):
def conv2d(input, weight, bias, stride=1, padding=1, dilation=1, groups=1, *, id):
return F.conv2d(
input=input,
weight=weight,
bias=bias,
stride=lazy(stride),
padding=lazy(padding),
dilation=lazy(dilation),
Expand All @@ -68,8 +69,8 @@ def conv2d(input, weight, stride=1, padding=1, dilation=1, groups=1, *, id):


@torch.fx.wrap
def linear(input, weight, *, id):
return F.linear(input=input, weight=weight.T)
def linear(input, weight, bias, *, id):
return F.linear(input=input, weight=weight.T, bias=bias)


@torch.fx.wrap
Expand Down Expand Up @@ -105,12 +106,12 @@ def adaptive_avg_pooling1d(input, output_size=(1, 1), *, id):


@torch.fx.wrap
def max_pool(input, kernel_size, stride, padding, dilation):
def max_pool(input, kernel_size, stride, padding, dilation, *, id):
return F.max_pool2d(input, kernel_size, stride, padding, dilation)


@torch.fx.wrap
def avg_pool(input, kernel_size, stride, padding):
def avg_pool(input, kernel_size, stride, padding, *, id):
return F.avg_pool2d(input, kernel_size, stride, padding)


Expand Down Expand Up @@ -142,7 +143,7 @@ def self_attention2d(q, k, v, num_heads, d_model, *, id):
k: Tensor, shape ``[B, h*d, H, W]``
v: Tensor, shape ``[B, h*d, H, W]``
"""
scale = d_model**-0.5
scale = d_model ** -0.5
b, _, h, w = q.shape
q = q.view(b, num_heads, d_model, h * w)
k = k.view(b, num_heads, d_model, h * w)
Expand Down Expand Up @@ -253,7 +254,9 @@ def shape_fun(self):
@parametrize
class Conv2d(Op):
def __init__(self, stride=1, dilation=1, groups=1, padding=None) -> None:
super().__init__(name="Conv2d", stride=stride, dilation=dilation, groups=groups)
super().__init__(
name="Conv2d", stride=stride, dilation=dilation, groups=groups,
)
self.stride = stride
self.dilation = dilation
self.groups = groups
Expand All @@ -264,6 +267,8 @@ def __call__(self, *operands) -> Any:
input_shape = operands[0].shape()
weight_shape = operands[1].shape()
operands[1].id = f"{new_conv.id}.{operands[1].id}"
if len(operands) >= 3:
operands[2].id = f"{new_conv.id}.{operands[2].id}"

new_conv.in_channels = input_shape[1]
new_conv.out_channels = weight_shape[0]
Expand All @@ -275,10 +280,11 @@ def __call__(self, *operands) -> Any:

return new_conv

def _forward_implementation(self, x, weight):
def _forward_implementation(self, input, weight, bias=None):
return conv2d(
x,
input,
weight,
bias,
stride=lazy(self.stride),
padding=lazy(self.padding),
dilation=lazy(self.dilation),
Expand Down Expand Up @@ -306,14 +312,19 @@ def __call__(self, *operands) -> Any:
new_linear.in_features = operands[1].shape()[0]
new_linear.out_features = operands[1].shape()[1]
operands[1].id = f"{new_linear.id}.{operands[1].id}"
if len(operands) >= 3:
operands[2].id = f"{new_linear.id}.{operands[2].id}"
return new_linear

def shape_fun(self):
return linear_shape(*self.operands)

def _forward_implementation(self, input, weight):
def _forward_implementation(self, input, weight, bias=None):
input = torch.flatten(input, start_dim=1)
return linear(input, weight, id=self.id)
return linear(
input, weight, bias,
id=self.id
)


@parametrize
Expand Down Expand Up @@ -376,7 +387,6 @@ def _forward_implementation(self, *operands):
class Requantize(Op):
def __init__(self, *args, **kwargs) -> None:
super().__init__(name="Quantize")

self.quantize = FakeQuantize()

@property
Expand Down Expand Up @@ -460,6 +470,7 @@ def _forward_implementation(self, *operands):
stride=lazy(self.stride),
padding=lazy(self.padding),
dilation=lazy(self.dilation),
id=self.id
)


Expand Down Expand Up @@ -488,6 +499,7 @@ def _forward_implementation(self, *operands):
kernel_size=lazy(self.kernel_size),
stride=lazy(self.stride),
padding=lazy(self.padding),
id=self.id
)


Expand Down Expand Up @@ -516,6 +528,7 @@ def _forward_implementation(self, *operands):
kernel_size=lazy(self.kernel_size),
stride=lazy(self.stride),
padding=lazy(self.padding),
id=self.id
)


Expand Down
34 changes: 27 additions & 7 deletions hannah/nas/graph_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,16 @@ def add_nodes_conv_fun(self, target, mod, args, kwargs, output):
attrs["padding"] = to_int(kwargs["padding"])

# FIXME: How to handle quantization
weight_attrs = {"quant": None, "shape": args[1].tensor.shape}
qconfig = kwargs.get("qconfig", None)
weight_quant_attrs = qconfig["weight"] if qconfig else None
weight_attrs = {"quant": weight_quant_attrs, "shape": args[1].tensor.shape}
# weight_attrs = {"quant": None, "shape": args[1].tensor.shape}

bias_attrs = None
# FIXME: Bias missing
if hasattr(args[2], 'tensor'):
bias_attrs = {"quant": None, "shape": args[2].tensor.shape}
else:
bias_attrs = None

name = target + "_conv"
input_attrs = self.extract_input_attrs([args[0]])
Expand All @@ -514,7 +520,10 @@ def add_nodes_conv_fun(self, target, mod, args, kwargs, output):
output=output_attr,
)

input_names = [arg.name for arg in args]
input_names = list()
for arg in args:
if hasattr(arg, 'name'):
input_names.append(arg.name)
for input_name in input_names:
self.nx_graph.add_edge(input_name, name)

Expand All @@ -530,10 +539,18 @@ def add_nodes_linear_fun(self, target, mod, args, kwargs, output):
attrs["in_features"] = args[1].tensor.shape[0]
attrs["out_features"] = args[1].tensor.shape[0]

weight_attrs = {"quant": None, "shape": args[1].tensor.shape}
bias_attrs = None
qconfig = kwargs.get("qconfig", None)
weight_quant_attrs = qconfig["weight"] if qconfig else None
weight_attrs = {"quant": weight_quant_attrs, "shape": args[1].tensor.shape}
# weight_attrs = {"quant": None, "shape": args[1].tensor.shape}

if hasattr(args[2], 'tensor'):
bias_attrs = {"quant": None, "shape": args[2].tensor.shape}
else:
bias_attrs = None

name = target + "_linear"
input_attrs = self.extract_input_attrs(args)
input_attrs = self.extract_input_attrs([args[0]])
output_quant = {"dtype": "float", "bits": 32, "method": "none"}
output_attr = {"name": name, "quant": output_quant, "shape": output.shape}
self.nx_graph.add_node(
Expand All @@ -546,7 +563,10 @@ def add_nodes_linear_fun(self, target, mod, args, kwargs, output):
output=output_attr,
)

input_names = [arg.name for arg in args]
input_names = list()
for arg in args:
if hasattr(arg, 'name'):
input_names.append(arg.name)
for input_name in input_names:
self.nx_graph.add_edge(input_name, name)
quantization = None
Expand Down
Loading

0 comments on commit 0829856

Please sign in to comment.