diff --git a/README.md b/README.md index cd17d85c..d87634f3 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ limitations under the License. # HANNAH - Hardware Accelerator and Neural network searcH # Getting Started +!!! note + For more information, visit the [documentation](https://ekut-es.github.io/hannah/). ## Installing dependencies diff --git a/doc/assets/choice_node.jpg b/doc/assets/choice_node.jpg new file mode 100755 index 00000000..218360c8 Binary files /dev/null and b/doc/assets/choice_node.jpg differ diff --git a/doc/assets/graph.jpg b/doc/assets/graph.jpg new file mode 100755 index 00000000..daaa7554 Binary files /dev/null and b/doc/assets/graph.jpg differ diff --git a/doc/assets/graph.png b/doc/assets/graph.png new file mode 100755 index 00000000..a0f1adc6 Binary files /dev/null and b/doc/assets/graph.png differ diff --git a/doc/nas/eval.md b/doc/nas/eval.md new file mode 100644 index 00000000..d57b6c40 --- /dev/null +++ b/doc/nas/eval.md @@ -0,0 +1,3 @@ +# Evaluation +!!! Note + Coming soon \ No newline at end of file diff --git a/doc/nas/legacy.md b/doc/nas/legacy.md new file mode 100644 index 00000000..49205d75 --- /dev/null +++ b/doc/nas/legacy.md @@ -0,0 +1,116 @@ + + +!!! warning + This is the old documentation for the NAS. A lot has changed and a lot will still change in the future so handle with care. + +# Neural architecture search + +In contrast to hyperparameter optimization neural architecture search, explores new neural network hyperparameters. + +A aging evolution based neural architecture search has been implemented as a hydra plugin: + + hannah-train --config-name config_unas + +To launch multiple configuration jobs in parallel use joblib launcher: + + hannah-train --config-name config_unas hydra/launcher=joblib + +Parametrization for neural architecture search need to be given as *YAML* configuration files at +the moment. For an example see: `speech_recognition/conf/config_unas.yaml` + +## Parametrization + +The Parametrization contains the following elements: + +### Choice Parameter + +Choice Parameters select options from a list of parameters. They are configured as a list of options in +the parameters. Example: + + conv_size: [1,3,5,7,9,11] + +### Choice List Parameters + +Choice List Parameters represent a variable length list of Choices. They are configured with the follwing parameters: + +`min` +: Minimum length of list + +`max` +: Maximum length of list+1 + +`choices` +: List of Choices + +Example: + + min: 4 + max: 10 + choices: + - _target_: "torch.nn.Conv2d" + size: 3 + - _target : "torch.nn.MaxPool2d" + size: 7 + +*Warning*: Mutations for intervall parameters currently always sample randomly from the range of values + +### Intervall Parameters + +Intervall Parameters represent a Scalar Value from an intervall of Values +They are configure with the following parameters: + +`lower` +: lower bound of intervall [lower, upper[ + +`upper` +: upper bound of intervall [lower, upper[ + +`int` +: set to true to generate integers + +`log` +: set to true to generate log scaled distribution + +### Subset Parameter + +Subset Parameters select a subset of a list of choices. + +They are configured using the following parameters: + +`choices` +: List of choices to sample from + + +`size` +: size of the subset to generate + + + +### Partition Parameter + +Partition parameters split the list of choices into a predefined number of partitions. + +They are configured using the following parameters: + +`choices` +: List of choices to partition + +`partition` +: Number of partitions to generate diff --git a/doc/nas/parametrization.md b/doc/nas/parametrization.md new file mode 100644 index 00000000..4bd3ece3 --- /dev/null +++ b/doc/nas/parametrization.md @@ -0,0 +1,37 @@ + +!!! Note + Coming soon +# Parametrization + +## Usage in Search Space + +## Basic Parameter Types +### IntScalarParameter +### FloatScalarParameter +### CategoricalParameter + +## Advanced Parameters +### Choice +### SymbolicAttr +### SymbolicSequence + +## Placeholders + +## Expressions \ No newline at end of file diff --git a/doc/nas/search.md b/doc/nas/search.md new file mode 100644 index 00000000..9a237617 --- /dev/null +++ b/doc/nas/search.md @@ -0,0 +1,3 @@ +# Search +!!! Note + Coming soon \ No newline at end of file diff --git a/doc/nas/search_spaces.md b/doc/nas/search_spaces.md new file mode 100644 index 00000000..80194e92 --- /dev/null +++ b/doc/nas/search_spaces.md @@ -0,0 +1,550 @@ +!!! warning + The search spaces in HANNAH are currently under construction. If you run into bugs, please contact us. + + +# Search Spaces + +Search spaces in HANNAH are directed graphs (DAGs) where the nodes are **Ops** or **Tensors** and the edges indicate data movement. + +!!! note + Search spaces are not executable themselves but need an [Executor](#executor) which uses the current parametrization state to + build a `forward`. + + + +```python +from hannah.nas.functional_operators.operators import Conv2d +``` + +![Graph illustration](../../assets/graph.jpg) + +## Basic Building Blocks + +### Ops & Tensors + +**Op** nodes represent the operators used in the networks of the search space. Their basic syntax is + +```python +# usage +var_name = Operator(param0=val0, param1=val1, ...)(*operands) +``` + +**Tensor** nodes indicate the presence of data in the graph. They do not themselves contain actual values when +the search space graph is defined (the actual data is managed by the [Executor](#executor)). The tensor node +defines attributes that the data has at this point in the graph (e.g., shape, axis names, datatypes, ...). + + +```python +from hannah.nas.functional_operators.operators import Conv2d +from hannah.nas.functional_operators.op import Tensor + +input = Tensor(name='input', shape=(1, 3, 32, 32), axis=("N", "C", "H", "W")) +weight = Tensor(name='weight', shape=(32, 3, 1, 1), axis=("O", "I", "kH", "kW")) + +conv = Conv2d(stride=2, dilation=1) # Define operator and parametrization +graph = conv(input, weight) # Define/create/extend graph +graph +``` + + + + + Conv2d(Conv2d_0) + + + +A set of basic operators is implemented in HANNAH, among them + +* Convolution (1D, 2D) +* Linear +* BatchNorm +* Relu +* Add + +and more operators will be added in the future. It is also easy to +define a new operator, see [Custom Ops](#custom-ops). + +## Parametrization & Expressions + +!!! note + For more information about parametrization and expressions, see [Parametrization](parametrization.md). + +To build a search space it is not sufficient to feed scalar values to operator parameters. Instead, one can use +*parameters*. + + +```python +from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter + +input = Tensor(name='input', shape=(1, 3, 32, 32), axis=("N", "C", "H", "W")) +weight = Tensor(name='weight', shape=(IntScalarParameter(min=8, max=64, name='out_channels'), 3, 1, 1), axis=("O", "I", "kH", "kW")) + +# a search space with stride 1 and stride 2 convolutions +graph = Conv2d(stride=CategoricalParameter(name='stride', choices=[1, 2]))(input, weight) +graph.parametrization(flatten=True) +``` + + + + + {'Conv2d_0.weight.out_channels': IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = Conv2d_0.weight.out_channels, min = 8, max = 64, step_size = 1, current_value = 8), + None: CategoricalParameter(rng = Generator(PCG64), name = stride, id = None, choices = [1, 2], current_value = 1)} + + + +As futher explained in [Parametrization](parametrization.md), parameters are *expressions* and can be combined to more complex *expressions*, +encoding properties of the search space symbolically. One common use-case is symbolically expressing shapes. Consider for example the following: + + +```python +in_channel = 3 +kernel_size = 1 +input = Tensor(name='input', + shape=(1, in_channel, 32, 32), + axis=('N', 'C', 'H', 'W')) + +weight_0 = Tensor(name='weight', + shape=(IntScalarParameter(min=8, max=64, name='out_channels'), in_channel, kernel_size, kernel_size), + axis=("O", "I", "kH", "kW")) +conv_0 = Conv2d(stride=CategoricalParameter(name='stride', choices=[1, 2]))(input, weight_0) +``` + +How can we know the output shape of `conv_0`, e.g., to put it into the weight tensor of a following convolution, without knowing what value +the ``out_channel`` parameter has? +--> Each node has a method `.shape()` which returns the shape as an expression and can be used interchangeably with actual values. Those expressions +are then only evaluated at sampling and during the forward. + + +```python +print("Input shape: ", input.shape()) +print("Weight shape: ", weight_0.shape()) +print("Convolution output shape:", conv_0.shape()) +``` + + Input shape: (1, 3, 32, 32) + Weight shape: (IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = Conv2d_0.weight.out_channels, min = 8, max = 64, step_size = 1, current_value = 8), 3, 1, 1) + Convolution output shape: (1, IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = Conv2d_0.weight.out_channels, min = 8, max = 64, step_size = 1, current_value = 8), , ) + + +The `lazy` keyword can be used to evaluate values which *might* be parameters (but could also be `int` or else). + + +```python +from hannah.nas.functional_operators.lazy import lazy + + +print("Input shape: ", [lazy(i) for i in input.shape()]) +print("Weight shape: ", [lazy(i) for i in weight_0.shape()]) +print("Convolution output shape:", [lazy(i) for i in conv_0.shape()]) +``` + + Input shape: [1, 3, 32, 32] + Weight shape: [8, 3, 1, 1] + Convolution output shape: [1, 8, 16, 16] + + +When defining an operator, one also has to define a `shape` function (the default shape function is identity, i.e., ``output_shape == input_shape``). Tensors return their own shape. + +## Graphs and Hierarchy + +As seen in the simple examples above, we can chain op and tensor nodes together to create graphs and use parameters to span search spaces. + + +```python +from hannah.nas.functional_operators.operators import Relu + +input = Tensor(name='input', + shape=(1, 3, 32, 32), + axis=('N', 'C', 'H', 'W')) + +weight_0 = Tensor(name='weight', shape=(IntScalarParameter(min=8, max=64, name='out_channels'), 3, 1, 1), axis=("O", "I", "kH", "kW")) + +conv_0 = Conv2d(stride=CategoricalParameter(name='stride', choices=[1, 2]))(input, weight_0) +relu_0 = Relu()(conv_0) + +weight_1 = Tensor(name='weight', shape=(IntScalarParameter(min=32, max=64, name='out_channels'), conv_0.shape()[1], 3, 3), axis=("O", "I", "kH", "kW")) +conv_1 = Conv2d(stride=CategoricalParameter(name='stride', choices=[1, 2]))(relu_0, weight_1) +relu_1 = Relu()(conv_1) + +``` + + +```python +relu_1.parametrization(flatten=True) +``` + + + + + {'Conv2d_1.weight.out_channels': IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = Conv2d_1.weight.out_channels, min = 32, max = 64, step_size = 1, current_value = 32), + 'Conv2d_0.weight.out_channels': IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = Conv2d_0.weight.out_channels, min = 8, max = 64, step_size = 1, current_value = 8), + 'Conv2d_0.stride': CategoricalParameter(rng = Generator(PCG64), name = stride, id = Conv2d_0.stride, choices = [1, 2], current_value = 2), + 'Conv2d_1.stride': CategoricalParameter(rng = Generator(PCG64), name = stride, id = Conv2d_1.stride, choices = [1, 2], current_value = 2)} + + + +Nodes have *operands* for backwards traversal and *users* for forward traversal. +With helper functions like `get_nodes` one can iterate through all graph nodes. + + +```python +from hannah.nas.functional_operators.op import get_nodes + +print("Relu Operands: ", relu_1.operands) +print("Conv Users: ", relu_1.operands[0].users) + +print("\nNodes:") +for node in get_nodes(relu_1): + print('Node:', node) + print('\tOperands: ', node.operands) + +``` + + Relu Operands: [Conv2d(Conv2d_1)] + Conv Users: [Relu(Relu_1)] + + Nodes: + Node: Relu(Relu_1) + Operands: [Conv2d(Conv2d_1)] + Node: Conv2d(Conv2d_1) + Operands: [Relu(Relu_0), Tensor(Conv2d_1.weight)] + Node: Tensor(Conv2d_1.weight) + Operands: [] + Node: Relu(Relu_0) + Operands: [Conv2d(Conv2d_0)] + Node: Conv2d(Conv2d_0) + Operands: [Tensor(input), Tensor(Conv2d_0.weight)] + Node: Tensor(Conv2d_0.weight) + Operands: [] + Node: Tensor(input) + Operands: [] + + +### Blocks + +Creating large graphs with a lot of operators and tensors manually can get tedious and convoluted. Instead, we can define search space graphs in a hierarchical manner by encapsulating them in functions: + + +```python +def conv_relu(input, kernel_size, out_channels, stride): + in_channels = input.shape()[1] + weight = Tensor(name='weight', + shape=(out_channels, in_channels, kernel_size, kernel_size), + axis=('O', 'I', 'kH', 'kW'), + grad=True) + + conv = Conv2d(stride=stride)(input, weight) + relu = Relu()(conv) + return relu +``` + + +```python +input = Tensor(name='input', + shape=(1, 3, 32, 32), + axis=('N', 'C', 'H', 'W')) + +kernel_size = CategoricalParameter(name="kernel_size", choices=[1, 3, 5]) +stride = CategoricalParameter(name="stride", choices=[1, 2]) +out_channels = IntScalarParameter(name="out_channels", min=8, max=64) +net = conv_relu(input, kernel_size=kernel_size, out_channels=out_channels, stride=stride) +net = conv_relu(net, kernel_size=kernel_size, out_channels=out_channels, stride=stride) + +for n in get_nodes(net): + print(n) +``` + + Relu(Relu_1) + Conv2d(Conv2d_1) + Tensor(Conv2d_1.weight) + Relu(Relu_0) + Conv2d(Conv2d_0) + Tensor(Conv2d_0.weight) + Tensor(input) + + + +```python +net.parametrization(flatten=True) +``` + + + + + {'Conv2d_0.weight.kernel_size': CategoricalParameter(rng = Generator(PCG64), name = kernel_size, id = Conv2d_0.weight.kernel_size, choices = [1, 3, 5], current_value = 5), + 'Conv2d_0.weight.out_channels': IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = Conv2d_0.weight.out_channels, min = 8, max = 64, step_size = 1, current_value = 8), + 'Conv2d_0.stride': CategoricalParameter(rng = Generator(PCG64), name = stride, id = Conv2d_0.stride, choices = [1, 2], current_value = 1)} + + + +Note, how there is just one set of parameters. If defined this way, both blocks share their parameters. To define seperate parameters one can use `param.new()` + + +```python +input = Tensor(name='input', + shape=(1, 3, 32, 32), + axis=('N', 'C', 'H', 'W')) + +kernel_size = CategoricalParameter(name="kernel_size", choices=[1, 3, 5]) +stride = CategoricalParameter(name="stride", choices=[1, 2]) +out_channels = IntScalarParameter(name="out_channels", min=8, max=64) +net = conv_relu(input, kernel_size=kernel_size.new(), out_channels=out_channels.new(), stride=stride.new()) +net = conv_relu(net, kernel_size=kernel_size.new(), out_channels=out_channels.new(), stride=stride.new()) + +net.parametrization(flatten=True) +``` + + + + + {'Conv2d_1.weight.kernel_size': CategoricalParameter(rng = Generator(PCG64), name = kernel_size, id = Conv2d_1.weight.kernel_size, choices = [1, 3, 5], current_value = 3), + 'Conv2d_1.weight.out_channels': IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = Conv2d_1.weight.out_channels, min = 8, max = 64, step_size = 1, current_value = 8), + 'Conv2d_0.weight.kernel_size': CategoricalParameter(rng = Generator(PCG64), name = kernel_size, id = Conv2d_0.weight.kernel_size, choices = [1, 3, 5], current_value = 3), + 'Conv2d_0.weight.out_channels': IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = Conv2d_0.weight.out_channels, min = 8, max = 64, step_size = 1, current_value = 8), + 'Conv2d_0.stride': CategoricalParameter(rng = Generator(PCG64), name = stride, id = Conv2d_0.stride, choices = [1, 2], current_value = 2), + 'Conv2d_1.stride': CategoricalParameter(rng = Generator(PCG64), name = stride, id = Conv2d_1.stride, choices = [1, 2], current_value = 2)} + + + +These function blocks can be nested as desired. + + +```python +def block(input): + kernel_size = CategoricalParameter(name="kernel_size", choices=[1, 3, 5]) + stride = CategoricalParameter(name="stride", choices=[1, 2]) + out_channels = IntScalarParameter(name="out_channels", min=8, max=64) + net = conv_relu(input, kernel_size=kernel_size.new(), out_channels=out_channels.new(), stride=stride.new()) + net = conv_relu(net, kernel_size=kernel_size.new(), out_channels=out_channels.new(), stride=stride.new()) + net = conv_relu(net, kernel_size=kernel_size.new(), out_channels=out_channels.new(), stride=stride.new()) + return net + +input = Tensor(name='input', + shape=(1, 3, 32, 32), + axis=('N', 'C', 'H', 'W')) +net = block(input) +net = block(net) + +for n in get_nodes(net): + print(n) +``` + + Relu(Relu_5) + Conv2d(Conv2d_5) + Tensor(Conv2d_5.weight) + Relu(Relu_4) + Conv2d(Conv2d_4) + Tensor(Conv2d_4.weight) + Relu(Relu_3) + Conv2d(Conv2d_3) + Tensor(Conv2d_3.weight) + Relu(Relu_2) + Conv2d(Conv2d_2) + Tensor(Conv2d_2.weight) + Relu(Relu_1) + Conv2d(Conv2d_1) + Tensor(Conv2d_1.weight) + Relu(Relu_0) + Conv2d(Conv2d_0) + Tensor(Conv2d_0.weight) + Tensor(input) + + +### Scopes + +As seen above, while the *definition* of the graph is made in a hierarchical manner, the actual graph and its node are "flat" and do not have any inherent hierarchy. To make the graph more clear and readable one can use **scopes** with the `@scope` decorator for blocks. Note that `@scope` does not have any effect on the inherent structure of the graph but only affects the node `id`s. + + +```python +from hannah.nas.functional_operators.op import scope + + +@scope +def conv_relu(input, kernel_size, out_channels, stride): + in_channels = input.shape()[1] + weight = Tensor(name='weight', + shape=(out_channels, in_channels, kernel_size, kernel_size), + axis=('O', 'I', 'kH', 'kW'), + grad=True) + + conv = Conv2d(stride=stride)(input, weight) + relu = Relu()(conv) + return relu + +@scope +def block(input): + kernel_size = CategoricalParameter(name="kernel_size", choices=[1, 3, 5]) + stride = CategoricalParameter(name="stride", choices=[1, 2]) + out_channels = IntScalarParameter(name="out_channels", min=8, max=64) + net = conv_relu(input, kernel_size=kernel_size.new(), out_channels=out_channels.new(), stride=stride.new()) + net = conv_relu(net, kernel_size=kernel_size.new(), out_channels=out_channels.new(), stride=stride.new()) + net = conv_relu(net, kernel_size=kernel_size.new(), out_channels=out_channels.new(), stride=stride.new()) + return net + +input = Tensor(name='input', + shape=(1, 3, 32, 32), + axis=('N', 'C', 'H', 'W')) +net = block(input) +net = block(net) + +for n in get_nodes(net): + print(n) +``` + + Relu(block_1.conv_relu_2.Relu_0) + Conv2d(block_1.conv_relu_2.Conv2d_0) + Tensor(block_1.conv_relu_2.Conv2d_0.weight) + Relu(block_1.conv_relu_1.Relu_0) + Conv2d(block_1.conv_relu_1.Conv2d_0) + Tensor(block_1.conv_relu_1.Conv2d_0.weight) + Relu(block_1.conv_relu_0.Relu_0) + Conv2d(block_1.conv_relu_0.Conv2d_0) + Tensor(block_1.conv_relu_0.Conv2d_0.weight) + Relu(block_0.conv_relu_2.Relu_0) + Conv2d(block_0.conv_relu_2.Conv2d_0) + Tensor(block_0.conv_relu_2.Conv2d_0.weight) + Relu(block_0.conv_relu_1.Relu_0) + Conv2d(block_0.conv_relu_1.Conv2d_0) + Tensor(block_0.conv_relu_1.Conv2d_0.weight) + Relu(block_0.conv_relu_0.Relu_0) + Conv2d(block_0.conv_relu_0.Conv2d_0) + Tensor(block_0.conv_relu_0.Conv2d_0.weight) + Tensor(input) + + +## Choice Ops + +A choice op is a special node kind that allows to have multiple paths in the graph that exclude each other (or have other specialized behaviour). + + +```python +from hannah.nas.functional_operators.operators import Identity +from functools import partial +from hannah.nas.functional_operators.op import ChoiceOp + +@scope +def choice_block(input): + kernel_size = CategoricalParameter([1, 3, 5], name='kernel_size') + out_channels = IntScalarParameter(min=4, max=64, name='out_channels') + stride = CategoricalParameter([1, 2], name='stride') + + identity = Identity() + optional_conv = partial(conv_relu, out_channels=out_channels.new(), stride=stride.new(), kernel_size=kernel_size.new()) + net = ChoiceOp(identity, optional_conv)(input) + return net + + +``` + + +```python +input = Tensor(name='input', shape=(1, 3, 32, 32), axis=('N', 'C', 'H', 'W')) +conv = conv_relu(input, out_channels=out_channels.new(), stride=stride.new(), kernel_size=kernel_size.new()) +net = choice_block(conv) + +net.parametrization(flatten=True) + +``` + + + + + {'choice_block_0.ChoiceOp_0.choice': IntScalarParameter(rng = Generator(PCG64), name = choice, id = choice_block_0.ChoiceOp_0.choice, min = 0, max = 1, step_size = 1, current_value = 0), + 'choice_block_0.conv_relu_1.Conv2d_0.stride': CategoricalParameter(rng = Generator(PCG64), name = stride, id = choice_block_0.conv_relu_1.Conv2d_0.stride, choices = [1, 2], current_value = 2), + 'choice_block_0.conv_relu_1.Conv2d_0.weight.kernel_size': CategoricalParameter(rng = Generator(PCG64), name = kernel_size, id = choice_block_0.conv_relu_1.Conv2d_0.weight.kernel_size, choices = [1, 3, 5], current_value = 5), + 'choice_block_0.conv_relu_1.Conv2d_0.weight.out_channels': IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = choice_block_0.conv_relu_1.Conv2d_0.weight.out_channels, min = 4, max = 64, step_size = 1, current_value = 4), + 'conv_relu_0.Conv2d_0.stride': CategoricalParameter(rng = Generator(PCG64), name = stride, id = conv_relu_0.Conv2d_0.stride, choices = [1, 2], current_value = 2), + 'conv_relu_0.Conv2d_0.weight.kernel_size': CategoricalParameter(rng = Generator(PCG64), name = kernel_size, id = conv_relu_0.Conv2d_0.weight.kernel_size, choices = [1, 3, 5], current_value = 3), + 'conv_relu_0.Conv2d_0.weight.out_channels': IntScalarParameter(rng = Generator(PCG64), name = out_channels, id = conv_relu_0.Conv2d_0.weight.out_channels, min = 8, max = 64, step_size = 1, current_value = 8)} + + + +![Choice node](../../assets/choice_node.jpg) + +!!! note + When defining options for a choice node, one can either use ops directly (see ``Identity()`` above) or use block functions (``conv_relu``). For block functions, one has to use ``functools.partial`` to enable + the choice node to perform the respective integration in the graph. + +During execution, the choice node can be leveraged to define the behaviour (e.g., select one and only one path, execute all paths and return a parametrized sum for differential NAS, ...). Choice nodes can, for example, be used to search over different operator types, different operator patterns, or to implement dynamic depth/a variable amount of layers/blocks. + + +```python +def dynamic_depth(*exits, switch): + return ChoiceOp(*exits, switch=switch)() +``` + +## Custom Ops + +To define custom operators, one can inherit from the ``Op`` class. Then, one can override the ``__call__(self, *operands)`` class to perform specific actions, e.g., saving certain parameters of the operands as fields of the operator instance that is returned. Don't forget to call ``super().__call__(*operands)``, which performs the integration of the new operator instance into the graph. + +Then, one has to provide a ``_forward_implementation(self, *args)``, which defines the computation that the operator executes. + +Lastly, a ``shape_fun(self)`` defines the output shape of the operator. + +## Executor + +The search space graphs are not themselves executable. For that one needs an ``Executor``. The ``BasicExecutor`` analyzes the graph to find dependencies and a valid node order (e.g., to execute the results of operands first before they are added in an ``Add`` operation) and builds a ``forward`` function. It also registers torch parameters and buffers for training.The executor should be usable as a normal ``torch.nn.Module``. One can define custom executors, e.g., for weight sharing NAS or differential NAS. + + +```python +import torch +from hannah.nas.functional_operators.executor import BasicExecutor + + +input = Tensor(name='input', + shape=(1, 3, 32, 32), + axis=('N', 'C', 'H', 'W')) +net = block(input) +net = block(net) + +model = BasicExecutor(net) +model.initialize() + +x = torch.randn(input.shape()) +model.forward(x) +``` + + + + + tensor([[[[0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000]], + + [[0.0255, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0152, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0898, 0.0000], + [0.1132, 0.0894, 0.0094, 0.0138]], + + [[0.0000, 0.0000, 0.0365, 0.0000], + [0.0000, 0.1532, 0.0000, 0.2529], + [0.0000, 0.0859, 0.0396, 0.0000], + [0.0000, 0.2311, 0.0757, 0.0000]], + + [[0.0000, 0.1285, 0.1754, 0.0000], + [0.1788, 0.1729, 0.1973, 0.1036], + [0.1823, 0.2994, 0.2293, 0.2580], + [0.0554, 0.2454, 0.1355, 0.3018]], + + [[0.0000, 0.0234, 0.0000, 0.0000], + [0.0725, 0.0212, 0.0615, 0.0960], + [0.1040, 0.0960, 0.1613, 0.0927], + [0.1025, 0.0846, 0.0000, 0.0424]], + + [[0.0000, 0.0000, 0.0672, 0.0818], + [0.0000, 0.1420, 0.0404, 0.0326], + [0.0000, 0.0000, 0.0000, 0.1140], + [0.0000, 0.1518, 0.1521, 0.2088]], + + [[0.0000, 0.0995, 0.1362, 0.0000], + [0.0000, 0.1206, 0.0000, 0.0000], + [0.0000, 0.1001, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0435]], + + [[0.0000, 0.0000, 0.0000, 0.0245], + [0.0000, 0.0938, 0.0000, 0.0763], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000]]]], grad_fn=) + + diff --git a/doc/nas/usage.md b/doc/nas/usage.md new file mode 100644 index 00000000..ceae994b --- /dev/null +++ b/doc/nas/usage.md @@ -0,0 +1,29 @@ + +!!! Note + Coming soon +# Introduction + +# Usage + +## NAS CLI + +## CL options + +## Debugging in VSCode \ No newline at end of file diff --git a/hannah/models/capsule_net/models.py b/hannah/models/capsule_net/models.py index dc13135b..49084efc 100644 --- a/hannah/models/capsule_net/models.py +++ b/hannah/models/capsule_net/models.py @@ -41,7 +41,7 @@ def __init__(self, params, id) -> None: self.mods = nn.ModuleDict({"relu": relu_act, "identity": identity_act}) # self.mods = {"relu": relu_act, "identity": identity_act} - self.choice = handle_parameter(self, params, f"{self.id}.choice") + self.choice = handle_parameter(self, params, "choice") self.active_module = Choice(self.mods, self.choice) def initialize(self): @@ -68,8 +68,8 @@ def __init__(self, params, id, inputs) -> None: self.in_channels = input_shape[1] # FIXME: Share parameters over patterns - self.out_channels = handle_parameter(self, self.params.convolution.out_channels, f"{self.id}.out_channels") - self.expand_ratio = handle_parameter(self, self.params.expand_reduce.ratio, f"{self.id}.expand_ratio") + self.out_channels = handle_parameter(self, self.params.convolution.out_channels, "out_channels") + self.expand_ratio = handle_parameter(self, self.params.expand_reduce.ratio, "expand_ratio") self.expanded_channels = Int(self.expand_ratio * self.in_channels) self.expansion = PointwiseConvolution(self.expanded_channels, self.id + '.expand', inputs) @@ -126,8 +126,8 @@ def __init__(self, params, id, inputs) -> None: self.in_channels = self.input_shape[1] # FIXME: Share parameters over patterns - self.out_channels = handle_parameter(self, self.params.convolution.out_channels, f"{self.id}.out_channels") - self.reduce_ratio = handle_parameter(self, self.params.reduce_expand.ratio, f"{self.id}.reduce_ratio") + self.out_channels = handle_parameter(self, self.params.convolution.out_channels, f"out_channels") + self.reduce_ratio = handle_parameter(self, self.params.reduce_expand.ratio, f"reduce_ratio") self.reduced_channels = Int(self.reduce_ratio * self.in_channels) self.reduction = PointwiseConvolution(self.reduced_channels, self.id + '.expand', inputs) @@ -233,8 +233,8 @@ def __init__(self, params, input_shape, id) -> None: self.id = id # shared parameters - self.stride = handle_parameter(self, params.stride, f"{self.id}.stride") - self.out_channels = handle_parameter(self, params.out_channels, f"{self.id}.out_channels") + self.stride = handle_parameter(self, params.stride, "stride") + self.out_channels = handle_parameter(self, params.out_channels, "out_channels") conv_params = OmegaConf.create({'convolution': {'stride': self.stride, 'out_channels': self.out_channels}, 'pooling': {'stride': self.stride}}, flags={"allow_objects": True}) @@ -289,7 +289,7 @@ def __init__(self, params, input_shape, id) -> None: super().__init__() self.id = id self.input_shape = input_shape - self.depth = handle_parameter(self, params.depth, f'{self.id}.depth') + self.depth = handle_parameter(self, params.depth, 'depth') self.mods = nn.ModuleList() self.params = params diff --git a/hannah/nas/functional_operators/op.py b/hannah/nas/functional_operators/op.py index 9f8d2b27..f36f8fff 100644 --- a/hannah/nas/functional_operators/op.py +++ b/hannah/nas/functional_operators/op.py @@ -166,16 +166,13 @@ def __init__(self, name, shape, axis, dtype=FloatType(), grad=False) -> None: self.dtype = dtype # FIXME: Maybe check for lists/tuples in @parametrize? - # FIXME: What if a parameter is defined elsewhere (e.g. conv) + # FIXME: What if a parameter is defined elsewhere (e.g. conv) --> not good for s in self._shape: if is_parametrized(s): # FIXME: IDs of parameters - self._PARAMETERS[self.id + '.' + s.name] = s - elif isinstance(s, Expression): - params = extract_parameter_from_expression(s) - for p in params: - self._PARAMETERS[self.id + '.' + p.name] = p - + if s.id is None: # Else: parameter is registered elsewhere + s.id = self.id + '.' + s.name + self._PARAMETERS[s.id] = s self.axis = axis self.users = [] self.operands = [] diff --git a/hannah/nas/parameters/parameters.py b/hannah/nas/parameters/parameters.py index f02b4ff6..bfad3b7e 100644 --- a/hannah/nas/parameters/parameters.py +++ b/hannah/nas/parameters/parameters.py @@ -35,7 +35,7 @@ class Parameter(Expression): def __init__( self, - name: Optional[str] = None, + name: Optional[str] = "", rng: Optional[Union[np.random.Generator, int]] = None, ) -> None: super().__init__() @@ -102,7 +102,7 @@ def __init__( min: Union[int, IntScalarParameter], max: Union[int, IntScalarParameter], step_size: int = 1, - name: Optional[str] = None, + name: Optional[str] = "", rng: Optional[Union[np.random.Generator, int]] = None, ) -> None: super().__init__(name, rng) @@ -157,7 +157,7 @@ def __init__( self, min, max, - name: Optional[str] = None, + name: Optional[str] = "", rng: Optional[Union[np.random.Generator, int]] = None, ) -> None: super().__init__(name, rng) @@ -193,7 +193,7 @@ class CategoricalParameter(Parameter): def __init__( self, choices, - name: Optional[str] = None, + name: Optional[str] = "", rng: Optional[Union[np.random.Generator, int]] = None, ) -> None: super().__init__(name, rng) @@ -238,7 +238,7 @@ def __init__( choices, min, max, - name: Optional[str] = None, + name: Optional[str] = "", rng: Optional[Union[np.random.Generator, int]] = None, ) -> None: super().__init__(name, rng) diff --git a/hannah/nas/parameters/parametrize.py b/hannah/nas/parameters/parametrize.py index ff85b69e..acb161c9 100644 --- a/hannah/nas/parameters/parametrize.py +++ b/hannah/nas/parameters/parametrize.py @@ -168,6 +168,8 @@ def get_parameters( if hasattr(current, "_PARAMETERS"): for param in current._PARAMETERS.values(): if param.id is None: + param.id = param.name + if isinstance(param, Parameter) and param.id not in visited: param.id = current.id + '.' + param.name if param.id not in visited: queue.append(param) diff --git a/hannah/nas/test/test_dataflow.py b/hannah/nas/test/test_dataflow.py index b3e0e3b1..8ea0cc38 100644 --- a/hannah/nas/test/test_dataflow.py +++ b/hannah/nas/test/test_dataflow.py @@ -24,9 +24,9 @@ def conv_relu(input: TensorExpression, @dataflow def block(input: TensorExpression, expansion=FloatScalarParameter(1, 6, name='expansion'), - output_channel=IntScalarParameter(4, 64), - kernel_size=CategoricalParameter([1, 3, 5]), - stride=CategoricalParameter([1, 2])): + output_channel=IntScalarParameter(4, 64, name='out_channels'), + kernel_size=CategoricalParameter([1, 3, 5], name='kernel_size'), + stride=CategoricalParameter([1, 2], name='stride')): out = conv_relu(input, output_channel=output_channel.new()*expansion.new(), kernel_size=kernel_size.new(), stride=DefaultInt(1)) out = conv_relu(out, output_channel=output_channel.new(), kernel_size=DefaultInt(1), stride=stride.new()) @@ -81,7 +81,7 @@ def test_flatten(): def test_parameter_extraction(): input = batched_image_tensor(name='input') - out = block(input, stride=IntScalarParameter(min=1, max=2)) + out = block(input, stride=IntScalarParameter(min=1, max=2, name='stride')) out = block(out) # flattened_graph = flatten(out) params = out.parametrization(include_empty=True, flatten=True) diff --git a/pydoc-markdown.yml b/pydoc-markdown.yml index 429c84e4..fcbbb841 100644 --- a/pydoc-markdown.yml +++ b/pydoc-markdown.yml @@ -23,6 +23,11 @@ processors: - type: filter - type: smart - type: crossref + +hooks: + pre-render: + - cp -r doc/assets build/docs/content + renderer: type: mkdocs pages: @@ -60,16 +65,33 @@ renderer: name: compression/decomposition source: doc/compression/decomposition.md - title: Knowledge Distillation - name: compresseion/knowledge_distillation + name: compression/knowledge_distillation source: doc/compression/knowledge_distillation.md - title: Optimization children: - title: Hyperparameter Optimization name: optimization/hyperparameter source: doc/optimization/hyperparameters.md - - title: Neural Architecture Search - name: optimization/nas - source: doc/optimization/nas.md + - title: Neural Architecture Search + children: + - title: NAS (Legacy) + name: nas/legacy + source: doc/nas/legacy.md + - title: Usage + name: nas/usage + source: doc/nas/usage.md + - title: Search Spaces + name: nas/search_spaces + source: doc/nas/search_spaces.md + - title: Parametrization + name: nas/parametrization + source: doc/nas/parametrization.md + - title: Search + name: nas/search + source: doc/nas/search.md + - title: Evaluating Results + name: nas/eval + source: doc/nas/eval.md - title: Deployment children: - title: Torch Mobile @@ -109,6 +131,7 @@ renderer: markdown_extensions: - def_list - admonition + - codehilite - pymdownx.arithmatex: generic: true extra_javascript: