Skip to content

Commit

Permalink
Merge branch 'f/multiple_backends' into 'main'
Browse files Browse the repository at this point in the history
Define metrics estimator interface and allow multiple estimators

See merge request es/ai/hannah/hannah!380
  • Loading branch information
cgerum committed Apr 15, 2024
2 parents eb4be47 + 2b2034c commit f7bf3a3
Show file tree
Hide file tree
Showing 26 changed files with 550 additions and 581 deletions.
4 changes: 2 additions & 2 deletions doc/deployment/tenssorrt.md → doc/deployment/tensorrt.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<!--
Copyright (c) 2023 Hannah contributors.
Copyright (c) 2024 Hannah contributors.
This file is part of hannah.
See https://github.com/ekut-es/hannah for further info.
Expand All @@ -26,7 +26,7 @@ The tensorrt module in the poetry shell needs to be installed seperately via pip
```
poetry shell
pip install tensorrt
```
```

## Configuration

Expand Down
36 changes: 23 additions & 13 deletions doc/nas/metrics.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
# Performance Metrics in Neural Architecture Search
<!--
Copyright (c) 2024 Hannah contributors.
This file is part of hannah.
See https://github.com/ekut-es/hannah for further info.
Currently there are two main sources of Performance Metrics used in the hannah's NAS subystem.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Performance Metrics in Neural Architecture Search

1. The main training loop generates performance metrics using the training loops, these metrics are logged using the lightning logging system during training and are then extracted using the 'HydraOptCallback', and are only available for optimization purposes after a training has been run. These kinds of metrics are also generated for normal training runs.
2. Estimators can provide metrics before the neural networks have been trained. Predictors are used in presampling phases of the neural architecture search. Predictors are not and will not be used outside of neural architecture search.
Currently there are two main sources of Performance Metrics used in the hannah's NAS subystem.

There are 2 subclasses of predictors.
1. Backend generated metrics. Backends generated metrics are returned by the backend's `profile` method. Backend generated metrics are usually generated by running the neural networks, either on real target hardware or on accurate simulators. We currently do not enforce accuracy requirements on the reported metrics, but we will consider them as golden reference results for the evaluation and if necessary the training of the performance estimators, so they should be as accurate as possible.
2. Estimators can provide metrics before the neural networks have been trained. Predictors are used in presampling phases of the neural architecture search. Predictors are not and will not be used outside of neural architecture search.

There are 2 subclasses of predictors.
- Machine Learning based predictors: These predictors provide an interface based on: `predict`, `update`, `load`, `train`
- Analytical predictors, the interface of these methods only contains the: `predict`

The current implementation has a few problems:

- not using a unified interface for both predictors induces breakage at a lot of places in the nas flow
- it is currently not possible to configure more than one predictor at the same time, which has led to things like hardcoding additional predictors in the NAS loops: https://es-git.cs.uni-tuebingen.de/es/ai/hannah/hannah/-/blob/main/hannah/nas/search/search.py?ref_type=heads#L136
- Currently it is not immediately clear how device metrics should be generated, and especially how device metrics obtained from device execution should be generated.
The predictor interfaces are defined in `hannah.nas.performance_prediction.protcol` as python protocols.

There have been a few approaches to this.
1. The BackendPredictor, it instantiates a backend and then calls the predict method on the untrained model https://es-git.cs.uni-tuebingen.de/es/ai/hannah/hannah/-/blob/main/hannah/nas/performance_prediction/simple.py?ref_type=heads#L42
2. Target Specific estimators, like this mlonmcu predictor: https://es-git.cs.uni-tuebingen.de/es/ai/hannah/hannah/-/merge_requests/378/diffs#a45007495fb172b95977d0692f84cff89bf6d692
`
98 changes: 4 additions & 94 deletions hannah/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,97 +102,7 @@ def export(self) -> None:
logger.critical("Exporting model is not implemented for this backend")


class InferenceBackendBase(AbstractBackend, Callback):
"""Base class to wrap backends as a lightning callback"""

def __init__(
self, val_batches=1, test_batches=1, val_frequency=10, tune: bool = True
):
self.test_batches = test_batches
self.val_batches = val_batches
self.val_frequency = val_frequency
self.validation_epoch = 0
self.tune = tune

def on_validation_epoch_start(self, trainer, pl_module):
if not self.tune:
return

if self.val_batches > 0:
if self.validation_epoch % self.val_frequency == 0:
pl_module = self.quantize(pl_module)
self.prepare(pl_module)

def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=-1
):
if not self.tune:
return

if batch_idx < self.val_batches:
if self.validation_epoch % self.val_frequency == 0:
result = self.run_batch(inputs=batch[0])
if not isinstance(result, torch.Tensor):
logging.warning("Could not calculate MSE on target device")
return
target = pl_module.forward(batch[0].to(pl_module.device))
mse = torch.nn.functional.mse_loss(
result.to(pl_module.device),
target.to(pl_module.device),
reduction="mean",
)
pl_module.log("val_backend_mse", mse)
logging.info("val_backend_mse: %f", mse)

def on_validation_epoch_end(self, trainer, pl_module):
self.validation_epoch += 1

def on_test_epoch_start(self, trainer, pl_module):
logger.info("Exporting module")

pl_module = self.quantize(pl_module)
self.prepare(pl_module)
self.export()

def quantize(self, pl_module: torch.nn.Module) -> torch.nn.Module:
qconfig_mapping = getattr(pl_module, "qconfig_mapping", None)
if qconfig_mapping is None:
logger.info("No qconfig found in module, leaving module unquantized")
return pl_module

pl_module = copy.deepcopy(pl_module)
pl_module.cpu()

logger.info("Quantizing module")

example_inputs = next(iter(pl_module.train_dataloader()))[0]

model = torch.ao.quantization.quantize_fx.prepare_fx(
pl_module.model, qconfig_mapping, example_inputs
)
model = torch.ao.quantization.quantize_fx.convert_fx(model)
pl_module.model = model

return pl_module

def on_test_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=-1
):
if batch_idx < self.test_batches:
# decode batches from target device
if isinstance(batch, Mapping) or isinstance(batch, dict):
inputs = batch["data"]
else:
inputs = batch[0]

result = self.run_batch(inputs=inputs)
target = pl_module(inputs.to(pl_module.device))
target = target[: result.shape[0]]

mse = torch.nn.functional.mse_loss(
result.to(pl_module.device),
target.to(pl_module.device),
reduction="mean",
)
pl_module.log("test_backend_mse", mse)
logging.info("test_backend_mse: %f", mse)
class InferenceBackendBase(AbstractBackend):
"""Base class for backends, it is only here for backwards compatibility reasons, use AbstractBackend instead"""

pass
7 changes: 0 additions & 7 deletions hannah/backends/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,9 @@ class OnnxruntimeBackend(InferenceBackendBase):

def __init__(
self,
val_batches=1,
test_batches=1,
val_frequency=10,
repeat=10,
warmup=2,
):
super(OnnxruntimeBackend, self).__init__(
val_batches=val_batches, test_batches=test_batches, val_frequency=10
)

self.repeat = repeat
self.warmup = warmup

Expand Down
10 changes: 1 addition & 9 deletions hannah/backends/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,7 @@ def cuda_call(call):


class TensorRTBackend(InferenceBackendBase):
def __init__(
self, val_batches=1, test_batches=1, val_frequency=10, warmup=10, repeat=30
):
super().__init__(
val_batches=val_batches,
test_batches=test_batches,
val_frequency=val_frequency,
)

def __init__(self, warmup=10, repeat=30):
if trt is None or cuda is None or cudart is None:
raise RuntimeError(
"TensorRT is not available, please install with tensorrt extra activated."
Expand Down
6 changes: 1 addition & 5 deletions hannah/backends/torch_mobile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
class TorchMobileBackend(InferenceBackendBase):
"""Inference backend for torch mobile"""

def __init__(
self, val_batches=1, test_batches=1, val_frequency=1, warmup=2, repeat=10
):
super().__init__(val_batches, test_batches, val_frequency)

def __init__(self, warmup=2, repeat=10):
self.warmup = warmup
self.repeat = repeat
self.script_module = None
Expand Down
10 changes: 9 additions & 1 deletion hannah/callbacks/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,15 @@ def _do_summary(self, pl_module, input=None, print_log=True):
"""
dummy_input = input
if dummy_input is None:
dummy_input = pl_module.example_feature_array
if hasattr(pl_module, "example_feature_array"):
dummy_input = pl_module.example_feature_array
elif hasattr(pl_module, "example_input_array"):
dummy_input = pl_module.example_input_array
else:
raise ValueError(
"No example_input_array or example_feature_array found in pl_module"
)

dummy_input = dummy_input.to(pl_module.device)

total_macs = 0.0
Expand Down
3 changes: 0 additions & 3 deletions hannah/conf/backend/tensorrt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,3 @@


_target_: hannah.callbacks.backends.TensorRTBackend
val_batches: 10
test_batches: 10
val_frequency: 10
3 changes: 0 additions & 3 deletions hannah/conf/backend/torchmobile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,3 @@


_target_: hannah.callbacks.backends.TorchMobileBackend
val_batches: 10
test_batches: 10
val_frequency: 10
5 changes: 3 additions & 2 deletions hannah/conf/nas/aging_evolution_nas.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
## limitations under the License.
##
defaults:
- predictor: gcn
- predictor:
- macs
- gcn
- sampler: aging_evolution
- model_trainer: simple
- constraint_model: random_walk
- presampler: null


_target_: hannah.nas.search.search.DirectNAS
budget: 2000
n_jobs: 10
Expand Down
19 changes: 0 additions & 19 deletions hannah/conf/nas/predictor/backend.yaml

This file was deleted.

10 changes: 6 additions & 4 deletions hannah/conf/nas/predictor/gcn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
##


_target_: hannah.nas.performance_prediction.simple.GCNPredictor
model:
_target_: hannah.nas.performance_prediction.gcn.predictor.GaussianProcessPredictor
input_feature_size: 31

gcn:
_target_: hannah.nas.performance_prediction.simple.GCNPredictor
model:
_target_: hannah.nas.performance_prediction.gcn.predictor.GaussianProcessPredictor
input_feature_size: 31
5 changes: 4 additions & 1 deletion hannah/conf/nas/predictor/macs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@
## See the License for the specific language governing permissions and
## limitations under the License.
##
_target_: hannah.nas.performance_prediction.simple.MACPredictor

macs:
_target_: hannah.nas.performance_prediction.simple.MACPredictor
predictor: fx
4 changes: 3 additions & 1 deletion hannah/conf/nas/random_nas.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ defaults:
- sampler: random
- model_trainer: simple
- constraint_model: z3
- predictor: gcn
- predictor:
- gcn
- macs
- presampler: single_range_checker

_target_: hannah.nas.search.search.DirectNAS
Expand Down
6 changes: 4 additions & 2 deletions hannah/nas/constraints/constraint_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023 Hannah contributors.
# Copyright (c) 2024 Hannah contributors.
#
# This file is part of hannah.
# See https://github.com/ekut-es/hannah for further info.
Expand Down Expand Up @@ -151,7 +151,9 @@ def naive_search(self, solver, module, key=None, parameters=None):
def solve(self, module, parameters=None, key=None, fix_vars=[]):
self.soft_constrain_current_parametrization(module, parameters, key, fix_vars)

def soft_constrain_current_parametrization(self, module, parameters=None, key=None, fix_vars=[]):
def soft_constrain_current_parametrization(
self, module, parameters=None, key=None, fix_vars=[]
):
self.solver = []
self.build_model(module._conditions)
for solver in self.solver:
Expand Down
10 changes: 8 additions & 2 deletions hannah/nas/performance_prediction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#
# Copyright (c) 2022 University of Tübingen.
# Copyright (c) 2024 Hannah contributors.
#
# This file is part of hannah.
# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info.
# See https://github.com/ekut-es/hannah for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,3 +16,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#


from .protocol import FitablePredictor, Predictor
from .simple import GCNPredictor, MACPredictor

__all__ = ["MACPredictor", "GCNPredictor", "Predictor", "FitablePredictor"]
Loading

0 comments on commit f7bf3a3

Please sign in to comment.