Skip to content

Commit

Permalink
Merge branch 'f/nas_mwe' into 'main'
Browse files Browse the repository at this point in the history
Add Simple NAS Example

See merge request es/ai/hannah/hannah!411
  • Loading branch information
moreib committed Oct 11, 2024
2 parents 1338383 + b07b12b commit 8a4588d
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 209 deletions.
55 changes: 55 additions & 0 deletions experiments/nas_mwe/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
##
## Copyright (c) 2022 University of Tübingen.
##
## This file is part of hannah.
## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info.
##
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
##
## http://www.apache.org/licenses/LICENSE-2.0
##
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
##
defaults:
- base_config
- override dataset: cifar10 # Dataset configuration name
- override nas: aging_evolution_nas # which NAS algorithm to use
- override features: identity # Feature extractor configuration name (use identity for vision datasets)
- override model: simple_conv_search_space # in case of NAS -> search space name
- override scheduler: 1cycle # learning rate scheduler config name
- override optimizer: sgd # Optimizer config name
- override normalizer: null # Feature normalizer (used for quantized neural networks)
- override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks)
- _self_


dataset:
data_folder: ${oc.env:HANNAH_DATA_FOLDER,${hydra:runtime.cwd}/../../datasets/}

trainer:
max_epochs: 10
limit_train_batches: 0.1 # Set this to 1.0 to use the whole training set


nas:
budget: 2000
n_jobs: 1
population_size: 10
# Note: If we choose a different NAS algorithm, it might be necessary
# to use different config fields (e.g. population size might throw an error for random search).

# The nas samples {total_candidates}, sorts them by a {sort_key} (currently predicted val_error, which is 0 if no
# predictor is chosen or trained) and only chooses the {num_selected_candidates} for training. Here,
# we just use all and take as many candidates as the population size of the AE
total_candidates: 10
num_selected_candidates: 10


fx_mac_summary: True # This has to be set to use current NAS search spaces
experiment_id: nas_mwe
17 changes: 0 additions & 17 deletions hannah/conf/model/lazy_convnet.yaml

This file was deleted.

3 changes: 3 additions & 0 deletions hannah/conf/model/simple_conv_search_space.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: hannah.models.convnet.models.convnet
name: simple_conv_space
num_classes: 10
237 changes: 68 additions & 169 deletions hannah/models/convnet/models.py
Original file line number Diff line number Diff line change
@@ -1,169 +1,68 @@
import torch
import torch.nn as nn
from hannah.nas.expressions.shapes import conv2d_shape, identity_shape
from hannah.nas.parameters.lazy import Lazy
from hannah.nas.parameters.parametrize import parametrize
from hannah.nas.parameters.iterators import RangeIterator
from hannah.nas.parameters.parameters import IntScalarParameter, CategoricalParameter
from hannah.nas.expressions.arithmetic import Ceil
from hannah.nas.expressions.choice import SymbolicAttr, Choice

conv2d = Lazy(nn.Conv2d, shape_func=conv2d_shape)
linear = Lazy(nn.Linear)
batch_norm = Lazy(nn.BatchNorm2d, shape_func=identity_shape)
relu = Lazy(nn.ReLU)


def padding_expression(kernel_size, stride, dilation = 1):
"""Symbolically calculate padding such that for a given kernel_size, stride and dilation
the padding is such that the output dimension is kept the same(stride=1) or halved(stride=2).
Note: If the input dimension is 1 and stride = 2, the calculated padding will result in
an output with also dimension 1.
Parameters
----------
kernel_size : Union[int, Expression]
stride : Union[int, Expression]
dilation : Union[int, Expression], optional
_description_, by default 1
Returns
-------
Expression
"""
p = (dilation * (kernel_size - 1) - stride + 1) / 2
return Ceil(p)

def stride_product(expressions: list):
res = None
for expr in expressions:
if res:
res = res * expr
else:
res = expr
return res


@parametrize
class ConvReluBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, id, inputs) -> None:
super().__init__()
self.id = id
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride

self.conv = conv2d(self.id + ".conv",
inputs=inputs,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding_expression(kernel_size, stride))

self.shape = self.conv.shape
self.bn = batch_norm(self.id + ".bn", num_features=out_channels)
self.relu = relu(self.id + ".relu")
def initialize(self):
self.tconv = self.conv.instantiate()
self.tbn = self.bn.instantiate()
self.trelu = self.relu.instantiate()

def forward(self, x):
out = self.tconv(x)
out = self.tbn(out)
out = self.trelu(out)
return out


@parametrize
class ConvReluBlock(nn.Module):
def __init__(self, params, input_shape, id, depth) -> None:
super().__init__()
self.input_shape = input_shape
self.depth = self.add_param(f'{id}.depth', depth)
self.mods = nn.ModuleList()
self.id = id
self.depth = depth
self.params = params

strides = []

previous = input_shape
for d in RangeIterator(self.depth, instance=False):
in_channels = self.input_shape[1] if d == 0 else self._PARAMETERS[f'{self.id}.conv{d-1}.out_channels']
out_channels = self.add_param(f'{self.id}.conv{d}.out_channels', IntScalarParameter(self.params.conv.out_channels.min,
self.params.conv.out_channels.max,
self.params.conv.out_channels.step))
kernel_size = self.add_param(f'{self.id}.conv{d}.kernel_size', CategoricalParameter(self.params.conv.kernel_size.choices))
stride = self.add_param(f'{self.id}.conv{d}.stride', CategoricalParameter(self.params.conv.stride.choices))

strides.append(stride)

layer = ConvReluBn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, id=f'{self.id}.{d}', inputs=[previous])
self.mods.append(layer)
previous = layer

