Skip to content

Commit

Permalink
Add support for external optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
cgerum committed Aug 27, 2024
1 parent d529c82 commit 72668f0
Show file tree
Hide file tree
Showing 23 changed files with 1,256 additions and 67 deletions.
11 changes: 5 additions & 6 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
POETRY_HOME: "$CI_PROJECT_DIR/.poetry"
POETRY_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pypoetry"
POETRY_VIRTUALENVS_CREATE: false
GIT_SUBMODULE_STRATEGY: recursive
DEBIAN_FRONTEND: "noninteractive"

Expand All @@ -36,7 +35,7 @@ test_prebuilt_docker:
image: ghcr.io/ekut-es/hannah_hannah:latest
script:
- set -e
- poetry install -E vision
- poetry install --all-extras
- poetry run python3 -m pytest -v --cov=hannah test hannah
tags:
- docker
Expand All @@ -51,7 +50,7 @@ test_python_39:
script:
- set -e
- poetry config installer.max-workers 10
- poetry install -E vision
- poetry install --all-extras
- poetry run python3 -m pytest -v --cov=hannah test hannah --integration
tags:
- docker
Expand All @@ -66,7 +65,7 @@ test_python_310:
script:
- set -e
- poetry config installer.max-workers 10
- poetry install -E vision
- poetry install --all-extras
- "echo 'import coverage; coverage.process_startup()' > sitecustomize.py"
- export PYTHONPATH=$PWD
- export COVERAGE_PROCESS_START=$PWD/.coveragerc
Expand All @@ -93,7 +92,7 @@ test_python_311:
script:
- set -e
- poetry config installer.max-workers 10
- poetry install -E vision
- poetry install --all-extras
- "echo 'import coverage; coverage.process_startup()' > sitecustomize.py"
- export PYTHONPATH=$PWD
- export COVERAGE_PROCESS_START=$PWD/.coveragerc
Expand Down Expand Up @@ -146,7 +145,7 @@ run_sca:
- export POETRY_HOME=/root/.local
- export PATH=${POETRY_HOME}/bin:${PATH}
- poetry config installer.max-workers 10
- poetry install -E vision
- poetry install --all-extras
- ulimit -a
script:
- set -e
Expand Down
4 changes: 1 addition & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,5 @@ COPY poetry.lock pyproject.toml /deps/


# Install dependencies
RUN poetry config virtualenvs.create false \
&& poetry install --no-interaction --no-ansi --all-extras --no-root \
RUN poetry install --no-interaction --no-ansi --all-extras --no-root \
&& rm -rf $POETRY_CACHE_DIR

4 changes: 2 additions & 2 deletions experiments/cifar10/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
defaults:
- base_config
- override dataset: cifar10 # Dataset configuration name
#- override dataset: cifar10 # Dataset configuration name
- override features: identity # Feature extractor configuration name (use identity for vision datasets)
- override model: timm_resnet18 #timm_mobilenetv3_small_100 # Neural network name (for now timm_resnet50 or timm_efficientnet_lite1)
#- override model: timm_resnet18 #timm_mobilenetv3_small_100 # Neural network name (for now timm_resnet50 or timm_efficientnet_lite1)
- override optimizer: sgd # Optimizer config name
- override normalizer: null # Feature normalizer (used for quantized neural networks)
- override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks)
Expand Down
22 changes: 22 additions & 0 deletions experiments/cifar10/experiment/ae_nas.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# @package _global_
defaults:
- override /nas: aging_evolution_nas
- override /model: embedded_vision_net
- override /dataset: cifar10

model:
num_classes: 10
module:
batch_size: 128
nas:
budget: 300
n_jobs: 8


trainer:
max_epochs: 10

seed: [1234]

experiment_id: "ae_nas"

23 changes: 23 additions & 0 deletions experiments/cifar10/experiment/nsga2_nas.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_
defaults:
- override /nas: aging_evolution_nas
- override /nas/sampler: nsga2
- override /model: embedded_vision_net
- override /dataset: cifar10

model:
num_classes: 10
module:
batch_size: 128
nas:
budget: 300
n_jobs: 8


trainer:
max_epochs: 10

seed: [1234]

experiment_id: "ae_nas"

Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ trainer:

seed: [1234]

experiment_id: "ae_nas_cifar10"
experiment_id: "ae_nas"
24 changes: 24 additions & 0 deletions experiments/mnist/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
defaults:
- base_config
- override dataset: mnist # Dataset configuration name
- override features: identity # Feature extractor configuration name (use identity for vision datasets)
- override model: timm_resnet18 #timm_mobilenetv3_small_100 # Neural network name (for now timm_resnet50 or timm_efficientnet_lite1)
- override optimizer: sgd # Optimizer config name
- override normalizer: null # Feature normalizer (used for quantized neural networks)
- override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks)
- _self_


monitor:
metric: val_accuracy
direction: maximize

module:
batch_size: 2048

trainer:
max_epochs: 50
precision: 16

optimizer:
lr: 0.3
22 changes: 22 additions & 0 deletions hannah/conf/dataset/cifar100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
##
## Copyright (c) 2022 University of Tübingen.
##
## This file is part of hannah.
## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info.
##
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
##
## http://www.apache.org/licenses/LICENSE-2.0
##
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
##
data_folder: ${hydra:runtime.cwd}/datasets/
cls: hannah.datasets.vision.Cifar100Dataset
dataset: cifar100
val_percent: 0.1
22 changes: 22 additions & 0 deletions hannah/conf/dataset/mnist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
##
## Copyright (c) 2022 University of Tübingen.
##
## This file is part of hannah.
## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info.
##
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
##
## http://www.apache.org/licenses/LICENSE-2.0
##
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
##
data_folder: ${hydra:runtime.cwd}/datasets/
cls: hannah.datasets.vision.MNISTDataset
dataset: mnist
val_percent: 0.1
22 changes: 22 additions & 0 deletions hannah/conf/dataset/svhn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
##
## Copyright (c) 2022 University of Tübingen.
##
## This file is part of hannah.
## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info.
##
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
##
## http://www.apache.org/licenses/LICENSE-2.0
##
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
##
data_folder: ${hydra:runtime.cwd}/datasets/
cls: hannah.datasets.vision.SVHNDataset
dataset: svhn
val_percent: 0.1
3 changes: 3 additions & 0 deletions hannah/conf/nas/sampler/nsga2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: hannah.nas.search.sampler.pymoo.PyMOOSampler
algorithm: nsga2
population_size: 100
9 changes: 7 additions & 2 deletions hannah/datasets/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023 Hannah contributors.
# Copyright (c) 2024 Hannah contributors.
#
# This file is part of hannah.
# See https://github.com/ekut-es/hannah for further info.
Expand All @@ -16,18 +16,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .cifar import Cifar10Dataset
from .cifar import Cifar10Dataset, Cifar100Dataset
from .dresden_capsule import DresdenCapsuleDataset
from .fake import FakeDataset
from .kvasir import KvasirCapsuleDataset
from .kvasir_unlabeled import KvasirCapsuleUnlabeled
from .ri_capsule import RICapsuleDataset
from .mnist import MNISTDataset
from .svhn import SVHNDataset

__all__ = [
"DresdenCapsuleDataset",
"KvasirCapsuleDataset",
"FakeDataset",
"Cifar10Dataset",
"Cifar100Dataset",
"KvasirCapsuleUnlabeled",
"RICapsuleDataset",
"MNISTDataset",
"SVHNDataset",
]
Loading

0 comments on commit 72668f0

Please sign in to comment.