Skip to content

Commit

Permalink
Merge branch 'channel_width_multiplier' into 'main'
Browse files Browse the repository at this point in the history
Fix Parametrization

See merge request es/ai/hannah/hannah!348
  • Loading branch information
moreib committed Oct 27, 2023
2 parents aafe42a + b5a8c9a commit 3d78e33
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 133 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# @package _global_
defaults:
- override /nas: aging_evolution_nas
- override /model: embedded_vision_nas
- override /dataset: cifar10

model:
num_classes: 10
module:
batch_size: 128
nas:
budget: 600
n_jobs: 8


trainer:
max_epochs: 10

seed: [1234]

experiment_id: "ae_nas_cifar10_fixreduce"
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ scontrol show job $SLURM_JOB_ID
conda activate hannah


hannah-train trainer.gpus=8 experiment=ae_nas_cifar10_v2 model=embedded_vision_net dataset=cifar10 model.num_classes=10 nas.n_jobs=8 fx_mac_summary=True ~normalizer
hannah-train trainer.gpus=8 experiment=ae_nas_cifar10_fixreduce model=embedded_vision_net dataset=cifar10 model.num_classes=10 nas.n_jobs=8 fx_mac_summary=True ~normalizer
4 changes: 4 additions & 0 deletions hannah/conf/model/embedded_vision_net_cwm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: hannah.models.embedded_vision_net.models.search_space_cwm
name: embedded_vision_net_cwm
num_classes: 10

6 changes: 3 additions & 3 deletions hannah/conf/nas/aging_evolution_nas.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ defaults:
_target_: hannah.nas.search.search.DirectNAS
budget: 2000
n_jobs: 10
presample: True
total_candidates: 100
presample: False
total_candidates: 50
num_selected_candidates: 20
bounds:
val_error: 0.12
val_error: 0.03
total_macs: 128000000
total_weights: 500000
140 changes: 140 additions & 0 deletions hannah/models/embedded_vision_net/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from functools import partial
from hannah.models.embedded_vision_net.expressions import expr_product
from hannah.nas.expressions.arithmetic import Ceil
from hannah.nas.expressions.types import Int
from hannah.nas.functional_operators.op import scope
from hannah.models.embedded_vision_net.operators import adaptive_avg_pooling, add, conv2d, conv_relu, depthwise_conv2d, dynamic_depth, pointwise_conv2d, linear, relu, batch_norm, choice, identity
# from hannah.nas.functional_operators.visualizer import Visualizer
from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter


@scope
def expansion(input, expanded_channels):
return pointwise_conv2d(input, expanded_channels)


@scope
def spatial_correlation(input, out_channels, kernel_size, stride=1):
return depthwise_conv2d(input, out_channels=out_channels, kernel_size=kernel_size, stride=stride)


@scope
def reduction(input, out_channels):
return pointwise_conv2d(input, out_channels=out_channels)


@scope
def reduce_expand(input, out_channels, reduce_ratio, kernel_size, stride):
in_channels = input.shape()[1]
reduced_channels = Int(in_channels / reduce_ratio)

out = reduction(input, reduced_channels)
# out.add_param('reduce_ratio', reduce_ratio)
out = batch_norm(out)
out = relu(out)
out = conv2d(out, reduced_channels, kernel_size, stride)
out = batch_norm(out)
out = relu(out)
out = expansion(out, out_channels)
out = batch_norm(out)
out = relu(out)
return out


@scope
def expand_reduce(input, out_channels, expand_ratio, kernel_size, stride):
in_channels = input.shape()[1]
expanded_channels = Int(expand_ratio * in_channels)
out = expansion(input, expanded_channels)
out = batch_norm(out)
out = relu(out)
out = spatial_correlation(out, kernel_size=kernel_size, stride=stride, out_channels=expanded_channels)
out = batch_norm(out)
out = relu(out)
out = reduction(out, out_channels)
out = batch_norm(out)
out = relu(out)
return out


@scope
def pattern(input, stride, out_channels, kernel_size, expand_ratio, reduce_ratio):
convolution = partial(conv_relu, stride=stride, kernel_size=kernel_size, out_channels=out_channels)
exp_red = partial(expand_reduce, out_channels=out_channels, expand_ratio=expand_ratio, kernel_size=kernel_size, stride=stride)
red_exp = partial(reduce_expand, out_channels=out_channels, reduce_ratio=reduce_ratio, kernel_size=kernel_size, stride=stride)
# TODO: pooling

out = choice(input, convolution, exp_red, red_exp)
return out


@scope
def residual(input, main_branch_output_shape):
input_shape = input.shape()
in_fmap = input_shape[2]
out_channels = main_branch_output_shape[1]
out_fmap = main_branch_output_shape[2]
stride = Int(Ceil(in_fmap / out_fmap))

out = conv2d(input, out_channels=out_channels, kernel_size=1, stride=stride, padding=0)
out = batch_norm(out)
out = relu(out)
return out