self.cond(stride_product(strides) <= self.input_shape[2])

def initialize(self):
for d in RangeIterator(self.depth, instance=False):
self.mods[d].initialize()

def forward(self, x):
out = x
for d in RangeIterator(self.depth, instance=True):
out = self.mods[d](out)
return out


@parametrize
class ConvNet(nn.Module):
def __init__(self, name, params, input_shape, labels) -> None:
super().__init__()
self.input_shape = input_shape
self.labels = labels
self.depth = IntScalarParameter(params.depth.min, params.depth.max)
self.conv_block = self.add_param("convs", ConvReluBlock(params, self.input_shape, 'convs', self.depth))

last = Choice(self.conv_block.mods, self.depth - 1)
in_features = last.get('out_channels') * last.get('shape')[2] * last.get('shape')[3]

# Alternatively to the following, one can create a parametrized class "Classifier" which
# wraps the linear layer.
self._linear = self.add_param('linear',
linear("linear",
inputs=[last],
in_features=in_features,
out_features=self.labels))

def initialize(self):
self.conv_block.initialize()
self.linear = self._linear.instantiate()

def forward(self, x):
out = self.conv_block(x)
out = out.view(out.shape[0], -1)
out = self.linear(out)
return out

def get_hparams(self):
params = {}
for key, param in self.parametrization(flatten=True).items():
params[key] = param.current_value.item()

return params


def create_cnn(name, input_shape, labels):
return ConvNet(name, input_shape, labels)


if __name__ == '__main__':
net = ConvNet()
x = torch.randn((3, 3, 32, 32))
net.sample()
net.initialize()
out = net(x)
print()
from hannah.models.embedded_vision_net.expressions import Tensor
from hannah.nas.functional_operators.op import scope, search_space
from hannah.nas.functional_operators.operators import AdaptiveAvgPooling, BatchNorm, Conv2d, Linear, Relu
from hannah.nas.parameters import CategoricalParameter, IntScalarParameter


def conv2d(input, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, padding=None):
in_channels = input.shape()[1]
weight = Tensor(name='weight',
shape=(out_channels, in_channels, kernel_size, kernel_size),
axis=('O', 'I', 'kH', 'kW'),
grad=True)

conv = Conv2d(stride=stride, dilation=dilation, groups=groups, padding=padding)(input, weight)
return conv


def linear(input, out_features):
input_shape = input.shape()
in_features = input_shape[1] * input_shape[2] * input_shape[3]
weight = Tensor(name='weight',
shape=(in_features, out_features),
axis=('in_features', 'out_features'),
grad=True)

out = Linear()(input, weight)
return out


@scope
def batch_norm(input):
# https://stackoverflow.com/questions/44887446/pytorch-nn-functional-batch-norm-for-2d-input
n_chans = input.shape()[1]
running_mu = Tensor(name='running_mean', shape=(n_chans,), axis=('c',))
running_std = Tensor(name='running_std', shape=(n_chans,), axis=('c',))
# running_mu.data = torch.zeros(n_chans) # zeros are fine for first training iter
# running_std = torch.ones(n_chans) # ones are fine for first training iter
return BatchNorm()(input, running_mu, running_std)


def relu(input):
return Relu()(input)


def adaptive_avg_pooling(input):
return AdaptiveAvgPooling()(input)


def conv_bn_relu(input, out_channels, kernel_size, stride):
out = conv2d(input, out_channels=out_channels, stride=stride, kernel_size=kernel_size)
out = batch_norm(out)
out = relu(out)
return out


@search_space
def convnet(name, input, num_classes):
out_channels = IntScalarParameter(
16, 128, step_size=8, name="out_channels"
)
kernel_size = CategoricalParameter([3, 5, 7, 9], name="kernel_size")
stride = CategoricalParameter([1, 2], name="stride")

net = conv_bn_relu(input, out_channels=out_channels.new(), kernel_size=kernel_size.new(), stride=stride.new())
net = conv_bn_relu(net, out_channels=out_channels.new(), kernel_size=kernel_size.new(), stride=stride.new())
net = adaptive_avg_pooling(net)
net = linear(net, num_classes)
return net
23 changes: 0 additions & 23 deletions test/test_graph_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from hannah.models.factory.qconfig import get_trax_qat_qconfig
from hannah.nas.functional_operators.op import Tensor
from hannah.nas.graph_conversion import model_to_graph
from hannah.models.convnet.models import ConvNet


class Model(Module):
Expand Down Expand Up @@ -69,27 +68,6 @@ def test_graph_conversion():
pprint(data, indent=2)


def test_graph_conversion_lazy_convnet():
from omegaconf import OmegaConf

params = {
"depth": {"min": 3, "max": 3},
"conv": {
"kernel_size": {"choices": [3, 5, 7]},
"stride": {"choices": [1, 2]},
"out_channels": {"min": 16, "max": 64, "step": 4},
},
}

config = OmegaConf.merge(params)

model = ConvNet(name="cnn", params=config, input_shape=[1, 3, 32, 32], labels=10)
model.sample()
model.initialize()
test_output = model(torch.rand((1, 3, 32, 32), dtype=torch.float32))
graph = model_to_graph(model, torch.rand((1, 3, 32, 32), dtype=torch.float32))


def test_graph_conversion_functional_operators():
from hannah.models.embedded_vision_net.models import embedded_vision_net
from hannah.nas.functional_operators.executor import BasicExecutor
Expand All @@ -114,5 +92,4 @@ def test_graph_conversion_functional_operators():

if __name__ == "__main__":
test_graph_conversion()
test_graph_conversion_lazy_convnet()
test_graph_conversion_functional_operators()

0 comments on commit 8a4588d

Please sign in to comment.