Skip to content

Commit

Permalink
Fix nas und backend integration
Browse files Browse the repository at this point in the history
  • Loading branch information
cgerum committed Aug 28, 2024
1 parent 0e31fc1 commit 0117f80
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 22 deletions.
32 changes: 32 additions & 0 deletions hannah/backends/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# Copyright (c) 2024 Hannah contributors.
#
# This file is part of hannah.
# See https://github.com/ekut-es/hannah for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from hydra.utils import instantiate


def profile_backend(config, lit_module):
metrics = {}
if config.get("backend"):
backend = instantiate(config.backend)
backend.prepare(lit_module)

backend_results = backend.profile(lit_module.example_input_array) # noqa

metrics = backend_results.metrics

return metrics
19 changes: 12 additions & 7 deletions hannah/nas/search/model_trainer/simple_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

from hannah.backends.profile import profile_backend
from hannah.nas.functional_operators.executor import BasicExecutor
from hannah.nas.parameters.parametrize import set_parametrization
from hannah.nas.search.utils import save_graph_to_file, setup_callbacks
Expand Down Expand Up @@ -86,20 +87,24 @@ def run_training(self, model, num, global_num, config):

reset_seed()
trainer.validate(ckpt_path=ckpt_path, verbose=False)

backend_metrics = profile_backend(config, module)

res = opt_callback.result(dict=True)

res.update(backend_metrics)
save_graph_to_file(global_num, res, module)
except Exception as e:
msglogger.critical("Training failed with exception")
msglogger.critical(str(e))
print(traceback.format_exc())
sys.exit(1)

res = {}
for monitor in opt_monitor:
# res[monitor] = float("inf")
res[monitor] = (
1 # FIXME: "inf" causes errors in performance prediction. Find "worst" value for each respective metric?
)
res = {}
for monitor in opt_monitor:
# res[monitor] = float("inf")
res[monitor] = (
1 # FIXME: "inf" causes errors in performance prediction. Find "worst" value for each respective metric?
)

return res
finally:
Expand Down
23 changes: 8 additions & 15 deletions hannah/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .utils import clear_outputs, common_callbacks, git_version, log_execution_env_state
from .utils.dvclive import DVCLIVE_AVAILABLE, DVCLogger
from .utils.logger import JSONLogger
from .backends.profile import profile_backend

msglogger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,13 +90,12 @@ def instantiate_module(config) -> LightningModule:
def train(
config: DictConfig,
) -> Union[float, Dict[Any, float], List[Union[float, Dict[Any, float]]]]:

test_output = []
val_output = []
results = []

backend_output = []

if isinstance(config.seed, int):
config.seed = [config.seed]
validate_output = False
Expand Down Expand Up @@ -192,17 +192,10 @@ def train(
test_output.append(opt_callback.test_result())

results.append(opt_callback.result())
# Final inference run if a backend is given

# Final inference run if a backend is given
if "backend" in config:
backend = instantiate(config.backend)
backend.prepare(lit_module)

backend_results = backend.profile(lit_module.example_input_array) # noqa

metrics = backend_results.metrics

backend_output.append(metrics)
backend_output.append(profile_backend(config, lit_module))

@rank_zero_only
def summarize_stage(stage: str, output: Mapping["str", float]) -> None:
Expand Down Expand Up @@ -257,9 +250,9 @@ def summarize_stage(stage: str, output: Mapping["str", float]) -> None:

summarize_stage("test", test_output)
summarize_stage("val", val_output)

if len(backend_output) > 0:
summarize_stage("backend", backend_output)
summarize_stage("backend", backend_output)

if len(results) == 1:
return results[0]
Expand Down

0 comments on commit 0117f80

Please sign in to comment.