@scope
def block(input, depth, stride, out_channels, kernel_size, expand_ratio, reduce_ratio):
assert isinstance(depth, IntScalarParameter), "block depth must be of type IntScalarParameter"
out = input
exits = []
for i in range(depth.max+1):
out = pattern(out,
stride=stride.new() if i == 0 else 1,
out_channels=out_channels.new(),
kernel_size=kernel_size.new(),
expand_ratio=expand_ratio.new(),
reduce_ratio=reduce_ratio.new())
exits.append(out)

out = dynamic_depth(*exits, switch=depth)
res = residual(input, out.shape())
out = add(out, res)

return out


@scope
def cwm_block(input, depth, stride, channel_width_multiplier, kernel_size, expand_ratio, reduce_ratio):
assert isinstance(depth, IntScalarParameter), "block depth must be of type IntScalarParameter"
out = input
exits = []
in_channels = out.shape()[1]
for i in range(depth.max+1):
out = pattern(out,
stride=stride.new() if i == 0 else 1,
out_channels=Int(channel_width_multiplier * in_channels),
kernel_size=kernel_size.new(),
expand_ratio=expand_ratio.new(),
reduce_ratio=reduce_ratio.new())
exits.append(out)

out = dynamic_depth(*exits, switch=depth)
res = residual(input, out.shape())
out = add(out, res)

return out


@scope
def stem(input, kernel_size, stride, out_channels):
out = conv2d(input, out_channels, kernel_size, stride)
out = batch_norm(out)
out = relu(out)
return out


@scope
def classifier_head(input, num_classes):
out = choice(input, adaptive_avg_pooling)
out = linear(out, num_classes)
return out
153 changes: 41 additions & 112 deletions hannah/models/embedded_vision_net/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,138 +3,64 @@
from hannah.nas.expressions.arithmetic import Ceil
from hannah.nas.expressions.types import Int
from hannah.nas.functional_operators.executor import BasicExecutor
from hannah.nas.functional_operators.op import Tensor, scope
from hannah.nas.functional_operators.op import Tensor, get_nodes, scope
from hannah.models.embedded_vision_net.operators import adaptive_avg_pooling, add, conv2d, conv_relu, depthwise_conv2d, dynamic_depth, pointwise_conv2d, linear, relu, batch_norm, choice, identity
# from hannah.nas.functional_operators.visualizer import Visualizer
from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter
import time
from hannah.nas.parameters.parameters import CategoricalParameter, FloatScalarParameter, IntScalarParameter
from hannah.models.embedded_vision_net.blocks import block, cwm_block, classifier_head, stem


@scope
def expansion(input, expanded_channels):
return pointwise_conv2d(input, expanded_channels)


@scope
def spatial_correlation(input, out_channels, kernel_size, stride=1):
return depthwise_conv2d(input, out_channels=out_channels, kernel_size=kernel_size, stride=stride)


@scope
def reduction(input, out_channels):
return pointwise_conv2d(input, out_channels=out_channels)


@scope
def reduce_expand(input, out_channels, reduce_ratio, kernel_size, stride):
in_channels = input.shape()[1]
reduced_channels = Int(reduce_ratio * in_channels)

out = reduction(input, reduced_channels)
out = batch_norm(out)
out = relu(out)
out = conv2d(out, reduced_channels, kernel_size, stride)
out = batch_norm(out)
out = relu(out)
out = expansion(out, out_channels)
out = batch_norm(out)
out = relu(out)
return out


@scope
def expand_reduce(input, out_channels, expand_ratio, kernel_size, stride):
in_channels = input.shape()[1]
expanded_channels = Int(expand_ratio * in_channels)
out = expansion(input, expanded_channels)
out = batch_norm(out)
out = relu(out)
out = spatial_correlation(out, kernel_size=kernel_size, stride=stride, out_channels=expanded_channels)
out = batch_norm(out)
out = relu(out)
out = reduction(out, out_channels)
out = batch_norm(out)
out = relu(out)
return out


@scope
def pattern(input, stride, out_channels, kernel_size, expand_ratio, reduce_ratio):
convolution = partial(conv_relu, stride=stride, kernel_size=kernel_size, out_channels=out_channels)
exp_red = partial(expand_reduce, out_channels=out_channels, expand_ratio=expand_ratio, kernel_size=kernel_size, stride=stride)
red_exp = partial(reduce_expand, out_channels=out_channels, reduce_ratio=reduce_ratio, kernel_size=kernel_size, stride=stride)
# TODO: pooling

out = choice(input, convolution, exp_red, red_exp)
return out


@scope
def residual(input, main_branch_output_shape):
input_shape = input.shape()
in_fmap = input_shape[2]
out_channels = main_branch_output_shape[1]
out_fmap = main_branch_output_shape[2]
stride = Int(Ceil(in_fmap / out_fmap))

out = conv2d(input, out_channels=out_channels, kernel_size=1, stride=stride, padding=0)
out = batch_norm(out)
out = relu(out)
return out
def search_space(name, input, num_classes=10):
out_channels = IntScalarParameter(16, 64, step_size=4, name='out_channels')
kernel_size = CategoricalParameter([3, 5, 7, 9], name='kernel_size')
stride = CategoricalParameter([1, 2], name='stride')
expand_ratio = IntScalarParameter(1, 3, name='expand_ratio')
reduce_ratio = IntScalarParameter(2, 4, name='reduce_ratio')

depth = IntScalarParameter(0, 2, name='depth')

@scope
def block(input, depth, stride, out_channels, kernel_size, expand_ratio, reduce_ratio):
assert isinstance(depth, IntScalarParameter), "block depth must be of type IntScalarParameter"
out = input
num_blocks = IntScalarParameter(0, 6, name='num_blocks')
exits = []
for i in range(depth.max+1):
out = pattern(out,
stride=stride.new() if i == 0 else 1,
out_channels=out_channels.new(),
kernel_size=kernel_size.new(),
expand_ratio=expand_ratio.new(),
reduce_ratio=reduce_ratio.new())
exits.append(out)

out = dynamic_depth(*exits, switch=depth)
res = residual(input, out.shape())
out = add(out, res)

return out

stem_kernel_size = CategoricalParameter([3, 5], name="kernel_size")
stem_channels = IntScalarParameter(min=16, max=32, step_size=4, name="out_channels")
out = stem(input, stem_kernel_size, stride.new(), stem_channels)
for i in range(num_blocks.max+1):
out = block(out, depth=depth.new(), stride=stride.new(), out_channels=out_channels.new(), kernel_size=kernel_size.new(),
expand_ratio=expand_ratio.new(), reduce_ratio=reduce_ratio.new())
exits.append(out)

@scope
def stem(input, kernel_size, stride, out_channels):
out = conv2d(input, out_channels, kernel_size, stride)
out = batch_norm(out)
out = relu(out)
return out
out = dynamic_depth(*exits, switch=num_blocks)
out = classifier_head(out, num_classes=num_classes)

strides = [v for k, v in out.parametrization(flatten=True).items() if k.split('.')[-1] == 'stride']
total_stride = expr_product(strides)
out.cond(input.shape()[2] / total_stride > 1)

@scope
def classifier_head(input, num_classes):
out = choice(input, adaptive_avg_pooling)
out = linear(out, num_classes)
return out


def search_space(name, input, num_classes=10):
out_channels = IntScalarParameter(32, 256, step_size=4, name='out_channels')
def search_space_cwm(name, input, num_classes=10):
channel_width_multiplier = CategoricalParameter([1.0, 1.1, 1.2, 1.3, 1.4, 1.5], name="channel_width_multiplier")
kernel_size = CategoricalParameter([3, 5, 7, 9], name='kernel_size')
stride = CategoricalParameter([1, 2], name='stride')
expand_ratio = IntScalarParameter(1, 6, name='expand_ratio')
reduce_ratio = IntScalarParameter(1, 6, name='reduce_ratio')

expand_ratio = IntScalarParameter(2, 6, name='expand_ratio')
reduce_ratio = IntScalarParameter(3, 6, name='reduce_ratio')
depth = IntScalarParameter(0, 2, name='depth')

num_blocks = IntScalarParameter(0, 6, name='num_blocks')
num_blocks = IntScalarParameter(0, 5, name='num_blocks')
exits = []

out = input
stem_kernel_size = CategoricalParameter([3, 5], name="kernel_size")
stem_channels = IntScalarParameter(min=16, max=32, step_size=4, name="out_channels")
out = stem(input, stem_kernel_size, stride.new(), stem_channels)
for i in range(num_blocks.max+1):
out = block(out, depth=depth.new(), stride=stride.new(), out_channels=out_channels.new(), kernel_size=kernel_size.new(),
expand_ratio=expand_ratio.new(), reduce_ratio=reduce_ratio.new())
out = cwm_block(out,
depth=depth.new(),
stride=stride.new(),
channel_width_multiplier=channel_width_multiplier.new(),
kernel_size=kernel_size.new(),
expand_ratio=expand_ratio.new(),
reduce_ratio=reduce_ratio.new())
exits.append(out)

out = dynamic_depth(*exits, switch=num_blocks)
Expand All @@ -144,4 +70,7 @@ def search_space(name, input, num_classes=10):
total_stride = expr_product(strides)
out.cond(input.shape()[2] / total_stride > 1)

multipliers = [v for k, v in out.parametrization(flatten=True).items() if k.split('.')[-1] == 'channel_width_multiplier']
max_multiplication = expr_product(multipliers)
out.cond(max_multiplication < 4)
return out
Loading

0 comments on commit 3d78e33

Please sign in to comment.