diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 09355fd2..8a0db357 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -27,6 +27,7 @@ variables: PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" POETRY_HOME: "$CI_PROJECT_DIR/.poetry" POETRY_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pypoetry" + POETRY_VIRTUALENVS_CREATE: false GIT_SUBMODULE_STRATEGY: recursive DEBIAN_FRONTEND: "noninteractive" diff --git a/experiments/rhode_island/model/localization_net.yaml b/experiments/rhode_island/model/localization_net.yaml index d9f3a1db..a18f1caa 100644 --- a/experiments/rhode_island/model/localization_net.yaml +++ b/experiments/rhode_island/model/localization_net.yaml @@ -26,10 +26,10 @@ qconfig: _target_: hannah.models.factory.qconfig.get_trax_qat_qconfig config: bw_b: 8 - bw_w: 6 + bw_w: 4 bw_f: 8 power_of_2: false # Use power of two quantization for weights - noise_prob: 0.7 # Probability of quantizing a value during training + noise_prob: 0.9 # Probability of quantizing a value during training conv: - target: forward stride: 1 diff --git a/hannah/__init__.py b/hannah/__init__.py index bac36f20..4b1836c6 100644 --- a/hannah/__init__.py +++ b/hannah/__init__.py @@ -1,8 +1,8 @@ # -# Copyright (c) 2022 University of Tübingen. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +# See https://github.com/ekut-es/hannah for further info. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +16,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # + + +try: + from beartype.claw import beartype_this_package + + beartype_this_package() +except ImportError: + pass # beartype is not installed in production environment diff --git a/hannah/callbacks/optimization.py b/hannah/callbacks/optimization.py index 7f7fc2a2..b705c20b 100644 --- a/hannah/callbacks/optimization.py +++ b/hannah/callbacks/optimization.py @@ -1,8 +1,8 @@ # -# Copyright (c) 2022 University of Tübingen. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +# See https://github.com/ekut-es/hannah for further info. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ from collections import defaultdict from typing import Any, Iterable, List, Mapping, Union +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from torch import Tensor @@ -94,8 +95,15 @@ def _add_monitor_mapping(self, monitor): else: self.directions.append(-1.0) - def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: 'STEP_OUTPUT', batch: Any, batch_idx: int) -> None: # noqa: F821 - callback_metrics = trainer.callback_metrics + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Any, + batch: Any, + batch_idx: int, + ) -> None: + callback_metrics = trainer.callback_metrics for k, v in callback_metrics.items(): if k.startswith("train"): @@ -112,7 +120,7 @@ def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu monitor_val = callback_metrics[monitor] * direction if monitor.startswith("train"): self._curves[monitor][trainer.global_step] = monitor_val - + self.values[monitor] = monitor_val def on_test_end(self, trainer, pl_module): @@ -156,7 +164,7 @@ def on_validation_end(self, trainer, pl_module): # Skip evaluation of validation metrics during sanity check if trainer.sanity_checking: return - + callback_metrics = trainer.callback_metrics for k, v in callback_metrics.items(): @@ -168,7 +176,7 @@ def on_validation_end(self, trainer, pl_module): try: monitor_val = float(callback_metrics[monitor]) directed_monitor_val = monitor_val * direction - + self.values[monitor] = directed_monitor_val self._curves[monitor][trainer.global_step] = directed_monitor_val except Exception: @@ -230,4 +238,3 @@ def curves(self, dict=False): return list(return_values.values())[0] return return_values - diff --git a/hannah/conf/config_dd_direct_angle.yaml b/hannah/conf/config_dd_direct_angle.yaml deleted file mode 100644 index 8a602ee5..00000000 --- a/hannah/conf/config_dd_direct_angle.yaml +++ /dev/null @@ -1,40 +0,0 @@ -defaults: - - dataset: directional - - model: tc-res8 - - scheduler: 1cycle - - optimizer: adamw - - features: raw - - normalizer: null - - module: direct_angle_classifier - - compress: null - - trainer: default - - checkpoint: default - - backend: null - - early_stopping: null - - profiler: null - -type: train -experiment_id: dd_direct_angle.raw -output_dir: trained_models -auto_lr: false -seed: [1234] - -module: - num_workers: 4 - batch_size: 64 - -trainer: - max_epochs: 100 - -model: - n_labels: 1 - -hydra: - run: - dir: ${output_dir}/${experiment_id}/${model.name}/ - sweep: - dir: ${output_dir}/${experiment_id}/${model.name}/${hydra.job.name} - -#TODO: -dump_test: false -input_file: '' \ No newline at end of file diff --git a/hannah/conf/config_dd_direct_angle_phase.yaml b/hannah/conf/config_dd_direct_angle_phase.yaml deleted file mode 100644 index 1cf5d740..00000000 --- a/hannah/conf/config_dd_direct_angle_phase.yaml +++ /dev/null @@ -1,40 +0,0 @@ -defaults: - - dataset: directional - - model: tc-res8 - - scheduler: 1cycle - - optimizer: adamw - - features: phase - - normalizer: null - - module: direct_angle_classifier - - compress: null - - trainer: default - - checkpoint: default - - backend: null - - early_stopping: null - - profiler: null - -type: train -experiment_id: dd_direct_angle.phase -output_dir: trained_models -auto_lr: false -seed: [1234] - -module: - num_workers: 4 - batch_size: 64 - -trainer: - max_epochs: 100 - -model: - n_labels: 1 - -hydra: - run: - dir: ${output_dir}/${experiment_id}/${model.name}/ - sweep: - dir: ${output_dir}/${experiment_id}/${model.name}/${hydra.job.name} - -#TODO: -dump_test: false -input_file: '' \ No newline at end of file diff --git a/hannah/conf/dataset/sensor/naneye-raw.yaml b/hannah/conf/dataset/sensor/naneye-raw.yaml new file mode 100644 index 00000000..3de11831 --- /dev/null +++ b/hannah/conf/dataset/sensor/naneye-raw.yaml @@ -0,0 +1,24 @@ +## +## Copyright (c) 2022 University of Tübingen. +## +## This file is part of hannah. +## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +## +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language governing permissions and +## limitations under the License. +## + +# Very approximative simulation of a raw sensor image naneya-m with some noise model + +name: naneye-m +resolution: [320,320] +pattern: GRBG diff --git a/hannah/conf/dataset/sensor/naneye.yaml b/hannah/conf/dataset/sensor/naneye.yaml index 32cb68e5..81d7d988 100644 --- a/hannah/conf/dataset/sensor/naneye.yaml +++ b/hannah/conf/dataset/sensor/naneye.yaml @@ -18,3 +18,4 @@ ## name: naneye-m resolution: [320,320] +pattern: [1,1] diff --git a/hannah/conf/model/conv-net-2d.yaml b/hannah/conf/model/conv-net-2d.yaml index 79160778..f9cfc7b4 100644 --- a/hannah/conf/model/conv-net-2d.yaml +++ b/hannah/conf/model/conv-net-2d.yaml @@ -26,17 +26,17 @@ qconfig: _target_: hannah.models.factory.qconfig.get_trax_qat_qconfig config: bw_b: 8 - bw_w: 6 + bw_w: 8 bw_f: 8 power_of_2: false # Use power of two quantization for weights - noise_prob: 0.7 # Probability of quantizing a value during training + noise_prob: 0.5 # Probability of quantizing a value during training conv: - target: forward - stride: 1 + stride: 2 blocks: - target: conv2d kernel_size: 3 - act: false + act: true norm: true out_channels: 16 - target: residual @@ -46,16 +46,16 @@ conv: kernel_size: 3 act: true norm: true - out_channels: 24 + out_channels: 32 - target: conv2d kernel_size: 1 parallel: true - out_channels: 24 + out_channels: 32 - target: conv2d kernel_size: 3 act: true norm: true - out_channels: 24 + out_channels: 32 - target: residual stride: 2 blocks: @@ -63,12 +63,12 @@ conv: kernel_size: 3 act: true norm: true - out_channels: 32 + out_channels: 64 - target: conv2d kernel_size: 3 act: true norm: true - out_channels: 32 + out_channels: 64 - target: residual stride: 2 blocks: @@ -76,13 +76,13 @@ conv: kernel_size: 3 act: true norm: true - out_channels: 48 + out_channels: 128 - target: conv2d kernel_size: 1 parallel: true - out_channels: 48 + out_channels: 128 - target: conv2d kernel_size: 3 act: true norm: true - out_channels: 48 + out_channels: 128 diff --git a/hannah/conf/nas_new.yaml b/hannah/conf/nas_new.yaml index b79c976d..8775a931 100644 --- a/hannah/conf/nas_new.yaml +++ b/hannah/conf/nas_new.yaml @@ -26,8 +26,8 @@ defaults: - override normalizer: fixedpoint - override module: stream_classifier - override checkpoint: default - - override backend: trax_ut - - override nas: aging_evolution_nas + #- override backend: trax_ut + - override nas: aging_evolution_nas_legacy - _self_ experiment_id: test @@ -39,9 +39,9 @@ trainer: nas: parametrization: - backend: - cols: [2,4,6,8,16] - rows: null + #backend: + # cols: [2,4,6,8,16] + # rows: null model: qconfig: config: diff --git a/hannah/conf/trainer/default.yaml b/hannah/conf/trainer/default.yaml index 421f9777..b8d082b4 100644 --- a/hannah/conf/trainer/default.yaml +++ b/hannah/conf/trainer/default.yaml @@ -34,3 +34,4 @@ plugins: null strategy: auto reload_dataloaders_every_n_epochs: 0 precision: 32 +enable_model_summary: False diff --git a/hannah/datasets/speech.py b/hannah/datasets/speech.py index 4f926836..80a478c4 100644 --- a/hannah/datasets/speech.py +++ b/hannah/datasets/speech.py @@ -51,7 +51,6 @@ def snr_factor(snr, psig, pnoise): def _load_audio(file_name, sr=16000, backend="torchaudio"): if backend == "torchaudio": - torchaudio.set_audio_backend("sox_io") try: data, samplingrate = torchaudio.load(file_name) except Exception as e: @@ -60,7 +59,6 @@ def _load_audio(file_name, sr=16000, backend="torchaudio"): msglogger.warning( "Could not load %s with default backend trying sndfile", str(file_name) ) - torchaudio.set_audio_backend("soundfile") data, samplingrate = torchaudio.load(file_name) if samplingrate != sr: data = torchaudio.transforms.Resample(samplingrate, sr).forward(data) diff --git a/hannah/models/ai8x/models.py b/hannah/models/ai8x/models.py index 075886f8..ad1716da 100644 --- a/hannah/models/ai8x/models.py +++ b/hannah/models/ai8x/models.py @@ -115,8 +115,8 @@ def block( weight1_quantized = quantize_weight(weight1) bias1 = Tensor( "b1", - (Int(channels)), - axis=["O", "I", "kH", "kW"], + (Int(channels),), + axis=["C"], grad=True, ) bias1_quantized = quantize_weight(bias1) diff --git a/hannah/models/ai8x/models_simplified.py b/hannah/models/ai8x/models_simplified.py index 0ba99895..8767b6fc 100644 --- a/hannah/models/ai8x/models_simplified.py +++ b/hannah/models/ai8x/models_simplified.py @@ -101,7 +101,7 @@ def block(input, channels: int, kernel_size: int): weight_quantized = quantize_weight(weight) bias = Tensor( "b1", - (Int(channels)), + (Int(channels),), axis=["O", "I", "kH", "kW"], grad=True, ) diff --git a/hannah/models/embedded_vision_net/parameters.py b/hannah/models/embedded_vision_net/parameters.py index 6fb82507..c46a85f9 100644 --- a/hannah/models/embedded_vision_net/parameters.py +++ b/hannah/models/embedded_vision_net/parameters.py @@ -1,5 +1,25 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from typing import Optional, Union + import numpy as np + from hannah.nas.functional_operators.lazy import lazy from hannah.nas.parameters.parameters import Parameter @@ -16,11 +36,13 @@ def channels_per_group(n): class Groups(Parameter): - def __init__(self, - in_channels: Parameter, - out_channels: Parameter, - name: Optional[str] = "", - rng: Optional[Union[np.random.Generator, int]] = None,): + def __init__( + self, + in_channels: Union[int, Parameter], + out_channels: Union[int, Parameter], + name: Optional[str] = "", + rng: Optional[Union[np.random.Generator, int]] = None, + ): super().__init__(name, rng) self.name = name self.in_channels = in_channels @@ -32,7 +54,9 @@ def get_possible_values(self): out_channels = lazy(self.out_channels) possible_values_in = channels_per_group(in_channels) possible_values_out = channels_per_group(out_channels) - possible_values = list(set(possible_values_in).intersection(possible_values_out)) + possible_values = list( + set(possible_values_in).intersection(possible_values_out) + ) return possible_values def instantiate(self): @@ -50,7 +74,11 @@ def sample(self): def check(self, value): possible_values = self.get_possible_values() if value not in possible_values: - raise ValueError("{} channels per group not valid with {} in channels and {} out channels".format(value, self.in_channels.evaluate(), self.out_channels.evaluate())) + raise ValueError( + "{} channels per group not valid with {} in channels and {} out channels".format( + value, self.in_channels.evaluate(), self.out_channels.evaluate() + ) + ) def set_current(self, x): possible_values = self.get_possible_values() @@ -59,11 +87,9 @@ def set_current(self, x): self.current_value = int(possible_values[np.argmin(diff)]) else: self.current_value = int(x) - - + def from_float(self, val): - possible_values = self.get_possible_values() val = int(val * (len(possible_values) - 1)) - + return possible_values[val] diff --git a/hannah/models/factory/factory.py b/hannah/models/factory/factory.py index c2ba87d3..f23c1e66 100644 --- a/hannah/models/factory/factory.py +++ b/hannah/models/factory/factory.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023 Hannah contributors. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. # See https://github.com/ekut-es/hannah for further info. @@ -230,7 +230,7 @@ def conv2d( norm: Union[BNConfig, bool] = False, act: Union[ActConfig, bool] = False, bias: bool = False, - ) -> None: + ) -> Any: """ Args: @@ -468,12 +468,12 @@ def conv1d( self, input_shape: Tuple[int, ...], out_channels: int, - kernel_size: int, - stride: int = 1, + kernel_size: Union[int, Tuple[int]], + stride: Union[int, Tuple[int]] = 1, bias: bool = False, - padding: Union[int, bool] = True, - dilation: int = 1, - groups: int = 1, + padding: Union[int, bool, Tuple[int]] = True, + dilation: Union[int, Tuple[int]] = 1, + groups: Union[int, Tuple[int]] = 1, norm: Union[BNConfig, bool] = False, act: Union[ActConfig, bool] = False, out_quant: bool = True, @@ -513,18 +513,19 @@ def conv1d( in_channels = input_shape[1] + if isinstance(kernel_size, tuple): + kernel_size = kernel_size[0] + if isinstance(stride, tuple): + stride = stride[0] + if isinstance(dilation, tuple): + dilation = dilation[0] + if padding is True: # Calculate full padding padding = self._padding(kernel_size, stride, dilation) if isinstance(padding, tuple): padding = padding[0] - if isinstance(dilation, tuple): - dilation = dilation[0] - if isinstance(kernel_size, tuple): - kernel_size = kernel_size[0] - if isinstance(stride, tuple): - stride = stride[0] if padding is False: padding = 0 diff --git a/hannah/modules/__init__.py b/hannah/modules/__init__.py index 77117e8d..284a8804 100644 --- a/hannah/modules/__init__.py +++ b/hannah/modules/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023 Hannah contributors. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. # See https://github.com/ekut-es/hannah for further info. @@ -16,11 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from .angle_classifier import ( - CartesianClassifierModule, - DirectAngleClassifierModule, - SINCOSClassifierModule, -) from .classifier import ( CrossValidationStreamClassifierModule, SpeechClassifierModule, @@ -34,7 +29,4 @@ "StreamClassifierModule", "AnomalyDetectionModule", "ObjectDetectionModule", - "CartesianClassifierModule", - "DirectAngleClassifierModule", - "SINCOSClassifierModule", ] diff --git a/hannah/modules/angle_classifier.py b/hannah/modules/angle_classifier.py deleted file mode 100644 index 8a3bb95d..00000000 --- a/hannah/modules/angle_classifier.py +++ /dev/null @@ -1,352 +0,0 @@ -# -# Copyright (c) 2023 Hannah contributors. -# -# This file is part of hannah. -# See https://github.com/ekut-es/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import logging -import math -import platform -from abc import abstractmethod -from typing import Dict, Optional, Union - -import numpy as np -import torch -import torch.utils.data as data -from hydra.utils import get_class, instantiate -from omegaconf import DictConfig -from pytorch_lightning import LightningModule -from sklearn.metrics import auc -from torchaudio.transforms import FrequencyMasking, TimeMasking, TimeStretch -from torchmetrics import Metric, MetricCollection - -from hannah.datasets.collate import ctc_collate_fn - -from ..models.factory.qat import QAT_MODULE_MAPPINGS -from ..utils import set_deterministic -from .config_utils import get_loss_function, get_model - -msglogger = logging.getLogger(__name__) - -from .classifier import StreamClassifierModule - - -class AngleClassifierModule(StreamClassifierModule): - def setup(self, stage): - # TODO stage variable is not used! - msglogger.info("Setting up model") - if self.logger: - msglogger.info("Model setup already completed skipping setup") - self.logger.log_hyperparams(self.hparams) - - if self.initialized: - return - - self.initialized = True - - if self.hparams.dataset is not None: - # trainset needed to set values in hparams - self.train_set, self.dev_set, self.test_set = self.get_split() - - self.num_classes = self.hparams.model.n_labels - - # Create example input - device = self.device - self.example_input_array = self.get_example_input_array() - dummy_input = self.example_input_array.to(device) - logging.info("Example input array shape: %s", str(dummy_input.shape)) - - # Instantiate features - self.features = instantiate(self.hparams.features) - self.features.to(device) - - features = self._extract_features(dummy_input) - self.example_feature_array = features.to(self.device) - - # Instantiate Model - if hasattr(self.hparams.model, "_target_") and self.hparams.model._target_: - print(self.hparams.model._target_) - self.model = instantiate( - self.hparams.model, - input_shape=self.example_feature_array.shape, - labels=self.num_classes, - _recursive_=False, - ) - else: - self.hparams.model.width = self.example_feature_array.size(2) - self.hparams.model.height = self.example_feature_array.size(1) - self.hparams.model.n_labels = self.num_classes - self.model = get_model(self.hparams.model) - - # loss function - self.criterion = self.loss_function - - # Metrics - self.train_metrics = MetricCollection( - { - "train_accuracy": self.get_accuracy_metric(), - "train_error": self.get_error_metric(), - } - ) - self.val_metrics = MetricCollection( - { - "val_accuracy": self.get_accuracy_metric(), - "val_error": self.get_error_metric(), - } - ) - self.test_metrics = MetricCollection( - { - "test_accuracy": self.get_accuracy_metric(), - "test_error": self.get_error_metric(), - } - ) - - def calculate_batch_metrics(self, output, y, loss, metrics, prefix): - if isinstance(output, list): - for idx, out in enumerate(output): - metrics(out, y) - self.log_dict(metrics, batch_size=self.batch_size) - else: - try: - metrics(output, y) - self.log_dict(metrics, batch_size=self.batch_size) - except ValueError: - logging.critical("Could not calculate batch metrics: {outputs}") - self.log(f"{prefix}_loss", loss, batch_size=self.batch_size) - - # TRAINING CODE - def training_step(self, batch, batch_idx): - x, x_len, y, y_len = batch - - output = self(x) - loss = self.criterion(output, y) - - # METRICS - self.calculate_batch_metrics(output, y, loss, self.train_metrics, "train") - - return loss - - # VALIDATION CODE - def validation_step(self, batch, batch_idx): - # dataloader provides these four entries per batch - x, x_length, y, y_length = batch - - # INFERENCE - output = self(x) - loss = self.criterion(output, y) - - # METRICS - self.calculate_batch_metrics(output, y, loss, self.val_metrics, "val") - return loss - - # TEST CODE - def test_step(self, batch, batch_idx): - # dataloader provides these four entries per batch - x, x_length, y, y_length = batch - - output = self(x) - loss = self.criterion(output, y) - - # METRICS - self.calculate_batch_metrics(output, y, loss, self.test_metrics, "test") - - return loss - - def forward(self, x): - x = self._extract_features(x) - x = self.model(x) - return x - - -class CartesianClassifierModule(AngleClassifierModule): - @staticmethod - def get_angle_diff(scores, labels): - assert scores.shape[0] == labels.shape[0] - assert scores.shape[1] == 2 - assert labels.shape[1] == 2 - - labels_norm = torch.nn.functional.normalize(labels) - - scores_norm = torch.nn.functional.normalize(scores) - - x_hat = labels_norm[:, 0] - y_hat = labels_norm[:, 1] - - x = scores_norm[:, 0] - y = scores_norm[:, 1] - - result = torch.acos(x_hat * x + y_hat * y) - - return result - - def get_dist(self, scores, labels): - assert scores.shape[0] == labels.shape[0] - assert scores.shape[1] == 2 - assert labels.shape[1] == 2 - - x_hat, y_hat = labels[:, 0], labels[:, 1] - x, y = scores[:, 0], scores[:, 1] - - dist = torch.sqrt(torch.square(x - x_hat) + torch.square(y - y_hat)) - return dist - - def loss_function(self, scores, labels): - return torch.mean(self.get_dist(scores, labels)) - - def get_error_metric(self): - return self.AngleError() - - def get_accuracy_metric(self): - return self.AngleAccuracy() - - class AngleError(Metric): - def __init__(self, dist_sync_on_step=False): - super().__init__(dist_sync_on_step=dist_sync_on_step) - - self.add_state("errors", default=torch.Tensor()) - - def update(self, preds: torch.Tensor, target: torch.Tensor): - error = CartesianClassifierModule.get_angle_diff(preds, target) - self.errors = torch.cat((self.errors, error)) - - def compute(self): - return torch.mean(self.errors) / (2 * math.pi) * 360.0 - - class AngleAccuracy(AngleError): - def compute(self): - return 1.0 - super().compute() / 180.0 - - -class DirectAngleClassifierModule(AngleClassifierModule): - def forward(self, x): - x = self._extract_features(x) - x = self.model(x) - x = torch.nn.functional.hardtanh(x, min_val=-math.pi, max_val=math.pi) - return x - - @staticmethod - def get_angle_diff(scores, labels, e=1e-7): - assert scores.shape[0] == labels.shape[0] - assert scores.shape[1] == 1 - assert labels.shape[1] == 2 - - labels_norm = torch.nn.functional.normalize(labels) - - x_hat = labels_norm[:, 0] - y_hat = labels_norm[:, 1] - - x = torch.sin(scores.squeeze()) - y = torch.cos(scores.squeeze()) - - result = torch.acos(torch.clamp(x_hat * x + y_hat * y, -1.0 + e, 1.0 - e)) - - return result - - def loss_function(self, scores, labels): - return torch.mean(self.get_angle_diff(scores, labels)) - - def get_error_metric(self): - return self.AngleError() - - def get_accuracy_metric(self): - return self.AngleAccuracy() - - class AngleError(Metric): - def __init__(self, dist_sync_on_step=False): - super().__init__(dist_sync_on_step=dist_sync_on_step) - - self.add_state("errors", default=torch.Tensor()) - - def update(self, preds: torch.Tensor, target: torch.Tensor): - error = DirectAngleClassifierModule.get_angle_diff(preds, target) - self.errors = torch.cat((self.errors, error)) - - def compute(self): - return torch.mean(self.errors) / (2 * math.pi) * 360.0 - - class AngleAccuracy(AngleError): - def compute(self): - return 1.0 - super().compute() / 180.0 - - -class SINCOSClassifierModule(AngleClassifierModule): - def forward(self, x): - x = self._extract_features(x) - x = self.model(x) - x = torch.nn.functional.hardtanh(x, min_val=-1, max_val=1) - return x - - @staticmethod - def get_loss(scores, labels): - assert scores.shape[0] == labels.shape[0] - assert scores.shape[1] == 2 - assert labels.shape[1] == 2 - - labels_norm = torch.nn.functional.normalize(labels) - - sin_hat = labels_norm[:, 0] - cos_hat = labels_norm[:, 1] - - sin = scores[:, 0] - cos = scores[:, 1] - - return torch.mean(torch.abs(sin_hat - sin) + torch.abs(cos_hat - cos)) - - @staticmethod - def get_angle_diff(scores, labels, e=1e-7): - assert scores.shape[0] == labels.shape[0] - assert scores.shape[1] == 2 - assert labels.shape[1] == 2 - - labels_norm = torch.nn.functional.normalize(labels) - - scores_norm = torch.nn.functional.normalize(scores) - - x_hat = labels_norm[:, 0] - y_hat = labels_norm[:, 1] - - x = scores_norm[:, 0] - y = scores_norm[:, 1] - - result = torch.acos(torch.clamp(x_hat * x + y_hat * y, -1.0 + e, 1.0 - e)) - - return result - - def loss_function(self, scores, labels): - return torch.mean(SINCOSClassifierModule.get_loss(scores, labels)) - - def get_error_metric(self): - return self.AngleError() - - def get_accuracy_metric(self): - return self.AngleAccuracy() - - class AngleError(Metric): - def __init__(self, dist_sync_on_step=False): - super().__init__(dist_sync_on_step=dist_sync_on_step) - - self.add_state("errors", default=torch.Tensor()) - - def update(self, preds: torch.Tensor, target: torch.Tensor): - error = SINCOSClassifierModule.get_angle_diff(preds, target) - self.errors = torch.cat((self.errors, error)) - - def compute(self): - return torch.mean(self.errors) / (2 * math.pi) * 360.0 - - class AngleAccuracy(AngleError): - def compute(self): - return 1.0 - super().compute() / 180.0 diff --git a/hannah/modules/augmentation/transforms/registry.py b/hannah/modules/augmentation/transforms/registry.py index 78066b48..13baabde 100644 --- a/hannah/modules/augmentation/transforms/registry.py +++ b/hannah/modules/augmentation/transforms/registry.py @@ -16,15 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -class TransformRegistry: - def __init__(self): - self.transforms = {} - def register(self, cls): - self.transforms[cls.__name__] = cls +from hannah.utils.registry import Registry - def instantiate(self, name, **params): - return self.transforms[name](**params) - - -registry = TransformRegistry() +registry = Registry() diff --git a/hannah/modules/vision/base.py b/hannah/modules/vision/base.py index a60500c4..97ac5cdd 100644 --- a/hannah/modules/vision/base.py +++ b/hannah/modules/vision/base.py @@ -22,7 +22,6 @@ from typing import Any, Optional, Sequence, Tuple import kornia -import kornia.augmentation as K import matplotlib.pyplot as plt import torch import torch.utils.data as data @@ -67,7 +66,9 @@ def setup(self, stage): msglogger.info(" Train Set (Labeled): %d", len(self.train_set)) msglogger.info( " Train Set (Unlabled): %d", - len(self.train_set_unlabeled) if self.train_set_unlabeled is not None else 0, + len(self.train_set_unlabeled) + if self.train_set_unlabeled is not None + else 0, ) msglogger.info(" Dev Set: %d", len(self.dev_set)) msglogger.info(" Test Set: %d", len(self.test_set)) @@ -75,7 +76,9 @@ def setup(self, stage): if self.test_set is not None: example_data = self._decode_batch(self.test_set[0])["data"].unsqueeze(0) else: - msglogger.warning("No test set found, using random data as example input array") + msglogger.warning( + "No test set found, using random data as example input array" + ) example_data = torch.randn(1, 3, 224, 224, device=self.device) if not isinstance(example_data, torch.Tensor): @@ -115,7 +118,9 @@ def setup(self, stage): self.train_set.mean, self.train_set.std, ) - self.input_normalizer = kornia.enhance.Normalize(self.train_set.mean, self.train_set.std) + self.input_normalizer = kornia.enhance.Normalize( + self.train_set.mean, self.train_set.std + ) # Setup Augmentations self.default_augmentation = torch.nn.Identity() @@ -126,21 +131,35 @@ def setup(self, stage): # Setup Metrics metrics = {} if self.num_classes > 0: - self.test_confusion = ConfusionMatrix("multiclass", num_classes=self.num_classes) + self.test_confusion = ConfusionMatrix( + "multiclass", num_classes=self.num_classes + ) - self.test_roc = ROC("multiclass", num_classes=self.num_classes, thresholds=10) - self.test_pr_curve = PrecisionRecallCurve("multiclass", num_classes=self.num_classes, thresholds=10) + self.test_roc = ROC( + "multiclass", num_classes=self.num_classes, thresholds=10 + ) + self.test_pr_curve = PrecisionRecallCurve( + "multiclass", num_classes=self.num_classes, thresholds=10 + ) for step_name in ["train", "val", "test"]: step_metrics = MetricCollection( { - f"{step_name}_accuracy": Accuracy("multiclass", num_classes=self.num_classes), - f"{step_name}_error": Error("multiclass", num_classes=self.num_classes), + f"{step_name}_accuracy": Accuracy( + "multiclass", num_classes=self.num_classes + ), + f"{step_name}_error": Error( + "multiclass", num_classes=self.num_classes + ), f"{step_name}_precision": Precision( "multiclass", num_classes=self.num_classes, average="macro" ), - f"{step_name}_recall": Recall("multiclass", num_classes=self.num_classes, average="macro"), - f"{step_name}_f1": F1Score("multiclass", num_classes=self.num_classes, average="macro"), + f"{step_name}_recall": Recall( + "multiclass", num_classes=self.num_classes, average="macro" + ), + f"{step_name}_f1": F1Score( + "multiclass", num_classes=self.num_classes, average="macro" + ), } ) metrics[f"{step_name}_metrics"] = step_metrics @@ -175,7 +194,10 @@ def _setup_datasets(self): data_splits: Tuple[Any] = dataset_cls.splits(self.hparams.dataset) if len(data_splits) == 3: - warnings.warn("Vision datasets should return a length 4 tuple, will assume unlabeled data is None") + warnings.warn( + "Vision datasets should return a length 4 tuple, will assume unlabeled data is None" + ) + data_splits = (data_splits[0], None, data_splits[1], data_splits[2]) ( @@ -187,7 +209,9 @@ def _setup_datasets(self): if self.hparams.unlabeled_data: unlabeled_cls = get_class(self.hparams.unlabeled_data.cls) - self.train_set_unlabeled, _, _ = unlabeled_cls.splits(self.hparams.unlabeled_data) + self.train_set_unlabeled, _, _ = unlabeled_cls.splits( + self.hparams.unlabeled_data + ) def _decode_batch(self, batch): if isinstance(batch, Sequence): @@ -254,7 +278,9 @@ def augment( if batch_idx == 0: pipeline_name = pipeline if pipeline is not None else "default" - self._log_batch_images(f"augmented_{pipeline_name}", batch_idx, augmented_norm_data) + self._log_batch_images( + f"augmented_{pipeline_name}", batch_idx, augmented_norm_data + ) return augmented_norm_data, images @@ -263,7 +289,9 @@ def setup_augmentations(self, pipeline_configs): augmentations = defaultdict(list) if pipeline_configs is None: - msglogger.warning("No data augmentations have been defined, make sure that this is intentional") + msglogger.warning( + "No data augmentations have been defined, make sure that this is intentional" + ) self.default_augmentation = torch.nn.Identity() return diff --git a/hannah/nas/constraints/dfg_constraint_model.py b/hannah/nas/constraints/dfg_constraint_model.py deleted file mode 100644 index 83f50ee6..00000000 --- a/hannah/nas/constraints/dfg_constraint_model.py +++ /dev/null @@ -1,286 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from z3 import Int, Or, Solver - -from hannah.nas.dataflow.dataflow_graph import DataFlowGraph, flatten -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor import Tensor -from hannah.nas.expressions.arithmetic import Add, Floor, Mul, Sub, Truediv -from hannah.nas.expressions.op import BinaryOp, UnaryOp -from hannah.nas.expressions.placeholder import ( - Categorical, - DefaultInt, - IntRange, - Placeholder, -) -from hannah.nas.ops import batched_image_tensor -from hannah.nas.parameters.parameters import ( - CategoricalParameter, - IntScalarParameter, - Parameter, -) -from hannah.nas.test.network import residual_block - - -class ConstraintModel: - def __init__(self) -> None: - self.solver = Solver() - self.vars = {} - - self.input_dict = {} - self.output_dict = {} - self.enter_dict = {} - - def build_model(self, graph): - queue = [graph] - visited = [graph] - - while queue: - graph = queue.pop(-1) - - if isinstance(graph, DataFlowGraph): - self.process_dataflow(graph) - - if graph.output not in visited: - queue.append(graph.output) - visited.append(graph.output) - - elif isinstance(graph, OpType): - self.process_optype(graph) - - for o in graph.operands: - if o not in visited: - queue.append(o) - visited.append(o) - - elif isinstance(graph, Tensor): - self.process_tensor(graph) - - def process_dataflow(self, graph): - # currently not needed because we use a flattened graph - # with only OpTypes and Tensors - pass - - def process_optype(self, op: OpType): - """Extracts the constraints based on the type of op. - New variables are added to self.vars and the constraints - are added to the solver. - - Parameters - ---------- - op : OpType - """ - if op.name == "Conv2d": - self.extract_conv_constraints(op) - elif op.name == "Add": - self.extract_add_constraints(op) - else: - self.extract_passthrough_constraints(op) - - def process_tensor(self, tensor: Tensor): - """Goes through all axis and extracts the constraints for - the respective axis sizes - - Parameters - ---------- - tensor : Tensor - """ - for name, ax in tensor.tensor_type().axis.items(): - self.build_constraint_from_expression(ax.size, []) - - def extract_conv_constraints(self, op: OpType): - input_tensor = op.operands[0].tensor_type() - output_tensor = op.tensor_type() - - for ax_name, ax in output_tensor.axis.items(): - con = self.build_constraint_from_expression( - output_tensor[ax_name].size, [input_tensor[ax_name].size] - ) - var = Int(f"{op.id}.{ax_name}.size") - self.vars[str(var)] = var - self.solver.add(var == con) - - padding = None - kernel_size = [] - for name, var in self.vars.items(): - if "padding" in name: - padding = var - elif "kh" in name or "kw" in name: - kernel_size.append(var) - - for ks in kernel_size: - self.solver.add(ks / 2 == padding) - - def extract_add_constraints(self, op): - output_tensor = op.tensor_type() - - for name, ax in output_tensor.axis.items(): - ax_out = Int(f"{op.id}.{name}.size") - - for operand in op.operands: - input_tensor = operand.tensor_type() - for in_name, in_ax in input_tensor.axis.items(): - ax_in = Int(f"{operand.id}.{name}.size") - self.solver.add(ax_in == ax_out) - - def extract_passthrough_constraints(self, op): - input_tensor = op.operands[0].tensor_type() - output_tensor = op.tensor_type() - - for ax_name, ax in output_tensor.axis.items(): - con = self.build_constraint_from_expression( - output_tensor[ax_name].size, [input_tensor[ax_name].size] - ) - var = Int(f"{op.id}.{ax_name}.size") - self.solver.add(var == con) - - def extract_parameter(self, expr): - if isinstance(expr, (IntScalarParameter, IntRange)): - return self.extract_int_range(expr) - elif isinstance(expr, (CategoricalParameter, Categorical)): - return self.extract_categorical(expr) - elif isinstance(expr, DefaultInt): - return self.extract_defaultint(expr) - elif isinstance(expr, int): - var = Int(expr.id) - self.solver.add(var == expr) - return var - - def extract_int_range(self, expr): - if expr.id: - var = Int(expr.id) - else: - var = Int(f"IntRange({expr.min}, {expr.max})") - # TODO: unique scope ids for DFG parameters - self.vars[str(var)] = var - self.solver.add(var >= expr.min) - self.solver.add(var <= expr.max) - if hasattr(expr, "step_size") and expr.step_size != 1: - self.solver.add((var - expr.min) % expr.step_size == 0) - - return var - - def extract_categorical(self, expr): - var = Int(expr.id) - self.vars[expr.id] = var - cons = [] - for val in expr.choices: - cons.append(var == val) - self.solver.add(Or(cons)) - return var - - def extract_defaultint(self, expr): - if expr.id: - var = Int(expr.id) - else: - var = Int(f"DefaultInt({expr.value})") - self.vars[str(var)] = var - self.solver.add(var == expr.value) - return var - - def build_constraint_from_expression(self, expr, inputs): - for inp in inputs: - if check_for_id(expr, inp): - in_var = Int(inp.id) - self.vars[inp.id] = in_var - return in_var - if isinstance(expr, Parameter): - var = self.extract_parameter(expr) - self.vars[str(var)] = var - return var - elif isinstance(expr, Placeholder): - var = self.extract_parameter(expr) - return var - elif isinstance(expr, Add): - lhs = self.build_constraint_from_expression(expr.lhs, inputs) - rhs = self.build_constraint_from_expression(expr.rhs, inputs) - con = lhs + rhs - return con - elif isinstance(expr, Truediv): - lhs = self.build_constraint_from_expression(expr.lhs, inputs) - rhs = self.build_constraint_from_expression(expr.rhs, inputs) - con = lhs / rhs - return con - elif isinstance(expr, Mul): - lhs = self.build_constraint_from_expression(expr.lhs, inputs) - rhs = self.build_constraint_from_expression(expr.rhs, inputs) - con = lhs * rhs - return con - elif isinstance(expr, Sub): - lhs = self.build_constraint_from_expression(expr.lhs, inputs) - rhs = self.build_constraint_from_expression(expr.rhs, inputs) - con = lhs - rhs - return con - elif isinstance(expr, Floor): - con = self.build_constraint_from_expression(expr.operand, inputs) - return con - elif isinstance(expr, int): - var = Int(f"Literal({expr})") - self.solver.add(var == expr) - return var - - -def check_for_id(a, b): - return hasattr(a, "id") and hasattr(b, "id") and a.id and b.id and a.id == b.id - - -def find_operand_in_expression(operand, expr): - queue = [expr] - visited = [expr] - - while queue: - current = queue.pop(-1) - if isinstance(current, UnaryOp): - print("Check Unary") - if check_for_id(current.operand, operand): - print("found") - else: - queue.append(current.operand) - visited.append(current.operand) - elif isinstance(current, BinaryOp): - print("Check Binary") - if check_for_id(operand, current.lhs): - print("Found lhs") - elif check_for_id(operand, current.rhs): - print("Found rhs") - else: - queue.append(current.lhs) - queue.append(current.rhs) - visited.append(current.lhs) - visited.append(current.rhs) - - -if __name__ == "__main__": - cm = ConstraintModel() - input = batched_image_tensor( - shape=(1, 3, 32, 32), - dtype=CategoricalParameter(choices=["int6", "int8"]), - name="input", - ) - graph = residual_block( - input, - stride=IntScalarParameter(1, 2), - output_channel=IntScalarParameter(4, 512, 4), - ) - graph = flatten(graph) - cm = ConstraintModel() - cm.build_model(graph) - inp = input.tensor_type() - blck = input.users[0].tensor_type() - print() diff --git a/hannah/nas/dataflow/__init__.py b/hannah/nas/dataflow/__init__.py deleted file mode 100644 index 76a5e433..00000000 --- a/hannah/nas/dataflow/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .dataflow_graph import DataFlowGraph, dataflow diff --git a/hannah/nas/dataflow/analysis/dataflow_analysis.py b/hannah/nas/dataflow/analysis/dataflow_analysis.py deleted file mode 100644 index 45dbbfba..00000000 --- a/hannah/nas/dataflow/analysis/dataflow_analysis.py +++ /dev/null @@ -1,71 +0,0 @@ -# -# Copyright (c) 2022 Hannah contributors. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from copy import deepcopy - -from hannah.nas.dataflow.dataflow_graph import DataFlowGraph -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.repeat import Repeat -from hannah.nas.dataflow.tensor import Tensor -from hannah.nas.dataflow.tensor_expression import TensorExpression - - -class DataFlowAnalysis: - def __init__(self) -> None: - pass - - def initial_env(self): - pass - - def create_worklist(self): - pass - - def analyze(self, expr: TensorExpression, env=None): - worklist = self.create_worklist(expr) - if env is None: - env = self.initial_env() - while worklist: - expr = worklist.pop() - if isinstance(expr, Repeat): - changed = self.visit_repeat(expr, env) - elif isinstance(expr, DataFlowGraph): - changed = self.visit_dataflowgraph(expr, env) - elif isinstance(expr, OpType): - changed = self.visit_optype(expr, env) - elif isinstance(expr, Tensor): - changed = self.visit_tensor(expr, env) - - if changed: - self.extend_worklist(expr) - - def visit_dataflowgraph(self, dfg, env) -> bool: - old_env = deepcopy(env) - self.analyze(dfg, env) - return old_env == env - - def visit_repeat(self, repeat, env) -> bool: - pass - - def visit_optype(self, op, env) -> bool: - pass - - def visit_tensor(self, tensor, env) -> bool: - - pass diff --git a/hannah/nas/dataflow/axis_type.py b/hannah/nas/dataflow/axis_type.py deleted file mode 100644 index ef609882..00000000 --- a/hannah/nas/dataflow/axis_type.py +++ /dev/null @@ -1,69 +0,0 @@ -from hannah.nas.dataflow.scoping_utils import get_id -from hannah.nas.expressions.placeholder import UndefinedInt -from .compression_type import CompressionType -from typing import Optional -from copy import deepcopy -from hannah.nas.parameters.parametrize import parametrize -from ..core.parametrized import is_parametrized - - -@parametrize -class AxisType: - def __init__( - self, - name: str, - size: Optional[int] = None, - compression: Optional[CompressionType] = None, - ): - self.name = name - if size is None: - self.size = UndefinedInt() - else: - self.size = size - self.compression = compression - - def new(self, new_name=None): - new_axis = deepcopy(self) - if new_name: - new_axis.name = new_name - return new_axis - - def set_scope(self, current_scope, counters, visited): - scope_id = get_id(current_scope, counters) - self.id = f'{scope_id}.axis.{self.name}' - self.set_param_scopes() - - -@parametrize -class AxisTuple: - """Used to have the axis dict as a parametrized object - """ - def __init__(self, *axis) -> None: - self.axis = {} - # reset parameters to improve naming - self._PARAMETERS = {} - for ax in axis: - self.axis[ax.name] = ax - if is_parametrized(ax): - self._PARAMETERS[ax.name] = ax - - def set_scope(self, current_scope, counters, visited): - scope_id = get_id(current_scope, counters) - self.id = f'{scope_id}.axis' - for _, ax in self.axis.items(): - ax.set_scope(current_scope, counters, visited) - - def values(self): - return self.axis.values() - - def items(self): - return self.axis.items() - - def __getitem__(self, key): - return self.axis[key] - - def __len__(self): - return len(self.axis) - - def __repr__(self) -> str: - return str(self.axis) diff --git a/hannah/nas/dataflow/compression_type.py b/hannah/nas/dataflow/compression_type.py deleted file mode 100644 index e897c201..00000000 --- a/hannah/nas/dataflow/compression_type.py +++ /dev/null @@ -1,3 +0,0 @@ -class CompressionType: - def __init__(self, method: str = "rle") -> None: - self.method = method diff --git a/hannah/nas/dataflow/dataflow_graph.py b/hannah/nas/dataflow/dataflow_graph.py deleted file mode 100644 index 9e009b04..00000000 --- a/hannah/nas/dataflow/dataflow_graph.py +++ /dev/null @@ -1,338 +0,0 @@ -from typing import Iterable -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor import TensorTuple -from hannah.nas.dataflow.dataflow_utils import find_first_op_in_dfg, find_leaf_nodes -from hannah.nas.dataflow.scoping_utils import get_id_and_update_counters, update_scope -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.expressions.placeholder import DefaultInt -from hannah.nas.parameters.parametrize import parametrize - -import numpy as np - - -@parametrize -class DataFlowGraph(TensorExpression): - def __init__(self, *operands, output, name: str = "dataflow") -> None: - super().__init__(*operands, tensor_type=None, name=name) - self.inputs = [] - self.output = output - self.enter = [] - # reset_user(self) - self.link_users() - self._scopes = {} - reset_scope_ids(self) - self.set_scopes() - - self.collect_scopes() - - def num_nodes(self): - n = len(self.nodes()) - return n - - def nodes(self): - g = flatten(self) - - queue = [g] - visited = [g] - node_list = [] - while queue: - current = queue.pop(-1) - node_list.append(current.id) - - for next_node in current.next_backwards(): - if next_node not in visited: - queue.append(next_node) - visited.append(next_node) - - return node_list - - def next_backwards(self): - return [self.output] - - def link_users(self): - """ Link the DFG to its users and the users of the DFG to - the DFG - """ - def _rewire_to_placeholder(operand, node): - """ - Parameters - ---------- - operand : TensorExpression - operand that we want to rewire - node : TensorExpression - node which uses the operand - """ - if operand in node.operands: - last_output = find_first_op_in_dfg(operand) - if node in last_output.users: - last_output.users.remove(node) - if node not in self.enter: - self.enter.append(node) - elif set(self.operands).isdisjoint(node.operands): - for n in node.next_backwards(): - _rewire_to_placeholder(operand, n) - for operand in self.operands: - _rewire_to_placeholder(operand, self.output) - # remove users if it is enclosed in 'self' - for user in operand.users: - if hasattr(self, 'enter') and user in self.enter: - operand.users.remove(user) - if self not in operand.users: - operand.users.append(self) - - self.output.users.append(self) - - - def set_scope(self, current_scope, counters, visited): - current_scope = update_scope(self, current_scope) - scope_id = get_id_and_update_counters(current_scope, counters) - self.id = scope_id - queue = [*self.enter] - visited.append(self) - - while queue: - node = queue.pop(-1) - node.set_scope(current_scope, counters, visited) - - leafs = [] - find_leaf_nodes(node, leafs, visited) - - while leafs: - leaf = leafs.pop(-1) - leaf.set_scope(current_scope, counters, visited) - visited.append(leaf) - - for u in node.users: - if u not in visited: - queue = [u] + queue - visited.append(u) - - def set_scopes(self): - visited = [] - current_scope = [] - node = find_first_input(self) - counters = {} - queue = [node] - visited.append(node) - - while queue: - node = queue.pop(-1) - node.set_scope(current_scope, counters, visited) - - leafs = [] - find_leaf_nodes(node, leafs, visited) - - while leafs: - leaf = leafs.pop(-1) - leaf.set_scope(current_scope, counters, visited) - visited.append(leaf) - - for u in node.users: - if u not in visited: - queue = [u] + queue - visited.append(u) - - def tensor_type(self): - return self.output.tensor_type() - - def adjacency(self): - g = flatten(self) - g.collect_scopes() - - node_list = self.nodes() - n = len(node_list) - - adj = np.zeros((n, n), dtype=np.int8) - indices = {n: i for i, n in enumerate(node_list)} - - for node_id in node_list: - node = g._scopes[node_id] - - for user in node.users: - adj[indices[node_id]][indices[user.id]] = 1 - - return adj, indices - - def __getitem__(self, key): - return self._scopes[key] - - def __repr__(self) -> str: - return "DataFlowGraph(id={})".format(self.id) - - def __str__(self) -> str: - lines = [] - print_from_input(find_first_input(self), 0, [], lines) - - return_str = "\n".join(lines) - return return_str - # return self.__repr__() - - -def print_from_input(input, indent, visited, lines): - queue = [input] - visited.append(input) - - while queue: - node = queue.pop(-1) - - leafs = [] - find_leaf_nodes(node, leafs, visited) - while leafs: - leaf = leafs.pop(-1) - print_from_input(leaf, indent + 1, visited, lines) - visited.append(leaf) - - lines.append('\t'*indent + f'{node.id}') - if isinstance(node, DataFlowGraph): - for e in node.enter: - print_from_input(e, indent + 1, visited, lines) - - for u in node.users: - if u not in visited: - queue = [u] + queue - visited.append(u) - - -def dataflow(func): - def wrapper_func(*args, **kwargs): - name = func.__name__ - operands = args - for key, value in kwargs.items(): - if isinstance(value, int): - kwargs[key] = DefaultInt(value) - output = func(*args, **kwargs) - - if isinstance(output, Iterable): - output = TensorTuple(output, name=name+".output") - - dfg = DataFlowGraph(*operands, output=output, name=name) - return dfg - - return wrapper_func - - -def flatten(graph): - delete_users(graph) - queue = [graph] - visited = [] - - while queue: - current = queue.pop(-1) - visited.append(current) - if isinstance(current, DataFlowGraph): - if current.output not in visited: - queue.append(current.output) - - elif isinstance(current, OpType): - # for each operand, traverse potential DFGs and store - # the first apperearing op - replace_map = {} - for i, operand in enumerate(current.operands): - op = find_first_op_in_dfg(operand) - replace_map[i] = op - - op.users.append(current) - - if operand not in visited: - queue.append(operand) - - # replace operands with non-dfg variant - current.operands = list(current.operands) - for idx, op in replace_map.items(): - current.operands[idx] = op - current.operands = tuple(current.operands) - - return find_first_op_in_dfg(graph) - - -def unflatten(graph): - pass - - -def delete_users(graph, user_to_delete=None): - queue = [graph] - visited = [graph] - - while queue: - current = queue.pop(-1) - if user_to_delete: - if user_to_delete in current.users: - current.users.remove(user_to_delete) - else: - current.users = [] - - for next_node in current.next_backwards(): - if next_node not in visited: - queue.append(next_node) - visited.append(next_node) - - -def collect_users(node): - """ Traverse graph starting from `node` and collect - all users (including users from subsequent nodes). - If a node_b is NOT in collect_users(node_a), this means - that node_b is either BEFORE node_a in the graph OR it is - in a parallel branch. - - Parameters - ---------- - node : _type_ - _description_ - - Returns - ------- - _type_ - _description_ - """ - collected_users = [] - queue = [node] - visited = [] - - while queue: - node = queue.pop(-1) - for u in node.users: - if u not in visited: - queue = [u] + queue - visited.append(u) - collected_users.append(u) - - return collected_users - - -def reset_scope_ids(node): - node.set_id(node.name) - - for next_node in node.next_backwards(): - reset_scope_ids(next_node) - - -def find_first_input(node): - """Recusively traverses the graph from the given node - back to its first input. NOTE: The traversal is via OPERANDS - and not OUTPUT, meaning that e.g. weight Tensors that are - included in Ops in a DFG are not returned - - Parameters - ---------- - node : _type_ - _description_ - - Returns - ------- - _type_ - _description_ - """ - if node.operands: - for o in node.operands: - return find_first_input(o) - else: - return node - - -def recursive_traversal(node : TensorExpression, hooks : list = [], hook_parameter : dict = {}, end=None): - for hook in hooks: - param = hook_parameter.get(hook, {}) - hook(node, **param) - if node != end: - for next_node in node.next_backwards(): - recursive_traversal(next_node, hooks, hook_parameter, end) diff --git a/hannah/nas/dataflow/dataflow_utils.py b/hannah/nas/dataflow/dataflow_utils.py deleted file mode 100644 index 457b36c6..00000000 --- a/hannah/nas/dataflow/dataflow_utils.py +++ /dev/null @@ -1,61 +0,0 @@ - -from typing import Optional -from hannah.nas.core.expression import Expression -from hannah.nas.dataflow.tensor import Tensor -from hannah.nas.expressions.placeholder import DefaultInt, UndefinedInt - - -# FIXME: Rename to "last" op -def find_first_op_in_dfg(node): - if hasattr(node, 'output'): - return find_first_op_in_dfg(node.output) - else: - return node - - -def find_next_dataflow(node): - if hasattr(node, 'output'): - return node - else: - assert len(node.users) < 2, "Next DataflowGraph is ambiguous" - return find_next_dataflow(node.users[0]) - - -def remove_old_users(node): - if hasattr(node, 'output'): - return find_first_op_in_dfg(node.output) - else: - return node - - -def find_leaf_nodes(node, leafs, visited): - for o in node.operands: - if o not in visited: - if isinstance(o, Tensor): - leafs.append(o) - else: - leafs.append(o) - find_leaf_nodes(o, leafs, visited) - - -def traverse_by_users(node): - def _traverse_by_users(node, visited): - print(node.id) - leafs = [] - visited.append(node) - find_leaf_nodes(node, leafs, visited) - for leaf in leafs: - _traverse_by_users(leaf, visited) - for u in node.users: - if u not in visited: - _traverse_by_users(u, visited) - _traverse_by_users(node, []) - - -def process_int(x: Optional[int]): - if isinstance(x, int): - return DefaultInt(x) - elif isinstance(x, Expression): - return x - elif x is None: - return UndefinedInt() diff --git a/hannah/nas/dataflow/op_type.py b/hannah/nas/dataflow/op_type.py deleted file mode 100644 index 8e15b496..00000000 --- a/hannah/nas/dataflow/op_type.py +++ /dev/null @@ -1,89 +0,0 @@ -from hannah.nas.core.parametrized import is_parametrized -from hannah.nas.dataflow.dataflow_utils import find_first_op_in_dfg, find_leaf_nodes -from hannah.nas.dataflow.scoping_utils import get_id_and_update_counters, update_scope -from hannah.nas.dataflow.tensor_expression import TensorExpression -import hannah.nas.dataflow.registry as reg -from hannah.nas.parameters.parametrize import parametrize - - -@parametrize -class OpType(TensorExpression): - def __init__(self, *operands, tensor_type=None, name="", **attributes): - super().__init__(*operands, tensor_type=tensor_type, name=name) - - self._PARAMETERS = {} - # self._conditions = [] - - for i, operand in enumerate(operands): - if is_parametrized(operand): - self._PARAMETERS[i] = operand - - for name, attribute in attributes.items(): - setattr(self, name, attribute) - if is_parametrized(attribute): - self._PARAMETERS[name] = attribute - - self.attributes = attributes - self.link_users() - - def next_backwards(self): - return list(self.operands) - - def link_users(self): - for operand in self.operands: - last_output = find_first_op_in_dfg(operand) - if self not in last_output.users: - last_output.users.append(self) - - def set_scope(self, current_scope, counters, visited): - current_scope = update_scope(self, current_scope) - scope_id = get_id_and_update_counters(current_scope, counters) - self.id = scope_id - self.set_param_scopes() - visited.append(self) - - leafs = [] - find_leaf_nodes(self, leafs, visited) - - while leafs: - leaf = leafs.pop(-1) - leaf.set_scope(current_scope, counters, visited) - visited.append(leaf) - - def tensor_type(self): - tensortype = reg.shape(self.name)(self) - - return tensortype - - # def sample(self): - # for _key, param in self._PARAMETERS.items(): - # param.sample() - - # def set_current(self, value): - # self.set_params(**value) - # self.check(None) # argument "value" not needed currently - - # def check(self, value): - # for con in self._conditions: - # if not con.evaluate(): - # raise Exception("Condition not satisfied: {}".format(con)) - - # def instantiate(self): - # instance = deepcopy(self) - # instance._parametrized = False - # self.check(None) - - # for key, param in instance._PARAMETERS.items(): - # instantiated_value = param.instantiate() - # instance._PARAMETERS[key] = instantiated_value - # setattr(instance, key, instantiated_value) - # return instance - - # def parameters(self): - # return self._PARAMETERS - - def convert(self, target): - return reg.convert(self.name, target)(self) - - def __repr__(self) -> str: - return "Op({})".format(self.id) diff --git a/hannah/nas/dataflow/ops/__init__.py b/hannah/nas/dataflow/ops/__init__.py deleted file mode 100644 index 1482d30a..00000000 --- a/hannah/nas/dataflow/ops/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import conv2d \ No newline at end of file diff --git a/hannah/nas/dataflow/ops/add.py b/hannah/nas/dataflow/ops/add.py deleted file mode 100644 index 21552552..00000000 --- a/hannah/nas/dataflow/ops/add.py +++ /dev/null @@ -1,34 +0,0 @@ -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_type import TensorType -from hannah.nas.dataflow.register_ops import add_op, add_shape_func - -from hannah.nas.dataflow.tensor_expression import TensorExpression - - -@add_op -class Add: - input: TensorExpression - other: TensorExpression - - -@add_shape_func("Add") -def add_shape(op: OpType): - input = op.operands[0].tensor_type() - other = op.operands[1].tensor_type() - - assert input.dim() == other.dim() - ax = [] - # constraints = [] - for ax1, ax2 in zip(input.axis.values(), other.axis.values()): - con = ax1.size == ax2.size - # constraints.append(con) - op.cond(con) - # assert con.evaluate(), """Tensor axis sizes do not match: Axis {} with dimension - # {} and axis {} with dimension {}""".format(ax1, - # input.tensor_type.axis[ax1].size, - # ax2, - # other.tensor_type.axis[ax2].size) - ax.append(ax1.new()) - - ax = tuple(ax) - return TensorType(ax, dtype=input.dtype) diff --git a/hannah/nas/dataflow/ops/batch_nom.py b/hannah/nas/dataflow/ops/batch_nom.py deleted file mode 100644 index cc2c414e..00000000 --- a/hannah/nas/dataflow/ops/batch_nom.py +++ /dev/null @@ -1,20 +0,0 @@ -from copy import deepcopy -from hannah.nas.core.expression import Expression -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.register_ops import add_op, add_shape_func -from hannah.nas.expressions.placeholder import Categorical, FloatRange, DefaultFloat, DefaultBool -from hannah.nas.parameters.parameters import CategoricalParameter, FloatScalarParameter - - -@add_op -class BatchNorm2d: - input: TensorExpression - eps: Expression = DefaultFloat(0.00001) - momentum: Expression = DefaultFloat(0.1) - affine: Expression = DefaultBool(True) - - -@add_shape_func("BatchNorm2d") -def add_shape(op: OpType): - return deepcopy(op.operands[0].tensor_type()) diff --git a/hannah/nas/dataflow/ops/concat.py b/hannah/nas/dataflow/ops/concat.py deleted file mode 100644 index 62c645ed..00000000 --- a/hannah/nas/dataflow/ops/concat.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import List -from hannah.nas.dataflow.axis_type import AxisType -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.register_ops import add_op, add_shape_func -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.tensor_type import TensorType - - -@add_op -class Concat: - inputs: List[TensorExpression] - axis: int - out_axis_name: str - - -@add_shape_func("Concat") -def add_shape(op: OpType): - tensors = [] - for operand in op.operands: - tensors.append(operand.tensor_type()) - - for tensor in tensors: - assert tensors[0].dim() == tensor.dim() - - ax = [] - # constraints = [] - ax_new = AxisType(name=op.out_axis_name) - for tensor in tensors: - for i, (ax1, ax2) in enumerate(zip(tensors[0].axis.values(), tensor.axis.values())): - if i != op.axis: - con = ax1.size == ax2.size - op.cond(con) - ax.append(ax1.new()) - else: - if ax_new.size: - ax_new.size = ax2.size - else: - ax_new.size = ax_new.size + ax2.size - ax.append(ax_new) - - ax = tuple(ax) - return TensorType(ax, dtype=tensors[0].dtype) diff --git a/hannah/nas/dataflow/ops/conv2d.py b/hannah/nas/dataflow/ops/conv2d.py deleted file mode 100644 index b72b6e46..00000000 --- a/hannah/nas/dataflow/ops/conv2d.py +++ /dev/null @@ -1,65 +0,0 @@ -from hannah.nas.core.expression import Expression -from hannah.nas.dataflow.axis_type import AxisType -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.tensor_type import TensorType -from hannah.nas.expressions.placeholder import DefaultInt, IntRange -from hannah.nas.dataflow.register_ops import add_op, add_shape_func, add_conversion -from hannah.nas.expressions.arithmetic import Floor -import torch.nn.functional as F -from torch.nn import Conv2d as torch_conv -from hannah.nas.parameters.parameters import IntScalarParameter - -from hannah.nas.parameters.parametrize import parametrize - - -@add_op -class Conv2d: - input: TensorExpression - weight: TensorExpression - dilation: Expression = DefaultInt(1) - stride : Expression = DefaultInt(1) - padding: Expression = IntScalarParameter(0, 10) - groups: Expression = DefaultInt(1) - - -@add_shape_func("Conv2d") -def conv2d_shape(op: OpType): - def _calc_output_dim(out_dim_name, input_dim, padding, dilation, kernel, stride) -> AxisType: - input_size = input_dim.size - kernel_size = kernel.size - ax = AxisType(name=out_dim_name, size=Floor(((input_size + padding * 2 - dilation * (kernel_size - 1) - 1) / stride) + 1)) - return ax - input = op.operands[0].tensor_type() - weight = op.operands[1].tensor_type() - - batch = input['n'] - out_channel = weight['o'].new('c') - output_height = _calc_output_dim('h', input['h'], op.padding, op.dilation, weight['kh'], op.stride) - output_width = _calc_output_dim('w', input['w'], op.padding, op.dilation, weight['kw'], op.stride) - - # FIXME: Just take inputs dtype? - return TensorType((batch, out_channel, output_height, output_width), dtype=input.dtype) - - -# @add_conversion("Conv2d", target="torch") -# def conv2d_torch(op: OpType): -# # kernel_size = op.kernel_size -# # dilation = op.dilation -# # stride = op.stride -# # padding = op.padding - -# # input_tensor = op.input.output_tensor() -# # output_tensor = op.output_tensor() - -# # torch_op = torch_conv(in_channels=input_tensor['c'].size.evaluate(), -# # out_channels=output_tensor['c'].size.evaluate(), -# # kernel_size=kernel_size, -# # stride=stride, -# # padding=padding, -# # dilation=dilation) - -# # # def conv2d_func(input, weight): -# # return F.conv2d(input, weight, None, stride.evaluate()) -# torch_op = None -# return torch_op diff --git a/hannah/nas/dataflow/ops/dropout.py b/hannah/nas/dataflow/ops/dropout.py deleted file mode 100644 index 94c40c61..00000000 --- a/hannah/nas/dataflow/ops/dropout.py +++ /dev/null @@ -1,19 +0,0 @@ -from copy import deepcopy -from typing import Union -from hannah.nas.core.expression import Expression -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.register_ops import add_op, add_shape_func -from hannah.nas.expressions.placeholder import FloatRange -from hannah.nas.parameters.parameters import FloatScalarParameter - - -@add_op -class Dropout2d: - input: TensorExpression - p: Expression - - -@add_shape_func("Dropout2d") -def add_shape(op: OpType): - return deepcopy(op.operands[0].tensor_type()) diff --git a/hannah/nas/dataflow/ops/identity.py b/hannah/nas/dataflow/ops/identity.py deleted file mode 100644 index 4d5f5d90..00000000 --- a/hannah/nas/dataflow/ops/identity.py +++ /dev/null @@ -1,14 +0,0 @@ -from copy import deepcopy -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.register_ops import add_op, add_shape_func - - -@add_op -class Identity: - input: TensorExpression - - -@add_shape_func("Identity") -def add_shape(op: OpType): - return deepcopy(op.operands[0].tensor_type()) diff --git a/hannah/nas/dataflow/ops/linear.py b/hannah/nas/dataflow/ops/linear.py deleted file mode 100644 index 8db7dd3f..00000000 --- a/hannah/nas/dataflow/ops/linear.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Union -from hannah.nas.core.expression import Expression -from hannah.nas.dataflow.axis_type import AxisType -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.tensor_type import TensorType -from hannah.nas.expressions.placeholder import IntRange -from hannah.nas.dataflow.register_ops import add_op, add_shape_func -from hannah.nas.expressions.arithmetic import Floor -from hannah.nas.parameters.parameters import IntScalarParameter - - -@add_op -class Linear: - input: TensorExpression - out_features: Expression - - -@add_shape_func("Linear") -def conv2d_shape(op: OpType): - input_tensor = op.operands[0].tensor_type() - out_axis = AxisType(name='features', size=op.out_features) - return TensorType((input_tensor.axis['n'].new(), out_axis), dtype=input.dtype) diff --git a/hannah/nas/dataflow/ops/pooling.py b/hannah/nas/dataflow/ops/pooling.py deleted file mode 100644 index 5c70cbe8..00000000 --- a/hannah/nas/dataflow/ops/pooling.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Union -from hannah.nas.core.expression import Expression -from hannah.nas.dataflow.axis_type import AxisType -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.register_ops import add_op, add_shape_func -from hannah.nas.dataflow.tensor_type import TensorType -from hannah.nas.expressions.placeholder import IntRange -from hannah.nas.parameters.parameters import IntScalarParameter - - -@add_op -class AdaptiveAveragePooling: - input: TensorExpression - output_size: Expression - - -@add_shape_func("AdaptiveAveragePooling") -def add_shape(op: OpType): - tensor = op.operands[0].tensor_type() - new_h = AxisType(name='h', size=op.output_size) - new_w = AxisType(name='w', size=op.output_size) - output_tensor_axis = (tensor.axis['n'].new(), tensor.axis['c'].new(), new_h, new_w) - return TensorType(output_tensor_axis, dtype=tensor.dtype) diff --git a/hannah/nas/dataflow/ops/relu.py b/hannah/nas/dataflow/ops/relu.py deleted file mode 100644 index 191c954d..00000000 --- a/hannah/nas/dataflow/ops/relu.py +++ /dev/null @@ -1,14 +0,0 @@ -from copy import deepcopy -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.register_ops import add_op, add_shape_func - - -@add_op -class Relu: - input: TensorExpression - - -@add_shape_func("Relu") -def add_shape(op: OpType): - return deepcopy(op.operands[0].tensor_type()) diff --git a/hannah/nas/dataflow/ops/sum.py b/hannah/nas/dataflow/ops/sum.py deleted file mode 100644 index e2830737..00000000 --- a/hannah/nas/dataflow/ops/sum.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import List -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor import TensorTuple -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.tensor_type import TensorType -from hannah.nas.dataflow.register_ops import add_op, add_shape_func - - -@add_op -class Sum: - inputs: List[TensorExpression] - - -@add_shape_func("Sum") -def add_shape(op: OpType): - tensors = [] - for operand in op.operands: - tensors.append(operand.tensor_type()) - - for tensor in tensors: - assert tensors[0].dim() == tensor.dim() - ax = [] - # constraints = [] - for tensor in tensors: - for ax1, ax2 in zip(tensors[0].axis.values(), tensor.axis.values()): - con = ax1.size == ax2.size - # constraints.append(con) - op.cond(con) - # assert con.evaluate(), """Tensor axis sizes do not match: Axis {} with dimension - # {} and axis {} with dimension {}""".format(ax1, - # input.tensor_type.axis[ax1].size, - # ax2, - # other.tensor_type.axis[ax2].size) - ax.append(ax1.new()) - - ax = tuple(ax) - return TensorType(ax, dtype=tensors[0].dtype) diff --git a/hannah/nas/dataflow/optional_op.py b/hannah/nas/dataflow/optional_op.py deleted file mode 100644 index c08a0321..00000000 --- a/hannah/nas/dataflow/optional_op.py +++ /dev/null @@ -1,7 +0,0 @@ -class OptionalOp: - def __init__(self, op, default): - self.op = op - self.default = default - - def __str__(self): - return "optional(" + str(self.op) + str(self.default) + ")" diff --git a/hannah/nas/dataflow/quantization_type.py b/hannah/nas/dataflow/quantization_type.py deleted file mode 100644 index 2bb48fed..00000000 --- a/hannah/nas/dataflow/quantization_type.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Optional - -from .axis_type import AxisType - - -class QuantizationType: - def __init__( - self, - axis: Optional[AxisType] = None, - scale: Optional[float] = None, - zero_point: Optional[float] = None, - ) -> None: - self.axis = axis - self.scale = scale - self.zero_point = zero_point diff --git a/hannah/nas/dataflow/register_ops.py b/hannah/nas/dataflow/register_ops.py deleted file mode 100644 index 4c62b5fb..00000000 --- a/hannah/nas/dataflow/register_ops.py +++ /dev/null @@ -1,88 +0,0 @@ -from hannah.nas.core.expression import T -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.registry import _OPS, _SHAPE_FUNCS, _CONVERSIONS -from typing import List - - -MISSING = 'missing' - - -def add_op(op_class): - operands = {} - attributes_annotations = {} - attributes_defaults = {} - cls_annotations = op_class.__annotations__ - for name, annotation in cls_annotations.items(): - if annotation is TensorExpression: - operands[name] = annotation - elif annotation is List[TensorExpression]: - operands[name] = annotation - else: - attributes_annotations[name] = annotation - default = getattr(op_class, name, MISSING) - attributes_defaults[name] = default - # if hasattr(op_class, name): - # delattr(op_class, name) - op_class.operands = operands - op_class.attributes_annotations = attributes_annotations - op_class.attributes_defaults = attributes_defaults - - op_class.create_op = _create_op - - _OPS[op_class.__name__] = op_class - - return op_class - - -@classmethod -def _create_op(cls, *operands, **attributes): - # If we allow a list of TensorExpressions as input the op can have arbitarily many operands - if not List[TensorExpression] in cls.operands.values(): - # otherwise just check for the same amount of operands, not names - assert len(operands) == len(cls.operands), "{} expects (exactly) the following operands: {}".format(cls.__name__, cls.operands) - for operand, operand_name in zip(operands, cls.operands): - assert isinstance(operand, cls.operands[operand_name]), \ - "Wrong operand type: Expected {} and got {}".format(type(operand), cls.operands[operand_name]) - - missing_keys = set(cls.attributes_annotations.keys()) - set(attributes.keys()) - for name, attribute in attributes.items(): - assert name in cls.attributes_annotations, "{} has no attribute \"{}\"".format(cls.__name__, name) - assert isinstance(attribute, cls.attributes_annotations[name]), "Attribute {} of {} must be of type {}".format(name, cls.__name__, cls.attributes_annotations[name].__name__) - - full_attributes = attributes - for name in missing_keys: - assert cls.attributes_defaults[name] is not MISSING, \ - "{} requires a named attribute {}: {} because no default is specified.".format(cls.__name__, name, cls.attributes_annotations[name]) - full_attributes[name] = cls.attributes_defaults[name] - optype = OpType(*operands, **full_attributes, name=str(cls.__name__)) - - # retrospectively set operands as fields with keyword name - # for operand, operand_name in zip(operands, cls.operands): - # setattr(optype, operand_name, operand) - - return optype - - -def add_shape_func(op_name): - def register_func(func): - _SHAPE_FUNCS[op_name] = func - return func - return register_func - - -def add_conversion(op_name, target): - def wrapper(func): - if op_name in _CONVERSIONS: - _CONVERSIONS[op_name][target] = func - else: - _CONVERSIONS[op_name] = {target: func} - return func - return wrapper - - -# Register default pass-through shape function -@add_shape_func('Default') -def default_shape(op: OpType): - input_tensor_type = op.operands[0].tensor_type() - return input_tensor_type diff --git a/hannah/nas/dataflow/registry.py b/hannah/nas/dataflow/registry.py deleted file mode 100644 index f1dd51e8..00000000 --- a/hannah/nas/dataflow/registry.py +++ /dev/null @@ -1,21 +0,0 @@ -_OPS = {} -_SHAPE_FUNCS = {} -_CONVERSIONS = {} - - -def op(name, *operands, **attributes): - return _OPS[name].create_op(*operands, **attributes) - - -def shape(op_name): - if op_name not in _SHAPE_FUNCS: - # if no shape func is registered for the op, just pass-through the tensor type - # of the first operand (which we assume to be the input) - op_name = 'Default' - return _SHAPE_FUNCS[op_name] - - -def convert(op_name, target): - assert op_name in _CONVERSIONS, f"No conversion strategies for op {op_name} found" - assert target in _CONVERSIONS[op_name], f"Op {op_name} has no conversion strategy for target {target}" - return _CONVERSIONS[op_name][target] diff --git a/hannah/nas/dataflow/repeat.py b/hannah/nas/dataflow/repeat.py deleted file mode 100644 index 91ef4d58..00000000 --- a/hannah/nas/dataflow/repeat.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Any -from hannah.nas.dataflow.dataflow_graph import DataFlowGraph, delete_users - - -class Repeater: - def __init__(self, block, num_repeats) -> None: - self.block = block - self.num_repeats = num_repeats - - def __str__(self): - ret = "Repeater" - return ret - - def __call__(self, *args, **kwargs) -> Any: - out_block = self.block(*args, **kwargs) - operands = out_block.operands - output = out_block.output - name = out_block.name - - delete_users(out_block, out_block) - del out_block - - - # create Repeat (i.e. child) instance from DataFlowGraph instance - out = Repeat(*operands, output=output, num_repeats=self.num_repeats, name=name) - return out - - -class Repeat(DataFlowGraph): - def __init__(self, *operands, output, num_repeats, name: str = "dataflow") -> None: - super().__init__(*operands, output=output, name=name) - self.num_repeats = num_repeats - - def dfg_line_representation(self, indent, input_names): - return '\t'*indent + self.id + " (repeats: {})".format(self.num_repeats) + ':' - - def __repr__(self) -> str: - return "DataFlowGraph(id={}) - repeats: ({})".format(self.id, self.num_repeats) - - -def repeat(block, num_repeats=1): - return Repeater(block=block, num_repeats=num_repeats) diff --git a/hannah/nas/dataflow/scoping_utils.py b/hannah/nas/dataflow/scoping_utils.py deleted file mode 100644 index 09a7ec1a..00000000 --- a/hannah/nas/dataflow/scoping_utils.py +++ /dev/null @@ -1,29 +0,0 @@ - - -def get_id_and_update_counters(current_scope, counters): - if len(current_scope) > 1: - scope = '.'.join([current_scope[-2].id, current_scope[-1].name]) - else: - scope = current_scope[-1].name - if scope not in counters: - counters[scope] = 0 - else: - counters[scope] += 1 - - return '{}.{}'.format(scope, counters[scope]) - - -def get_id(current_scope, counters): - if len(current_scope) > 1: - scope = '.'.join([current_scope[-2].id, current_scope[-1].name]) - else: - scope = current_scope[-1].name - - if scope not in counters: - counters[scope] = 0 - - return '{}.{}'.format(scope, counters[scope]) - - -def update_scope(node, current_scope): - return current_scope + [node] diff --git a/hannah/nas/dataflow/tensor.py b/hannah/nas/dataflow/tensor.py deleted file mode 100644 index 190d2675..00000000 --- a/hannah/nas/dataflow/tensor.py +++ /dev/null @@ -1,51 +0,0 @@ -from copy import deepcopy -from typing import List -from hannah.nas.dataflow.scoping_utils import get_id_and_update_counters, update_scope -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.dataflow.tensor_type import TensorType -from hannah.nas.parameters.parametrize import parametrize - - -@parametrize -class Tensor(TensorExpression): - def __init__(self, *operands, tensor_type : TensorType = None, name : str = "") -> None: - super().__init__(*operands, tensor_type=tensor_type, name=name) - - def set_scope(self, current_scope, counters, visited): - current_scope = update_scope(self, current_scope) - scope_id = get_id_and_update_counters(current_scope, counters) - self.id = scope_id - - if self._tensor_type: - self._tensor_type.set_scope(current_scope, counters, visited) - - def new(self): - new_tensor = deepcopy(self) - return new_tensor - - @property - def dim(self): - return self.tensor_type.dim() - - def __repr__(self) -> str: - return "Tensor({})".format(self.id) - - def __str__(self) -> str: - return self.__repr__() - - def __getitem__(self, key): - return self._tensor_type.axis[key] - - -class TensorTuple(TensorExpression): - def __init__(self, tensors : List[Tensor], name: str = ""): - super().__init__(name=name) - self.tensors = tensors - self.name = name - - def set_scope(self, current_scope, counters, visited): - current_scope = update_scope(self, current_scope) - scope_id = get_id_and_update_counters(current_scope, counters) - self.id = scope_id - for tensor in self.tensors: - tensor.set_scope(current_scope, counters, visited) diff --git a/hannah/nas/dataflow/tensor_expression.py b/hannah/nas/dataflow/tensor_expression.py deleted file mode 100644 index 95cf7bf6..00000000 --- a/hannah/nas/dataflow/tensor_expression.py +++ /dev/null @@ -1,36 +0,0 @@ -class TensorExpression: - def __init__(self, *operands, tensor_type=None, name="") -> None: - self.operands = operands - self._tensor_type = tensor_type - self.users = [] - self.name = name - self.id = name - self._scopes = {} - - def set_id(self, id): - self.id = id - - def tensor_type(self): - assert self._tensor_type is not None, "Tensor Type has not been set, please run shape inference" - return self._tensor_type - - def next_backwards(self): - return [] - - def next_forwards(self): - return [] - - def collect_scopes(self): - queue = [self] - visited = [self] - - while queue: - current = queue.pop(-1) - self._scopes[current.id] = current - - for next_node in current.next_backwards(): - if next_node not in visited: - queue.append(next_node) - visited.append(next_node) - - self._scopes = dict(sorted(self._scopes.items())) diff --git a/hannah/nas/dataflow/tensor_type.py b/hannah/nas/dataflow/tensor_type.py deleted file mode 100644 index 6e0eeddf..00000000 --- a/hannah/nas/dataflow/tensor_type.py +++ /dev/null @@ -1,51 +0,0 @@ -from hannah.nas.parameters.parametrize import parametrize -from ..hardware_description.memory_type import MemoryType -from .quantization_type import QuantizationType -from .data_type import DataType -from .axis_type import AxisTuple, AxisType -from typing import Optional, Tuple - -from hannah.nas.dataflow.scoping_utils import get_id - - -@parametrize -class TensorType: - def __init__( - self, - axis: Tuple[AxisType, ...], - dtype: DataType, - quantization: Optional[QuantizationType] = None, - memory: Optional[MemoryType] = None, - name: str = "", - ): - self.axis = AxisTuple(*axis) - self._PARAMETERS['axis'] = self.axis - # for ax in axis: - # self.axis[ax.name] = ax - self.dtype = dtype - self.quantization = quantization - self.memory = memory - self.name = name - self.id = name - - def set_scope(self, current_scope, counters, visited): - scope_id = get_id(current_scope, counters) - self.id = f'{scope_id}.tensor_type' - - self.axis.set_scope(current_scope, counters, visited) - - def dim(self) -> int: - return len(self.axis) - - def shape(self) -> Tuple[int, ...]: - return tuple((ax.size for ax in self.axis.values)) - - def __getitem__(self, key): - return self.axis[key] - - def __repr__(self) -> str: - # return 'Tensor(name=' + self.name + ", axis=(" + ' '.join(['{}, '.format(a) for a in self.axis.keys()]) + '))' - return "TensorType({})".format(self.name) - - -# OutputType = Union[TensorType, TensorTuple] diff --git a/hannah/nas/dataflow/transformations/graph_tranformer.py b/hannah/nas/dataflow/transformations/graph_tranformer.py deleted file mode 100644 index 3b815b96..00000000 --- a/hannah/nas/dataflow/transformations/graph_tranformer.py +++ /dev/null @@ -1,128 +0,0 @@ - -from hannah.nas.core.parametrized import is_parametrized -from hannah.nas.dataflow.dataflow_graph import DataFlowGraph, delete_users, find_first_input, reset_scope_ids -from hannah.nas.dataflow.op_type import OpType - - -class GraphTransformer: - def __init__(self, graph) -> None: - self.graph = graph - - def transform(self, source, target, transform): - first = find_first_input(self.graph) - queue = [first] - visited = [first] - - while queue: - current = queue.pop(-1) - - if isinstance(current, DataFlowGraph) and self.match(source, current): - self.replace_dataflow_graph(current, target, transform) - - if isinstance(current, DataFlowGraph): - for ent in current.enter: - if ent not in visited: - queue = queue + [ent] - visited.append(ent) - - for user in current.users: - if user not in visited: - queue = [user] + queue - visited.append(user) - # if isinstance(current, DataFlowGraph): - # if current.output not in visited: - # queue = [current.output] + queue - # visited.append(current.output) - # elif isinstance(current, OpType): - # for operand in current.operands: - # if operand not in visited: - # queue = [operand] + queue - # visited.append(operand) - - # self.reset_users() - self.graph._scopes = {} - reset_scope_ids(self.graph) - self.graph.set_scopes() - self.graph.collect_scopes() - - def reset_users(self): - delete_users(self.graph) - - queue = [self.graph] - visited = [self.graph] - - while queue: - current = queue.pop(-1) - - if isinstance(current, DataFlowGraph): - current.link_users() - if current.output not in visited: - queue.append(current.output) - visited.append(current.output) - elif isinstance(current, OpType): - current.link_users() - for operand in current.operands: - if operand not in visited: - queue.append(operand) - visited.append(operand) - - def replace_dataflow_graph(self, source, target, transform): - # create new dfg^ - args, kwargs = transform(source, target) - new_block = target(*args, **kwargs) # FIXME: Correct instantiation (parameters etc) - - # new_block is automatically a user of each operand, this is not - # necessarily correct, therefore remove here and add later if needed - for operand in new_block.operands: - operand.users.remove(new_block) - - # source.output.users.remove(source) - # source.output.users.append(new_block) - print() - - for user in source.users: - if user.output == source: - user.output = new_block - del user._PARAMETERS['output'] - if is_parametrized(new_block): - user._PARAMETERS['output'] = new_block - - new_block.users.append(user) - if source in user.operands: - user.operands = list(user.operands) - user.operands.remove(source) - user.operands.append(new_block) - user.operands = tuple(user.operands) - - for i, operand in enumerate(source.operands): - if source in operand.users: - operand.users.remove(source) - if new_block not in operand.users: - operand.users.append(new_block) - pass - if is_parametrized(new_block): - new_block._PARAMETERS[f'operand_{i}'] = operand - - parent_id = ".".join(source.id.split('.')[:-2]) - if source in self.graph._scopes[parent_id].enter: - self.graph._scopes[parent_id].enter.remove(source) - self.graph._scopes[parent_id].enter.append(new_block) - - del source - - def match_by_name(self, name, graph): - if graph.name == name: - return True - else: - return False - - def match_by_equivalence(self, graph_a, graph_b): - pass - - def match(self, graph_a, graph_b): - if isinstance(graph_a, str): - return self.match_by_name(graph_a, graph_b) - elif isinstance(graph_a, DataFlowGraph): - return self.match_by_equivalence(graph_a, graph_b) - else: - raise Exception("Argument 0 must be either str or DataflowGraph but is {}".format(type(graph_a))) diff --git a/hannah/nas/expressions/placeholder.py b/hannah/nas/expressions/placeholder.py index 4a81b6b5..956bb95e 100644 --- a/hannah/nas/expressions/placeholder.py +++ b/hannah/nas/expressions/placeholder.py @@ -1,8 +1,8 @@ # -# Copyright (c) 2022 University of Tübingen. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +# See https://github.com/ekut-es/hannah for further info. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,6 +44,9 @@ class UndefinedInt(Placeholder): def __init__(self, id: Optional[str] = None) -> None: super().__init__(id) + def format(self, indent=2, length=80) -> str: + return self.__class__.__name__ + f"(id={self.id})" + # TODO: class UndefinedFloat(Placeholder): diff --git a/hannah/nas/dataflow/data_type.py b/hannah/nas/functional_operators/data_type.py similarity index 55% rename from hannah/nas/dataflow/data_type.py rename to hannah/nas/functional_operators/data_type.py index 57b5bb97..026f1ee5 100644 --- a/hannah/nas/dataflow/data_type.py +++ b/hannah/nas/functional_operators/data_type.py @@ -23,60 +23,74 @@ class DataType(ABC): @abstractmethod - def bits(self) -> int: - ... + def bits(self) -> int: ... @abstractmethod - def range(self) -> Tuple[Number, Number]: - ... + def range(self) -> Tuple[Number, Number]: ... def as_numpy(self) -> str: return "" class IntType(DataType): - def __init__(self, signed: bool = True, bits: int = 8): - self.signed = signed - self.bits = bits + """Describe the properties of an integer datatype.add() + + Args: + signed (bool, optional): Whether the integer is signed or not. Defaults to True. + bits (int, optional): The number of bits used to represent the integer. Defaults to 8. + reduce_range (bool, optional): Whether to reduce the range of the integer to make the dataranges symmetric around zero. Only applies to signed datatypes. Defaults to Fa + lse. + """ + + def __init__(self, signed: bool = True, bits: int = 8, reduce_range=False): + self._signed = signed + self._bits = bits + self._reduce_range = reduce_range def bits(self) -> int: - return self.bits + return self._bits def signed(self) -> bool: - return self.signed + return self._signed def range(self) -> Tuple[int, int]: - if self.signed: - min_val = -(2 ** (self.bits - 1)) - max_val = 2 ** (self.bits - 1) - 1 + if self._signed: + min_val = -(2 ** (self._bits - 1)) + + if self._reduce_range: + min_val += 1 + max_val = 2 ** (self._bits - 1) - 1 else: min_val = 0 - max_val = 2 ** (self.bits) - 1 + max_val = 2 ** (self._bits) - 1 return (min_val, max_val) def as_numpy(self) -> str: if self.signed: - return f"np.int{self.bits}" + return f"np.int{self._bits}" else: - return f"np.uint{self.bits}" + return f"np.uint{self._bits}" + + def __str__(self): + return f"{'u' if not self._signed else 'i'}{self._bits}" class FloatType(DataType): def __init__(self, signed=True, significand_bits=23, exponent_bits=8): - self.signed = signed - self.significand_bits = significand_bits - self.exponent_bits = exponent_bits + self._signed = signed + self._significand_bits = significand_bits + self._exponent_bits = exponent_bits def bits(self) -> int: - bits = self.significand_bits + self.exponent_bits - if self.signed: + bits = self._significand_bits + self._exponent_bits + if self._signed: bits += 1 return bits def signed(self) -> int: - return self.signed + return self._signed def range(self) -> float: # FIXME: calculate correct range @@ -91,6 +105,9 @@ def range(self) -> float: def as_numpy(self) -> str: return f"float{self.bits()}" + def __str__(self): + return f"f{self.bits()}" + if __name__ == "__main__": fl = FloatType() diff --git a/hannah/nas/functional_operators/executor.py b/hannah/nas/functional_operators/executor.py index 454467be..ea08e49c 100644 --- a/hannah/nas/functional_operators/executor.py +++ b/hannah/nas/functional_operators/executor.py @@ -16,15 +16,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from collections import defaultdict from copy import deepcopy from typing import Iterator, Tuple import math import torch from torch.nn.modules.module import Module -from hannah.nas.functional_operators.op import ChoiceOp, Op, Tensor, get_nodes -from collections import defaultdict from torch.nn.parameter import Parameter +from .op import ChoiceOp, Op, Tensor, get_nodes + class BasicExecutor(torch.nn.Module): def __init__(self, net, input_node_name="input", init=None) -> None: @@ -54,16 +55,20 @@ def initialize_tensor(self, node): if isinstance(node, Tensor): node_name = node.id.replace(".", "_") if node.grad: - if node.name == 'bias': + if node.name == "bias": # get weight data - weight_name = node_name.replace('bias', 'weight') + weight_name = node_name.replace("bias", "weight") weight_param = self.get_parameter(weight_name) - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(weight_param.data) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( + weight_param.data + ) # register bias if fan_in != 0: bound = 1 / math.sqrt(fan_in) data = torch.empty(node.current_shape()) - data = torch.nn.Parameter(torch.nn.init.uniform_(data, -bound, bound)) + data = torch.nn.Parameter( + torch.nn.init.uniform_(data, -bound, bound) + ) self.register_parameter(node_name, data) else: # weight tensor data = torch.empty(node.current_shape()) diff --git a/hannah/nas/functional_operators/op.py b/hannah/nas/functional_operators/op.py index 86de194c..ffe76e70 100644 --- a/hannah/nas/functional_operators/op.py +++ b/hannah/nas/functional_operators/op.py @@ -16,19 +16,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from contextlib import contextmanager import sys from abc import ABC, abstractmethod from copy import deepcopy from functools import wraps -from typing import Any +from typing import Any, List, Mapping import torch from hannah.nas.core.expression import Expression from hannah.nas.core.parametrized import is_parametrized -from hannah.nas.dataflow.data_type import FloatType from hannah.nas.expressions.choice import Choice +from hannah.nas.expressions.placeholder import UndefinedInt from hannah.nas.expressions.utils import extract_parameter_from_expression +from hannah.nas.functional_operators.data_type import FloatType from hannah.nas.functional_operators.lazy import lazy from hannah.nas.parameters.parameters import ( CategoricalParameter, @@ -66,15 +68,6 @@ def get_nodes(node): visited.append(o) -_id = 0 - - -def get_unique_id(): - global _id - _id += 1 - return _id - - def get_highest_scope_counter(scope, scope_dict): if scope in scope_dict: scope_dict[scope] += 1 @@ -111,6 +104,24 @@ def set_scope(*args, **kwargs): return set_scope +@contextmanager +def context(): + """A contextmanager to provide a global scope stack for the hannah ir""" + + global global_scope_stack + + old_stack = globals().get("global_scope_stack", None) + + try: + global_scope_stack = [{}] + yield global_scope_stack + finally: + if old_stack is not None: + global_scope_stack = old_stack + else: + del global_scope_stack + + def search_space(function): """Decorator to define a search space. For correct scoping, a search space containing functional ops must be enclosed by @@ -119,17 +130,43 @@ def search_space(function): @wraps(function) def search_space_limits(*args, **kwargs): - global global_scope_stack - global_scope_stack = [{}] - out = scope(function)(*args, **kwargs) - del global_scope_stack + with context(): + out = scope(function)(*args, **kwargs) + return out return search_space_limits +class BaseNode(ABC): + """ + Base class for all nodes in the operator description, it defines the basic inteface used by all members of the data flow graph. + """ + + operands: List["BaseNode"] = [] + users: List["BaseNode"] = [] + id: str = "" # Fully qualified name of the node, e.g., "net.res.conv1" or "net.res.conv1.weight" + + def size(self, axis: int): + return self.shape()[axis] + + def attributes(self) -> Mapping[str, Any]: + res = {} + for k, v in self.__dict__.items(): + if k.startswith("_"): + continue + if k in ["operands", "users", "id", "name", "executor"]: + continue + + if is_parametrized(self) and k in self._PARAMETERS: + res[k] = v[k] + + res[k] = v + return res + + @parametrize -class Op(torch.nn.Module): +class Op(torch.nn.Module, BaseNode): def __init__(self, name, *args, **kwargs) -> None: super().__init__() self.operands = [] @@ -227,11 +264,16 @@ def get_tensor_data(tensor): @parametrize -class Tensor: +class Tensor(BaseNode): def __init__(self, name, shape, axis, dtype=FloatType(), grad=False) -> None: super().__init__() self.name = name self.id = name + + for num, (ax, size) in enumerate(zip(axis, shape)): + if size is None: + shape[num] = UndefinedInt(f"{name}_{ax}") + self._shape = shape self.dtype = dtype @@ -300,10 +342,6 @@ def __repr__(self): return f"Tensor({self.id})" -# @torch.fx.wrap -# def choice_forward() - - @parametrize class ChoiceOp(Op): def __init__(self, *options, switch=None) -> None: diff --git a/hannah/nas/functional_operators/operators.py b/hannah/nas/functional_operators/operators.py index a908512c..c2712228 100644 --- a/hannah/nas/functional_operators/operators.py +++ b/hannah/nas/functional_operators/operators.py @@ -143,7 +143,7 @@ def self_attention2d(q, k, v, num_heads, d_model, *, id): k: Tensor, shape ``[B, h*d, H, W]`` v: Tensor, shape ``[B, h*d, H, W]`` """ - scale = d_model ** -0.5 + scale = d_model**-0.5 b, _, h, w = q.shape q = q.view(b, num_heads, d_model, h * w) k = k.view(b, num_heads, d_model, h * w) @@ -260,7 +260,10 @@ def shape_fun(self): class Conv2d(Op): def __init__(self, stride=1, dilation=1, groups=1, padding=None) -> None: super().__init__( - name="Conv2d", stride=stride, dilation=dilation, groups=groups, + name="Conv2d", + stride=stride, + dilation=dilation, + groups=groups, ) self.stride = stride self.dilation = dilation @@ -278,6 +281,9 @@ def __call__(self, *operands) -> Any: new_conv.in_channels = input_shape[1] new_conv.out_channels = weight_shape[0] new_conv.kernel_size = weight_shape[2] + assert ( + weight_shape[3] == weight_shape[2] + ), "Only square kernels are supported, at the moment." if self.padding is None: new_conv.padding = padding_expression( new_conv.kernel_size, new_conv.stride, new_conv.dilation @@ -326,10 +332,7 @@ def shape_fun(self): def _forward_implementation(self, input, weight, bias=None): input = torch.flatten(input, start_dim=1) - return linear( - input, weight, bias, - id=self.id - ) + return linear(input, weight, bias, id=self.id) @parametrize @@ -475,7 +478,7 @@ def _forward_implementation(self, *operands): stride=lazy(self.stride), padding=lazy(self.padding), dilation=lazy(self.dilation), - id=self.id + id=self.id, ) @@ -504,7 +507,7 @@ def _forward_implementation(self, *operands): kernel_size=lazy(self.kernel_size), stride=lazy(self.stride), padding=lazy(self.padding), - id=self.id + id=self.id, ) @@ -533,7 +536,7 @@ def _forward_implementation(self, *operands): kernel_size=lazy(self.kernel_size), stride=lazy(self.stride), padding=lazy(self.padding), - id=self.id + id=self.id, ) @@ -551,6 +554,41 @@ def shape_fun(self): return adaptive_average_pooling_shape( *self.operands, output_size=self.output_size ) + return adaptive_average_pooling_shape( + *self.operands, output_size=self.output_size + ) + + def _forward_implementation(self, *operands): + if self.dim == 1: + return adaptive_avg_pooling1d( + operands[0], output_size=self.output_size, id=self.id + ) + else: + return adaptive_avg_pooling2d( + operands[0], output_size=self.output_size, id=self.id + ) + + +@parametrize +class MaxPool2d(Op): + def __init__( + self, kernel_size=3, stride=1, padding=0, dilation=1, ceil_mode=False + ) -> None: + super().__init__(name="MaxPool2d") + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + + def shape_fun(self): + return conv_shape( + *self.operands, + dims=2, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) def _forward_implementation(self, *operands): if self.dim == 1: diff --git a/hannah/nas/functional_operators/quant.py b/hannah/nas/functional_operators/quant.py new file mode 100644 index 00000000..98435ed5 --- /dev/null +++ b/hannah/nas/functional_operators/quant.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import abstractmethod, abstractproperty + +import torch +from torch.ao.quantization import FakeQuantize, FixedQParamsObserver + +from .data_type import DataType +from .op import Op, Tensor + + +def linear_quantize(x, scale, zero_point, dtype): + return torch.round(x / scale) + zero_point + + +class BaseQuantize(Op): + @abstractproperty + def scale(self): + ... + + @abstractproperty + def zero_point(self): + ... + + @abstractproperty + def dtype(self): + ... + + +class FixedQuantize(BaseQuantize): + """A fixed quantizer that quantizes the input tensor to a fixed scale and zero point. + + Args: + scale (float): The scale of the quantized values. + zero_point (float): The zero point of the quantized values. + dtype (DataType): The datatype of the quantized values. + """ + + def __init__(self, scale: float, zero_point: float, dtype: DataType): + super().__init__(self.__class__.__name__) + + self._scale = scale + self._zero_point = zero_point + self._dtype = dtype + + range = dtype.range() + + self.quantizer = FakeQuantize( + observer=FixedQParamsObserver, + scale=scale, + zero_point=zero_point, + dtype=torch.qint8 if dtype.signed else torch.quint8, + quant_min=range[0], + quant_max=range[1], + ) + + def shape_fun(self): + return self.operands[0].shape() + + def _forward_implementation(self, inputs): + assert len(inputs) == 1 + x = inputs[0] + + if not self._train: + return x + + return self.quantizer(x) + + @property + def scale(self): + return self._scale + + @property + def zero_point(self): + return self._zero_point + + @property + def dtype(self): + return self._dtype diff --git a/hannah/nas/functional_operators/utils/visit.py b/hannah/nas/functional_operators/utils/visit.py new file mode 100644 index 00000000..35aa17a9 --- /dev/null +++ b/hannah/nas/functional_operators/utils/visit.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from ..op import BaseNode + + +def post_order(op: BaseNode): + """Visits the operator graph in post order""" + visited = set() + worklist = [op] + while len(worklist) > 0: + current = worklist[-1] + if current in visited: + worklist.pop() + continue + visited.add(current) + for operand in current.operands: + if operand not in visited: + worklist.append(operand) + if current in worklist: + worklist.remove(current) + yield current + + +def reverse_post_order(op: BaseNode): + """Visits the operator graph in reverse post order""" + return reversed(list(post_order(op))) diff --git a/hannah/nas/functional_operators/utils/viz.py b/hannah/nas/functional_operators/utils/viz.py new file mode 100644 index 00000000..9ac4c156 --- /dev/null +++ b/hannah/nas/functional_operators/utils/viz.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2023 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import TYPE_CHECKING + +import networkx as nx + +if TYPE_CHECKING: + from ...functional_operators.op import BaseNode + + +def as_nx_graph(op: "BaseNode") -> nx.DiGraph: + """Returns a networkx graph representation of the operator graph""" + graph = nx.DiGraph() + + visited = set() + worklist = [op] + while len(worklist) > 0: + current = worklist.pop() + if current in visited: + continue + visited.add(current) + for operand in current.operands: + worklist.append(operand) + graph.add_edge(operand.id, current.id) + + return graph + + +def as_string(op: "BaseNode") -> str: + """Returns a string representation of the operator graph""" + return nx.write_network_text(as_nx_graph(op)) + + +def as_dot(op: "BaseNode") -> str: + """Returns a dot representation of the operator graph""" + return nx.nx_pydot.to_pydot(as_nx_graph(op)).to_string() + + +def write_png(op: "BaseNode", filename: str) -> None: + """Writes a png file of the operator graph""" + nx.nx_pydot.to_pydot(as_nx_graph(op)).write_png(filename) + + +def write_pdf(op: "BaseNode", filename: str) -> None: + """Writes a pdf file of the operator graph""" + nx.nx_pydot.to_pydot(as_nx_graph(op)).write_pdf(filename) diff --git a/hannah/nas/functional_operators/visualizer.py b/hannah/nas/functional_operators/visualizer.py deleted file mode 100644 index beb11313..00000000 --- a/hannah/nas/functional_operators/visualizer.py +++ /dev/null @@ -1,32 +0,0 @@ -# import networkx as nx -# import matplotlib.pyplot as plt -# from networkx.drawing.nx_pydot import graphviz_layout - - -# class Visualizer: -# def __init__(self, graph) -> None: -# self.graph = graph -# self.nx_graph = nx.DiGraph() - -# queue = [self.graph] -# visited = [] -# while queue: -# n = queue.pop() -# visited.append(n) -# self.nx_graph.add_node(n.id, type=str(type(n)).split('.')[-1].split('\'')[0]) - -# for operand in n.operands: -# self.nx_graph.add_edge(operand.id, n.id) -# if operand not in visited: -# queue.append(operand) - -# def draw(self): -# pos = graphviz_layout(self.nx_graph, prog="dot", root='input') -# # pos['input'] = (0, 1200) -# self.nx_graph.graph["graph"] = dict(rankdir="LR") -# labels = {} -# for n in self.nx_graph.nodes: -# labels[n] = self.nx_graph.nodes[n]['type'] -# nx.draw(self.nx_graph, pos, node_color='y') -# nx.draw_networkx_labels(self.nx_graph, pos, labels=labels, font_size=8) -# plt.show() diff --git a/hannah/nas/graph_conversion.py b/hannah/nas/graph_conversion.py index 3915bc92..fd23dfeb 100644 --- a/hannah/nas/graph_conversion.py +++ b/hannah/nas/graph_conversion.py @@ -20,7 +20,7 @@ import logging import math from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import networkx as nx import numpy as np @@ -87,7 +87,9 @@ def to_one_hot(val, options): @dataclass class NamedTensor: name: str - tensor: torch.Tensor + tensor: Union[ + torch.Size, torch.Tensor, int + ] # FIXME: this probably is not the intended type quantization: Any = None @@ -507,7 +509,7 @@ def add_nodes_conv_fun(self, target, mod, args, kwargs, output): # weight_attrs = {"quant": None, "shape": args[1].tensor.shape} # FIXME: Bias missing - if hasattr(args[2], 'tensor'): + if hasattr(args[2], "tensor"): bias_attrs = {"quant": None, "shape": args[2].tensor.shape} else: bias_attrs = None @@ -528,7 +530,7 @@ def add_nodes_conv_fun(self, target, mod, args, kwargs, output): input_names = list() for arg in args: - if hasattr(arg, 'name'): + if hasattr(arg, "name"): input_names.append(arg.name) for input_name in input_names: self.nx_graph.add_edge(input_name, name) @@ -550,7 +552,7 @@ def add_nodes_linear_fun(self, target, mod, args, kwargs, output): weight_attrs = {"quant": weight_quant_attrs, "shape": args[1].tensor.shape} # weight_attrs = {"quant": None, "shape": args[1].tensor.shape} - if hasattr(args[2], 'tensor'): + if hasattr(args[2], "tensor"): bias_attrs = {"quant": None, "shape": args[2].tensor.shape} else: bias_attrs = None @@ -571,7 +573,7 @@ def add_nodes_linear_fun(self, target, mod, args, kwargs, output): input_names = list() for arg in args: - if hasattr(arg, 'name'): + if hasattr(arg, "name"): input_names.append(arg.name) for input_name in input_names: self.nx_graph.add_edge(input_name, name) diff --git a/hannah/nas/hardware_description/__init__.py b/hannah/nas/hardware_description/__init__.py index e69de29b..bbbe8734 100644 --- a/hannah/nas/hardware_description/__init__.py +++ b/hannah/nas/hardware_description/__init__.py @@ -0,0 +1,21 @@ +# +# Copyright (c) 2023 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .device import Device + +__all__ = ["Device"] diff --git a/hannah/nas/hardware_description/__main__.py b/hannah/nas/hardware_description/__main__.py new file mode 100644 index 00000000..0fdf6ab5 --- /dev/null +++ b/hannah/nas/hardware_description/__main__.py @@ -0,0 +1,112 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" Implementation of hannah tooling for handling hardware descriptions, and generating backends. """ + + +import argparse +import sys + +import rich + +# import all devices to register them with the registry +from . import devices as _devices # noqa: F401 # pylint: disable=unused-import +from .device import Device +from .registry import devices + + +def add_args(parser, known_args): + for device in devices: + if device in known_args and device.name == known_args.device: + pass + # device.add_args(parser, known_args) + + return parser + + +def list(): + for device in devices: + print(device.name) + + +def export(args): + found = False + for device in devices: + if device.name == args.device: + if args.backend == "tvm": + from hannah.nas.hardware_description.backend import TVMBackend + + device: Device = device() + backend = TVMBackend() + result = backend.generate(device) + print(result) + + found = True + break + if not found: + print(f"Device {args.device} not found.") + print("Available devices:") + list() + sys.exit(1) + + +def describe(args): + found = False + import rich.markdown + + from hannah.nas.hardware_description.backend import MarkdownBackend + + for device in devices: + if device.name == args.device: + found = True + device: Device = device() + backend = MarkdownBackend() + rich.print(rich.markdown.Markdown(backend.generate(device))) + break + if not found: + print(f"Device {args.device} not found.") + print("Available devices:") + list() + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser(description="Hardware description generator.") + command_parsers = parser.add_subparsers(dest="command") + command_parsers.add_parser("list", help="List all available devices.") + export_parser = command_parsers.add_parser( + "export", help="Generate a hardware backend." + ) + export_parser.add_argument("device", help="Device to generate backend for.") + export_parser.add_argument("output", help="Output directory.") + export_parser.add_argument("--backend", help="Backend to use.", default="tvm") + + describe_parser = command_parsers.add_parser("describe", help="Describe a device.") + describe_parser.add_argument("device", help="Device to describe.") + + known_args, _ = parser.parse_known_args() + if known_args.command is not None: + parser = add_args(parser, known_args) + + args = parser.parse_args() + if args.command == "list": + list() + elif args.command == "export": + export(args) + elif args.command == "describe": + describe(args) diff --git a/hannah/nas/hardware_description/backend/__init__.py b/hannah/nas/hardware_description/backend/__init__.py new file mode 100644 index 00000000..81557080 --- /dev/null +++ b/hannah/nas/hardware_description/backend/__init__.py @@ -0,0 +1,29 @@ +# +# Copyright (c) 2023 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Backends for the hannah target descriptions. + +These backend descritptions can be used to generate translate the target descritption to different data formats. +""" + + +from .hannah import HannahBackend +from .markdown import MarkdownBackend +from .tvm import TVMBackend + +__all__ = ["MarkdownBackend", "HannahBackend", "TVMBackend"] diff --git a/hannah/nas/hardware_description/backend/base.py b/hannah/nas/hardware_description/backend/base.py new file mode 100644 index 00000000..6c6d7d2b --- /dev/null +++ b/hannah/nas/hardware_description/backend/base.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2023 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC, abstractmethod +from typing import Any, TypeVar + +from hannah.nas.hardware_description.device import Device + + +class DescriptionBackend(ABC): + """Abstract base class for generating tool specific descriptions from target devices.""" + + @abstractmethod + def generate(self, device: Device) -> Any: + """Generates a tool specific description from a target device meta model.""" + pass + + def save(self, device: Device, path: str) -> None: + """Saves a tool specific description to a file. if supported by the backend.""" + + raise NotImplementedError(f"Saving is not supported by {self}.") + + def __str__(self) -> str: + return self.__class__.__name__ diff --git a/hannah/nas/hardware_description/backend/hannah.py b/hannah/nas/hardware_description/backend/hannah.py new file mode 100644 index 00000000..8f12fe49 --- /dev/null +++ b/hannah/nas/hardware_description/backend/hannah.py @@ -0,0 +1,232 @@ +# +# Copyright (c) 2023 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Translates target patterns to be applied on the hannah neural network search space descriptions. +""" + +import logging +from typing import List, Sequence + +from hannah.nas.core.expression import Expression +from hannah.nas.expressions.placeholder import UndefinedInt +from hannah.nas.functional_operators.op import BaseNode, Op, Tensor +from hannah.nas.parameters import CategoricalParameter + +from .base import DescriptionBackend +from .utils import all_nodes + + +class MatchedRegion: + """A matched region of a hannah pattern.""" + + def __init__(self, pattern, nodes): + self.pattern = pattern + self.nodes = nodes + + def __str__(self): + return f"MatchedRegion(pattern={self.pattern.name}, nodes={[n.id for n in self.nodes]})" + + +class HannahPattern: + """Pattern for hannah search space descriptions.""" + + def __init__(self, name, pattern: BaseNode, condition): + self.name = name + self._pattern = pattern + self._condition = condition + + def match(self, nodes): + """Matches the pattern on a search space.""" + + matches = [] + + for node in nodes: + match = self._match_node(node) + if match: + match = MatchedRegion(self._pattern, match) + matches.append(match) + return matches + + def _match_node(self, node) -> List[BaseNode]: + """Matches the pattern on a single node. + Returns a list of nodes in the matched subexpression in post-order. + """ + partial_match = [] + + worklist = [(node, self._pattern)] + while worklist: + current_node, current_pattern = worklist.pop() + print("matching:", current_node.id, current_pattern.id) + + # Dispatch over the different types of nodes + if isinstance(current_node, Op): + matches = self._match_op(current_node, current_pattern) + if not matches: + return [] + elif isinstance(current_node, Tensor): + matches = self._match_tensor(current_node, current_pattern) + if not matches: + return [] + partial_match.append(current_node) + + print("partial match:", [n.id for n in partial_match]) + + if current_pattern.operands: + if len(current_node.operands) != len(current_pattern.operands): + return [] + + for operand, pattern_operand in zip( + current_node.operands, current_pattern.operands + ): + worklist.append((operand, pattern_operand)) + + return partial_match + + def _match_op(self, node: Op, pattern: Op) -> bool: + """Matches an op node.""" + + # Iterate over attributes of op node + for attr in node.__dict__: + if attr.startswith("_"): + continue + + if attr == "operands": + continue + + if attr == "users": + continue + + # FIXME: use regex match here + if attr == "id": + continue + if attr == "name": + continue + if attr == "scope": + continue + + # print("Matching attributes: ", attr) + if not self._is_subset(getattr(node, attr), getattr(pattern, attr)): + print("Failed to match attributes: ", attr) + return False + + return True + + def _match_tensor(self, node: BaseNode, pattern: Tensor) -> bool: + """Matches a tensor pattern, the node pattern can still be an op. + + In the case of matching against an op pattern we only check the shape. + """ + + if isinstance(pattern, Op): + logging.critical( + "Matching a tensor against an op pattern is not implemented correctly, shapes might not match" + ) + return True + + if len(node.shape()) != len(pattern.shape()): + return False + else: + for dim, pattern_dim in zip(node.shape(), pattern.shape()): + if not self._is_subset(dim, pattern_dim): + return False + + return True + + def _is_subset(self, node_attr, pattern_attr): + """Checks if a node attribute is a subset of the pattern attribute.""" + + print("Matching attribute sets: ", node_attr, pattern_attr) + + if isinstance(node_attr, CategoricalParameter): + node_set = set(node_attr.choices) + if isinstance(pattern_attr, UndefinedInt): + return True + elif hasattr(pattern_attr, "values"): + pattern_set = set(pattern_attr.values) + elif isinstance(pattern_attr, Expression): + logging.critical( + "Matching a categorical parameter against an expression is not implemented correctly" + ) + elif isinstance(pattern_attr, int): + pattern_set = set([pattern_attr]) + + if not (node_set & pattern_set): + return False + else: + # FIXME: this is a hack to make sure that the choices are only the ones in the pattern + # FIXME: this can result in unexpected behaviour if the current_value is not set to one of the restricted values + node_attr.choices = [x for x in node_attr.choices if x in pattern_set] + elif isinstance(node_attr, Expression): + return True + elif isinstance(node_attr, int): + if isinstance(pattern_attr, int): + if node_attr != pattern_attr: + return False + elif isinstance(pattern_attr, UndefinedInt): + return True + else: + logging.critical( + f"Matching an int against an expression of type: {type(pattern_attr)} is not implemented correctly" + ) + return True + + return True + + return True + + +class HannahMatcher: + """Matcher for hannah patterns.""" + + def __init__(self, name: str, patterns: List[HannahPattern]): + self.name = name + self._patterns = patterns + self._matched_regions = [] + + def run(self, search_space): + """Runs the matcher on a search space.""" + + nodes = all_nodes(search_space) + for pattern in self._patterns: + self._matched_regions.extend(pattern.match(nodes)) + + return self._matched_regions + + @property + def matches(self) -> Sequence[MatchedRegion]: + """Returns the matched regions.""" + return self._matched_regions + + +class HannahBackend(DescriptionBackend): + """Generator for hannah data sheets from target devices.""" + + def __init__(self): + super().__init__() + + def generate(self, device) -> HannahMatcher: + """Generates a hannah description from a target device meta model.""" + + patterns = [] + for name, op, cond in device.ops: + patterns.append(HannahPattern(name, op, cond)) + + backend = HannahMatcher(device.name, patterns) + + return backend diff --git a/hannah/nas/hardware_description/backend/markdown.py b/hannah/nas/hardware_description/backend/markdown.py new file mode 100644 index 00000000..67d42a80 --- /dev/null +++ b/hannah/nas/hardware_description/backend/markdown.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Translating target descriptions to markdown data sheets.""" + +from textwrap import dedent, wrap + +from .base import DescriptionBackend + + +class MarkdownBackend(DescriptionBackend): + """Generator for markdown data sheets from target devices.""" + + def __init__(self, textwidth: int = 80): + super().__init__() + self.textwidth = textwidth + + def generate(self, device) -> str: + """Generates a markdown description from a target device meta model.""" + + text = f"# {device.name}\n\n" + if device.description: + desc = dedent(device.description) + text += f"{desc}\n\n" + + text += "## Architecture\n" + + text += "### Processing Element\n" + + text += "### Memory Hierarchy\n" + + memory_table = "Scope | Size | Latency | Orchestration | Bandwidth | Energy | Area | Ports | \n" + memory_table += "-----| -----| --------| ------------- |-----------|--------|------|-------| \n" + for mem in device.memories: + memory_table += f"{mem.scope} | {mem.size} | {mem.latency} | {mem.management.value},{mem.coupling.value} | r: {mem.read_bw} w: {mem.write_bw} | r: {mem.read_energy} w: {mem.write_energy} i: {mem.idle_energy} | {mem.area} | r: {mem.read_port} w: {mem.write_port} rw: {mem.rw_port} | \n" + text += memory_table + "\n" + + text += "## Supported Operations\n\n" + + for op in device.ops: + text += f"{op.markdown()}\n" + + return text diff --git a/hannah/nas/hardware_description/backend/tvm.py b/hannah/nas/hardware_description/backend/tvm.py new file mode 100644 index 00000000..8e992130 --- /dev/null +++ b/hannah/nas/hardware_description/backend/tvm.py @@ -0,0 +1,129 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +from ...functional_operators.operators import Choice, Op, Tensor +from ..device import Device, TargetOp +from .base import DescriptionBackend + +logger = logging.getLogger(__name__) + + +_TVM_OP_TABLE = { + "Conv1d": "nn.conv1d", + "Conv2d": "nn.conv2d", + "Add": "add", + "Relu": "nn.relu", + "Linear": "nn.linear", + "MaxPool2d": "nn.max_pool2d", +} + + +class TVMBackend(DescriptionBackend): + def __init__(self): + self._device = None + self._pattern_table = "" + + def generate(self, device: Device): + self._device = device + self._pattern_table = "from tvm import relay\n" + self._pattern_table += "from tvm.relay import Expr\n" + self._pattern_table += ( + "from tvm.relay.dataflow_pattern import wildcard, is_op\n" + ) + self._pattern_table += ( + "from tvm.relay.op.contrib.register import register_pattern_table\n" + ) + self._pattern_table += "\n\n" + + for op in device.ops: + self._pattern_table += self._handle_graph(op) + self._pattern_table += "\n\n" + + self._pattern_table += f'@register_pattern_table("{self._device.name}")\n' + self._pattern_table += "def pattern_table():\n" + self._pattern_table += " return [\n" + + for op in device.ops: + self._pattern_table += ( + f" ({op.name}.name, {op.name}.pattern(), {op.name}.check),\n" + ) + + self._pattern_table += " ]\n" + + print(self._pattern_table) + + # compiled = compile(self._pattern_table, "", "exec") + # mod = exec(compiled, globals(), locals()) + + return self._pattern_table + + def _handle_graph(self, op: TargetOp) -> str: + worklist = [op.graph] + ops = [] + + while worklist: + current_op = worklist.pop() + if current_op in ops: + continue + + ops.append(current_op) + + for child in current_op.operands: + worklist.append(child) + + result = f"class {op.name}:\n" + result += f' name = "{self._device.name}.{op.name}" \n' + + # generate Pattern + result += " @classmethod\n" + result += " def pattern(cls):\n" + + id_table = {} + for num, op in enumerate(reversed(ops)): + op_id = f"{op.name}_{num}".lower() + + id_table[op] = op_id + + if isinstance(op, Tensor): + matcher = "wildcard()" # FIXME: handle consts + elif isinstance(op, Choice): + matcher = "||".join([f"{id_table[o]}" for o in op.options]) + elif isinstance(op, Op): + if op.name not in _TVM_OP_TABLE: + raise NotImplementedError(f"Unsupported op {op.name}") + + tvm_name = _TVM_OP_TABLE[op.name] + matcher = f'is_op("{tvm_name}")(' + for operand in op.operands[:-1]: + matcher += f"{id_table[operand]}, " + matcher += f"{id_table[op.operands[-1]]})" + else: + raise NotImplementedError(f"Unsupported op {op.name}") + + result += f" {op_id} = {matcher}\n" + + result += f" return {id_table[ops[0]]}\n" + + # Generate Verification Code + result += " @classmethod\n" + result += ' def check(cls, expr: "Expr") -> bool:\n' + result += " return True\n" + + return result diff --git a/hannah/nas/hardware_description/backend/utils.py b/hannah/nas/hardware_description/backend/utils.py new file mode 100644 index 00000000..1a634cd3 --- /dev/null +++ b/hannah/nas/hardware_description/backend/utils.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2023 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Utilities for the hardware description backend generation. + +Most of them should be moved to the general search space module. +""" + +from hannah.nas.functional_operators.op import BaseNode + + +def all_nodes(search_space: BaseNode): + """ + Return all nodes in the search space. + """ + + nodes = set() + worklist = [search_space] + + while worklist: + node = worklist.pop() + nodes.add(node) + worklist.extend(node.operands) + + return nodes diff --git a/hannah/nas/hardware_description/description.py b/hannah/nas/hardware_description/description.py deleted file mode 100644 index dc7ba929..00000000 --- a/hannah/nas/hardware_description/description.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Optional, Tuple - -from hannah.nas.dataflow.axis_type import AxisType -from hannah.nas.dataflow.compression_type import CompressionType -from hannah.nas.dataflow.data_type import DataType, FloatType, IntType -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.optional_op import OptionalOp -from hannah.nas.dataflow.quantization_type import QuantizationType -from hannah.nas.dataflow.tensor_type import TensorType -from hannah.nas.hardware_description.device import Ultratrail -from hannah.nas.hardware_description.memory_type import MemoryType -from hannah.nas.parameters import IntScalarParameter - -if __name__ == "__main__": - ultratrail = Ultratrail( - weight_bits=IntScalarParameter(min=1, max=8), - bias_bits=IntScalarParameter(min=1, max=8), - activation_bits=IntScalarParameter(min=1, max=8), - accumulator_bits=IntScalarParameter(min=1, max=32), - max_weight_bits=IntScalarParameter(min=4, max=8), - ) - - print(ultratrail) diff --git a/hannah/nas/hardware_description/device.py b/hannah/nas/hardware_description/device.py index f332a710..df3ca2d3 100644 --- a/hannah/nas/hardware_description/device.py +++ b/hannah/nas/hardware_description/device.py @@ -1,33 +1,179 @@ -from abc import ABC, abstractmethod -from typing import List +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from functools import wraps +import logging +from abc import ABCMeta, abstractmethod +from typing import List, NamedTuple, Sequence, Optional + +from hannah.nas.core.parametrized import is_parametrized +from hannah.nas.functional_operators.op import context -from ..dataflow.dataflow_graph import DataFlowGraph, dataflow -from ..dataflow.op_type import OpType from ..expressions.placeholder import IntRange, UndefinedFloat, UndefinedInt -from ..hardware_description.memory_type import MemoryType -from ..ops import ( - add, - avg_pool, - axis, - broadcast, - int_t, - optional, - quantization, - relu, - requantize, - tensor, -) +from ..functional_operators.utils.visit import reverse_post_order +from ..hardware_description.memory_type import CouplingType, ManagementType, MemoryType from ..parameters.parametrize import parametrize +from .registry import devices +logger = logging.getLogger(__name__) + +Constraint = any +DataFlowGraph = any + + +class TargetOp(NamedTuple): + name: str + graph: DataFlowGraph + constraints: Sequence[Constraint] + + def markdown(self) -> str: + res = "### " + self.name + "\n" + + ids = {} + res += "\nGraph:\n" + res += "```mlir\n" + for num, node in enumerate(reverse_post_order(self.graph)): + node_list = [] + ids[node.id] = f"%{node.id}_{num}" + for operand in node.operands: + node_list.append(ids[operand.id]) + + for attr, value in node.attributes().items(): + + def value_to_str(value): + if isinstance(value, IntRange): + value_str = f"{value.min}..{value.max}" + elif isinstance(value, UndefinedInt): + value_str = f"?{str(value.id)}" + elif isinstance(value, UndefinedFloat): + value_str = f"?{str(value.id)}" + elif isinstance(value, Sequence) and not isinstance(value, str): + value_str = ( + "[" + ", ".join(value_to_str(x) for x in value) + "]" + ) + else: + value_str = str(value) + return value_str + + value_str = value_to_str(value) + + node_list.append(f"{attr}={value_str}") + + node_str = f"{node.name}({', '.join(node_list)})" + + res += f"%{node.id}_{num} = {node_str}\n" + res += "```\n" + + res += "\nConstraints:\n" + ids = {} + for constraint in self.constraints: + res += "- " + str(constraint) + "\n" + + return res + + +class DeviceMeta(ABCMeta): + def __new__(mcls, name, bases, namespace, /, **kwargs): + cls = super().__new__(mcls, name, bases, namespace, **kwargs) + if not hasattr(cls, "name"): + cls.name = name + + devices.register(cls) + return cls + + def __init__(self, name, bases, namespace): + super().__init__(name, bases, namespace) + + # Add decorator to the local init_function + if hasattr(self, "__init__"): + init_method = getattr(self, "__init__") + + @wraps(init_method) + def init_wrapper(self, *args, **kwargs): + with context(): + init_method(self, *args, **kwargs) + + setattr(self, "__init__", init_wrapper) + + +class Device(metaclass=DeviceMeta): + name: str = "" + description: str = "" + _ops: List[TargetOp] + _memories: List[MemoryType] -class Device(ABC): def __init__(self) -> None: super().__init__() - self._ops = [] + self._ops: List[TargetOp] = [] self._memories = [] + def add_memory( + self, + scope: str, + size: int, + latency: int, + wordwidth: int = 8, + management: ManagementType = ManagementType.EXPLICIT, + coupling: CouplingType = CouplingType.COUPLED, + read_bw: int = 10, + write_bw: int = 10, + read_energy: int = 10, + write_energy: int = 10, + idle_energy: int = 10, + area: int = 10, + read_port: int = 10, + write_port: int = 10, + rw_port: int = 10, + ) -> None: + self._memories.append( + MemoryType( + scope=scope, + size=size, + latency=latency, + management=management, + coupling=coupling, + read_bw=read_bw, + write_bw=write_bw, + read_energy=read_energy, + write_energy=write_energy, + idle_energy=idle_energy, + area=area, + read_port=read_port, + write_port=write_port, + rw_port=rw_port, + ) + ) + + def add_op( + self, + name: str, + graph: DataFlowGraph, + constraints: Optional[Sequence[Constraint]] = None, + ) -> None: + """Adds an operation to the device.""" + + if constraints is None: + constraints = [] + + self._ops.append(TargetOp(name, graph, constraints)) + @property - def ops(self) -> List[DataFlowGraph]: + def ops(self) -> Sequence[TargetOp]: return self._ops @property @@ -39,151 +185,7 @@ def __str__(self): res += "Ops:\n" for op in self.ops: res += str(op) + "\n" + for memory in self.memories: res += str(memory) + "\n" return res - - -@dataflow -def conv(input, weight, stride): - return OpType("conv1d", input, weight, stride=stride) - - -def ut_op( - weight_bits: int = 8, - bias_bits: int = 8, - activation_bits: int = 8, - accumulator_bits: int = 8, - max_weight_bits: int = 8, - max_kernel_size: int = 2**4, - max_input_length: int = 2**7, - max_input_channel_block: int = 2**4, - max_output_channel_block: int = 2**4, - stride_range: int = 2**3, -): - - input_data_type = int_t(signed=True, bits=activation_bits) - input_quantization = quantization(scale=UndefinedFloat(), zero_point=0) - input = tensor( - ( - axis("n", size=1), - axis("c", size=UndefinedInt()), - axis("w", size=UndefinedInt()), - ), - dtype=input_data_type, - quantization=input_quantization, - ) - - weight_data_type = int_t(signed=True, bits=weight_bits) - weight_quantization = quantization(scale=UndefinedFloat(), zero_point=0) - - weight = tensor( - ( - axis("o", size=UndefinedInt()), - axis("i", size=UndefinedInt()), - axis("kw", size=IntRange(1, max_kernel_size)), - ), - dtype=weight_data_type, - quantization=weight_quantization, - ) - - res_input = tensor( - ( - axis("n", size=1), - axis("c", size=UndefinedInt()), - axis("w", size=UndefinedInt()), - ), - dtype=input_data_type, - quantization=input_quantization, - ) - - conv_out = conv(input, weight, stride=stride_range) - - accumulator_data_type = int_t(signed=True, bits=accumulator_bits) - accumulator_quantization = quantization(scale=UndefinedFloat(), zero_point=0) - quant_conv = requantize( - conv_out, dtype=accumulator_data_type, quantization=accumulator_quantization - ) - - bias_data_type = int_t(signed=True, bits=bias_bits) - bias_quantization = quantization(scale=UndefinedFloat(), zero_point=0) - bias = tensor( - (axis("c", size=UndefinedInt()),), - dtype=bias_data_type, - quantization=bias_quantization, - ) - - bias_add = optional( - add(quant_conv, broadcast(bias, axis=((axis("n"))))), quant_conv - ) # FIXME: define broadcasting - res_add = optional(add(bias_add, res_input), bias_add) - pool = optional(avg_pool(res_add), res_add) - activation = optional(relu(pool), pool) - requantization = requantize( - activation, dtype=input_data_type, quantization=input_quantization - ) - - return DataFlowGraph(inputs=[input, weight], output=requantization) - - -class HardwareOp(DataFlowGraph): - ... - # performance & energy modelling (e.g., hints) - # memory mapping, alignment constraints - # - - -@parametrize -class Ultratrail(Device): - def __init__( - self, - weight_bits: int = 6, - bias_bits: int = 8, - activation_bits: int = 8, - accumulator_bits: int = 20, - max_weight_bits: int = 8, - rows: int = 8, - cols: int = 8, - ifmap_bits: int = 7, - ic_block_bits: int = 4, - oc_block_bits: int = 4, - kernel_size_bits: int = 4, - stride_bits: int = 3, - ) -> None: - super().__init__() - - self.weight_bits = weight_bits - self.bias_bits = bias_bits - self.activation_bits = activation_bits - self.accumulator_bits = accumulator_bits - self.max_weight_bits = max_weight_bits - self.rows = rows - self.cols = cols - self.ifmap_bits = ifmap_bits - self.ic_block_bits = ic_block_bits - self.oc_block_bits = oc_block_bits - self.kernel_size_bits = kernel_size_bits - self.stride_bits = stride_bits - - max_kernel_size = 2**self.kernel_size_bits - max_input_length = 2**self.ifmap_bits - max_input_channel_block = self.rows * 2**self.ic_block_bits - max_output_channel_block = self.cols * 2**self.oc_block_bits - - stride_range = 2**2**stride_bits - - # self.cond(stride <= 2**2**S_BIT and is_power_of_2(stride)) - - op = ut_op( - weight_bits=self.weight_bits, - bias_bits=self.bias_bits, - activation_bits=self.activation_bits, - accumulator_bits=self.accumulator_bits, - max_weight_bits=self.max_weight_bits, - max_kernel_size=max_kernel_size, - max_input_length=max_input_length, - max_input_channel_block=max_input_channel_block, - max_output_channel_block=max_output_channel_block, - stride_range=stride_range, - ) - self._ops.append(op) diff --git a/hannah/nas/hardware_description/devices/__init__.py b/hannah/nas/hardware_description/devices/__init__.py new file mode 100644 index 00000000..ed188248 --- /dev/null +++ b/hannah/nas/hardware_description/devices/__init__.py @@ -0,0 +1,19 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from . import eyeriss, simple, vanilla # , ultratrail diff --git a/hannah/nas/hardware_description/devices/eyeriss.py b/hannah/nas/hardware_description/devices/eyeriss.py new file mode 100644 index 00000000..8d0f1308 --- /dev/null +++ b/hannah/nas/hardware_description/devices/eyeriss.py @@ -0,0 +1,121 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from hannah.nas.expressions.placeholder import UndefinedInt +from hannah.nas.functional_operators.op import search_space +from hannah.nas.functional_operators.operators import Conv2d, Tensor + +from ..device import Device + + +class EyerissDevice(Device): + name = "eyeriss_v1" + description = """ + Eyeriss version 1 hwa description. + + https://courses.cs.washington.edu/courses/cse550/21au/papers/CSE550.Eyeriss.pdf + """ + + def __init__(self, precision=8): + super().__init__() + + self._add_conv2d() + self._add_memories() + + def _add_conv2d(self): + N = UndefinedInt("N") + C = UndefinedInt("C") + H = UndefinedInt("H") + W = UndefinedInt("W") + + input = Tensor("input", shape=[N, C, H, W], axis=["N", "C", "H", "W"]) + + OC = UndefinedInt("O") + IC = UndefinedInt("I") + kh = UndefinedInt("kh") + kw = UndefinedInt("kw") + + weight = Tensor( + "weight", + shape=[OC, IC, kh, kw], + axis=["O", "I", "kh", "kW"], + grad=True, + ) + + sh = UndefinedInt("vertical_stride") + sw = UndefinedInt("horizonzal_stride") + + conv = Conv2d(stride=(), padding=0, dilation=1, groups=1)(input, weight) + + self.add_op( + "conv2d", + conv, + [ + sh <= 4, + sh % 2 == 0, + sh > 1, + sw >= 1, + sw <= 12, + C <= 1, + kh <= 12, + kw <= 32, + IC <= 1024, + OC <= 1024, + ], + ) + + def _add_memories(self): + self.add_memory( + "ifmap_buffer", + wordwidth=16, + size=12, + latency=1, + read_port=1, + write_port=1, + read_bw=1, + write_bw=1, + ) + + self.add_memory( + "weight_buffer", + wordwidth=16, + size=224, + latency=1, + read_port=1, + write_port=1, + read_bw=1, + write_bw=1, + ) + + self.add_memory( + "ps_buffer", + wordwidth=16, + size=32, + latency=1, + read_port=1, + write_port=1, + read_bw=1, + write_bw=1, + ) + + self.add_memory( + "global_buffer", + wordwidth=16, + size=1024 * 12, + latency=1, + ) diff --git a/hannah/nas/hardware_description/devices/simple.py b/hannah/nas/hardware_description/devices/simple.py new file mode 100644 index 00000000..7d4d4135 --- /dev/null +++ b/hannah/nas/hardware_description/devices/simple.py @@ -0,0 +1,11 @@ +from ..device import Device + +class SimpleDevice(Device): + name = "simple_hwa" + description = "A simple Abstract Hardware Device with Conv2d -> ReLU acceleration and configurable precision" + + def __init__(self, precision=8): + super().__init__() + self.precision = precision + + \ No newline at end of file diff --git a/hannah/nas/hardware_description/devices/vanilla.py b/hannah/nas/hardware_description/devices/vanilla.py new file mode 100644 index 00000000..a9b48648 --- /dev/null +++ b/hannah/nas/hardware_description/devices/vanilla.py @@ -0,0 +1,57 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from hannah.nas.expressions.placeholder import UndefinedInt +from hannah.nas.functional_operators.op import ChoiceOp +from hannah.nas.functional_operators.operators import Conv2d, Tensor +from hannah.nas.hardware_description import Device +from hannah.nas.functional_operators.op import context, search_space + +class VanillaAccelerator(Device): + name = "vanilla_accelerator" + description = "A simple Abstract Hardware Device only supporting 2d convolutions with a stride of 1 and same padding" + + + def __init__(self): + super().__init__() + self._add_conv2d() + + self.add_memory( + "local", + size=1024 * 10, + latency=1, + ) + + def _add_conv2d(self): + input = Tensor( + "input", shape=[None, None, None, None], axis=["N", "C", "H", "W"] + ) # NCHW tensor format + weight = Tensor( + "weight", + shape=[None, None, None, None], + axis=["O", "I", "kH", "kW"], + grad=True, + ) + padding = UndefinedInt("padding") + conv = Conv2d(stride=1, padding=padding, dilation=1, groups=1)(input, weight) + + self.add_op( + "conv2d", + conv, + [padding == weight.size(2) // 2, padding == weight.size(3) // 2], + ) diff --git a/hannah/nas/hardware_description/memory_type.py b/hannah/nas/hardware_description/memory_type.py index f86ee173..899a29ae 100644 --- a/hannah/nas/hardware_description/memory_type.py +++ b/hannah/nas/hardware_description/memory_type.py @@ -1,7 +1,68 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from enum import Enum from typing import Optional +class ManagementType(Enum): + EXPLICIT = "explicit" + IMPLICIT = "implicit" + + +class CouplingType(Enum): + DECOUPLED = "decoupled" + COUPLED = "coupled" + + class MemoryType: - def __init__(self, size: Optional[int] = None, name: Optional[str] = "") -> None: + def __init__( + self, + size: int, + scope: str, + read_bw: int, + write_bw: int, + read_energy: int = 0, + write_energy: int = 0, + idle_energy: int = 0, + latency: int = 1, + area: int = 1, + read_port=1, + write_port=1, + rw_port=0, + management: ManagementType = ManagementType.EXPLICIT, + coupling: CouplingType = CouplingType.DECOUPLED, + ) -> None: self.size = size - self.name = name + self.scope = scope + self.management = management + self.coupling = coupling + self.size = size # Size in bytes + self.read_bw = read_bw + self.write_bw = write_bw + self.read_energy = read_energy + self.write_energy = write_energy + self.idle_energy = idle_energy + self.latency = latency + self.area = area + self.read_port = read_port + self.write_port = write_port + self.rw_port = rw_port + + # TODO: add more attributes from the memory model + # min_r_granularity, min_w_granularity is used for zigzag, for gpus we should model shared memory bank conflicts and global memory coalescing diff --git a/hannah/nas/hardware_description/notes.md b/hannah/nas/hardware_description/notes.md deleted file mode 100644 index e69de29b..00000000 diff --git a/hannah/nas/hardware_description/performance_prediction.py b/hannah/nas/hardware_description/performance_prediction.py new file mode 100644 index 00000000..cc8b1f1d --- /dev/null +++ b/hannah/nas/hardware_description/performance_prediction.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ..performance_prediction.zigzag import ZigZagConfig, ZigZagPredictor diff --git a/hannah/nas/hardware_description/registry.py b/hannah/nas/hardware_description/registry.py new file mode 100644 index 00000000..0455ce29 --- /dev/null +++ b/hannah/nas/hardware_description/registry.py @@ -0,0 +1,3 @@ +from hannah.utils.registry import Registry + +devices = Registry('devices') \ No newline at end of file diff --git a/hannah/nas/hardware_description/testing/__init__.py b/hannah/nas/hardware_description/testing/__init__.py new file mode 100644 index 00000000..e97dedc2 --- /dev/null +++ b/hannah/nas/hardware_description/testing/__init__.py @@ -0,0 +1,21 @@ +# +# Copyright (c) 2023 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .device import get_device + +__all__ = ["get_device"] diff --git a/hannah/nas/hardware_description/testing/device.py b/hannah/nas/hardware_description/testing/device.py new file mode 100644 index 00000000..fd0d2c34 --- /dev/null +++ b/hannah/nas/hardware_description/testing/device.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"A very simple device description containing supporting only conv2d -> relu -> max_pool2d" + +from hannah.nas.expressions.placeholder import UndefinedInt +from hannah.nas.functional_operators.op import ChoiceOp +from hannah.nas.functional_operators.operators import Conv2d, MaxPool2d, Relu, Tensor +from hannah.nas.hardware_description import Device + +Dyn = UndefinedInt("Dyn") + + +class SimpleDevice(Device): + name = "simple_device" + description = "A simple Abstract Hardware Device only supporting 2d convolutions with a stride of 1 and same padding" + + def __init__(self, relu=False): + super().__init__() + self._add_conv2d() + if relu: + self._add_relu() + self._add_max_pool2d() + + self.add_memory( + "local", + size=1024 * 10, + latency=1, + ) + + def _add_conv2d(self): + input = Tensor( + "input", shape=[None, None, None, None], axis=["N", "C", "H", "W"] + ) + weight = Tensor( + "weight", + shape=[None, None, None, None], + axis=["O", "I", "kH", "kW"], + grad=True, + ) + padding = Dyn + conv = Conv2d(stride=1, padding=padding, dilation=1, groups=1)(input, weight) + + self.add_op( + "conv2d", + conv, + [padding == weight.size(2) // 2, padding == weight.size(3) // 2], + ) + + def _add_relu(self): + input = Tensor( + "input", shape=[None, None, None, None], axis=["N", "C", "H", "W"] + ) + relu = Relu()(input) + + self.add_op("relu", relu) + + def _add_max_pool2d(self): + input = Tensor( + "input", shape=[None, None, None, None], axis=["N", "C", "H", "W"] + ) + max_pool = MaxPool2d(kernel_size=2, stride=2, padding=0)(input) + + self.add_op("max_pool2d", max_pool) + + +def get_device(name, relu=False, *args, **kwargs): + simple_device = SimpleDevice() + simple_device.name = name + if relu: + simple_device._add_relu() + + return simple_device diff --git a/hannah/nas/hardware_description/ultratrail.py b/hannah/nas/hardware_description/ultratrail.py deleted file mode 100644 index c7ad6e70..00000000 --- a/hannah/nas/hardware_description/ultratrail.py +++ /dev/null @@ -1,105 +0,0 @@ -from hannah.nas.dataflow.data_type import IntType -from hannah.nas.expressions.placeholder import IntRange, UndefinedFloat, UndefinedInt -from hannah.nas.functional_operators.op import Tensor, scope -from hannah.nas.functional_operators.operators import Add, Conv2d, Quantize, Relu -from hannah.nas.ops import quantization - - -def conv1d(input, weight, stride, dilation=1): - in_channels = input.shape()[1] - conv = Conv2d(stride, dilation)(input, weight) - return conv - - -def quantize(input, dtype, quantization): - # FIXME: Use quantization - return Quantize()(input) - - -def add(input, other): - return Add()(input, other) - - -def avg_pool(input): - return input # FIXME: Implement avg pool - - -def relu(input): - return Relu()(input) - - -def optional(input, alternative): - return input - - -@scope -def ut_op( - weight_bits: int = 8, - bias_bits: int = 8, - activation_bits: int = 8, - accumulator_bits: int = 8, - max_weight_bits: int = 8, - max_kernel_size: int = 2**4, - max_input_length: int = 2**7, - max_input_channel_block: int = 2**4, - max_output_channel_block: int = 2**4, - stride_range: int = 2**3, -): - input_data_type = IntType(signed=True, bits=activation_bits) - input_quantization = quantization(scale=UndefinedFloat(), zero_point=0) - - input = Tensor(name='input', - shape=(UndefinedInt(), UndefinedInt(), UndefinedInt()), - axis=('N', 'C', 'W'), - dtype=input_data_type) - - weight_data_type = IntType(signed=True, bits=weight_bits) - weight_quantization = quantization(scale=UndefinedFloat(), zero_point=0) - - weight = Tensor(name="weight", - axis=('O', 'I', 'kW'), - shape=(UndefinedInt(), UndefinedInt(), IntRange(1, max_kernel_size)), - dtype=weight_data_type) - - res_input = Tensor(name='res_input', - shape=(UndefinedInt(), UndefinedInt(), UndefinedInt()), - axis=('N', 'C', 'W'), - dtype=input_data_type) - - conv_out = conv1d(input, weight, IntRange(1, stride_range)) - - accumulator_data_type = IntType(signed=True, bits=accumulator_bits) - accumulator_quantization = quantization(scale=UndefinedFloat(), zero_point=0) - - requant_conv = quantize(conv_out, accumulator_data_type, accumulator_quantization) - - bias_data_type = IntType(signed=True, bits=bias_bits) - bias = Tensor(name="bias", - shape=(UndefinedInt()), - axis=("C"), - dtype=bias_data_type) - - bias_add = add(requant_conv, bias) - out = optional(bias_add, requant_conv) - - res_add = add(out, res_input) - out = optional(res_add, out) - - pool = avg_pool(out) - out = optional(pool, out) - - activation = relu(pool) - out = optional(activation, out) - - requantization = quantize(out, input_data_type, input_quantization) - - return requantization - - -if __name__ == '__main__': - op = ut_op() - print(op) - - - - diff --git a/hannah/nas/ops.py b/hannah/nas/ops.py deleted file mode 100644 index 08b528a9..00000000 --- a/hannah/nas/ops.py +++ /dev/null @@ -1,185 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from typing import Optional, Tuple, Union - -import hannah.nas.dataflow.registry as registry -from hannah.nas.core.expression import Expression -from hannah.nas.dataflow.dataflow_utils import process_int -from hannah.nas.dataflow.tensor import Tensor - -from .dataflow.axis_type import AxisType -from .dataflow.compression_type import CompressionType -from .dataflow.data_type import DataType, FloatType, IntType -from .dataflow.dataflow_graph import dataflow -from .dataflow.op_type import OpType -from .dataflow.optional_op import OptionalOp -from .dataflow.quantization_type import QuantizationType -from .dataflow.tensor_type import TensorType -from .expressions.placeholder import DefaultInt, UndefinedInt -from .hardware_description.memory_type import MemoryType - - -def int_t(signed: bool = True, bits: int = 8): - return IntType(signed=signed, bits=bits) - - -def float_t(signed=True, significand_bits=23, exponent_bits=8): - return FloatType( - signed=signed, significand_bits=significand_bits, exponent_bits=exponent_bits - ) - - -def axis( - name: str, size: Optional[int] = None, compression: Optional[CompressionType] = None -): - return AxisType(name=name, size=size, compression=compression) - - -def memory(size: Optional[int] = None, name: Optional[str] = ""): - return MemoryType(size=size, name=name) - - -def quantization( - axis: Optional[AxisType] = None, - scale: Optional[float] = None, - zero_point: Optional[float] = None, -): - return QuantizationType(axis=axis, scale=scale, zero_point=zero_point) - - -def tensor( - axis: Tuple[AxisType, ...], - dtype: DataType, - quantization: Optional[QuantizationType] = None, - memory: Optional[MemoryType] = None, - name: str = "", -): - tensor_type = TensorType( - axis=axis, dtype=dtype, quantization=quantization, memory=memory, name=name - ) - return Tensor(tensor_type=tensor_type, name=name) - - -def tensor_by_tuples(shape, axis_names, dtype=float_t(), name="tensor"): - assert len(shape) == len(axis_names) - ax = [] - for dim, ax_name in zip(shape, axis_names): - ax.append(axis(ax_name, process_int(dim))) - - return tensor(axis=ax, dtype=dtype, name=name) - - -def batched_image_tensor(shape=(1, 3, 16, 16), dtype=float_t(), name=""): - assert len(shape) == 4 - return tensor_by_tuples( - shape=shape, dtype=dtype, name=name, axis_names=("n", "c", "h", "w") - ) - - -def weight_tensor( - dtype: DataType = float_t(), shape: tuple = (None, None, None, None), name="" -): - processed_shape = [None for i in range(len(shape))] - for i in range(len(shape)): - processed_shape[i] = process_int(shape[i]) - - return tensor( - ( - axis("o", processed_shape[0]), - axis("i", processed_shape[1]), - axis("kh", processed_shape[2]), - axis("kw", processed_shape[3]), - ), - dtype=dtype, - name=name, - ) - - -@dataflow -def conv(input): - kernel_size = UndefinedInt() - stride = DefaultInt(1) - weight = tensor( - ( - axis("o", UndefinedInt()), - axis("i", UndefinedInt()), - axis("kh", kernel_size), - axis("kw", kernel_size), - ), - dtype=IntType(), - ) - return registry.op("conv", input, weight, stride=stride) - - -@dataflow -def avg_pool(input: TensorType): - window_size = UndefinedInt() - stride = UndefinedInt() - return OpType("avg_pool", input, window_size=window_size, stride=stride) - - -@dataflow -def requantize(input: TensorType, dtype: DataType, quantization: QuantizationType): - return OpType("requantize", input, dtype=dtype, quantization=quantization) - - -@dataflow -def add(input: TensorType, other: TensorType): - return OpType("add", input, other) - - -@dataflow -def leaky_relu(input: TensorType, negative_slope: float = 0.0001): - return OpType("leaky_relu", input, negative_slope=negative_slope) - - -@dataflow -def relu(input: TensorType): - return OpType("relu", input) - - -@dataflow -def broadcast(input: TensorType, axis: int = 1): - return OpType("broadcast", input, axis=axis) - - -@dataflow -def optional(op: Union[OpType, TensorType], default: Union[OpType, TensorType]): - return OptionalOp(op, default) - - -# @dataflow -# def conv_block(input: TensorType, kernel_size: int = 4): -# out = add( -# conv(out, kernel_size=kernel_size, stride=CategoricalParameter(1, 2)), -# conv(out, kernel_size=DefaultInt(4), name="residual"), -# ) -# out = leaky_relu(out) -# return out - - -# @dataflow -# def network(input: TensorType, blocks: Optional[int] = None): -# out = inp -# with Repeat(blocks): -# with Parametrize( -# {"leaky_relu.negative_slope": FloatScalarParameter(0.000001, 0.1)} -# ): -# out = conv_block(out, kernel_size=4) -# return out diff --git a/hannah/nas/parameters/parameters.py b/hannah/nas/parameters/parameters.py index dbea24c3..47013e47 100644 --- a/hannah/nas/parameters/parameters.py +++ b/hannah/nas/parameters/parameters.py @@ -20,7 +20,8 @@ from abc import abstractmethod from copy import deepcopy -from typing import Optional, Union +from datetime import datetime +from typing import Any, Optional, Sequence, Union import numpy as np @@ -50,24 +51,19 @@ def register(self): self._registered = True @abstractmethod - def sample(self): - ... + def sample(self): ... @abstractmethod - def instantiate(self): - ... + def instantiate(self): ... @abstractmethod - def set_current(self): - ... + def set_current(self): ... @abstractmethod - def check(self, value): - ... + def check(self, value): ... @abstractmethod - def from_float(self, value): - ... + def from_float(self, value): ... # FIXME: evaluate and instantiate? def evaluate(self): @@ -101,8 +97,6 @@ def __repr__(self): + ", ".join((f"{k} = {v}" for k, v in self.__dict__.items())) + ")" ) - - class IntScalarParameter(Parameter): @@ -163,6 +157,7 @@ def set_current(self, value): def from_float(self, val): return int(val * (self.max - self.min) + self.min) + class FloatScalarParameter(Parameter): def __init__( self, @@ -177,7 +172,7 @@ def __init__( self.current_value = self.min def sample(self): - self.current_value = self.rng.uniform(self.min, self.max) + self.current_value = float(self.rng.uniform(self.min, self.max)) return self.current_value def instantiate(self): @@ -202,15 +197,16 @@ def set_current(self, value): def from_float(self, val): return val * (self.max - self.min) + self.min + class CategoricalParameter(Parameter): def __init__( self, - choices, + choices: Sequence[Any], name: Optional[str] = "", rng: Optional[Union[np.random.Generator, int]] = None, ) -> None: super().__init__(name, rng) - self.choices = choices + self.choices = tuple(choices) self.sample() def sample(self): @@ -243,7 +239,7 @@ def set_current(self, value): def __iter__(self): yield from iter(self.choices) - + def from_float(self, val): return self.choices[int(val * len(self.choices))] @@ -304,4 +300,3 @@ def set_current(self, value): def from_float(self, val): raise NotImplementedError("SubsetParameter does not support from_float") - diff --git a/hannah/nas/spaces/darts/darts_space.py b/hannah/nas/spaces/darts/darts_space.py deleted file mode 100644 index 5449050c..00000000 --- a/hannah/nas/spaces/darts/darts_space.py +++ /dev/null @@ -1,70 +0,0 @@ -from hannah.nas.dataflow.dataflow_graph import dataflow -from hannah.nas.dataflow.registry import op -from hannah.nas.expressions.placeholder import DefaultInt -from hannah.nas.ops import weight_tensor -# from hannah.nas.dataflow.ops import conv2d, sum, concat - - -@dataflow -def stem(input): - return input - - -@dataflow -def sum_node(*inputs): - return op('Sum', *inputs) - - -@dataflow -def concat_node(*inputs): - return op('Concat', *inputs, axis=1, out_axis_name="c") - - -@dataflow -def input_node(input): - return op('Identity', input) - - -@dataflow -def op_node(input): - input_tensor = input.tensor_type() - output_channel = DefaultInt(32) # FIXME: - kernel_size = DefaultInt(3) # FIXME: - stride = DefaultInt(1) # FIXME: - weight = weight_tensor(shape=(output_channel, input_tensor['c'], kernel_size, kernel_size), name='weight') - c = op("Conv2d", input, weight, stride=stride) - return c - - -@dataflow -def darts_cell(input, input_prev, num_nodes=4, reduction=True): - input_0 = input_node(input) - input_1 = input_node(input_prev) - - nodes = [input_0, input_1] - - for _ in range(num_nodes): - connections = [] - for node in nodes: - op_connection = op_node(node) - connections.append(op_connection) - s = sum_node(*connections) - nodes.append(s) - - # input nodes are not connected to concat/output node - nodes = nodes[2:] - c = concat_node(*nodes) - - return c - - -@dataflow -def darts_space(input, num_cells=4): - output = input_node(input) - previous = input_node(input) - for _ in range(num_cells): - tmp = output - output = darts_cell(output, previous) - previous = tmp - - return output diff --git a/hannah/nas/spaces/mobilenet/mobilenet.py b/hannah/nas/spaces/mobilenet/mobilenet.py index 91093a54..82b92c56 100644 --- a/hannah/nas/spaces/mobilenet/mobilenet.py +++ b/hannah/nas/spaces/mobilenet/mobilenet.py @@ -1,10 +1,41 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from hannah.nas.dataflow.dataflow_graph import dataflow from hannah.nas.dataflow.registry import op from hannah.nas.dataflow.repeat import repeat from hannah.nas.expressions.placeholder import DefaultFloat, DefaultInt from hannah.nas.ops import weight_tensor -from hannah.nas.parameters.parameters import CategoricalParameter, FloatScalarParameter, IntScalarParameter -from hannah.nas.dataflow.ops import conv2d, batch_nom, relu, dropout, pooling, linear, add, identity # noqa +from hannah.nas.parameters.parameters import ( + CategoricalParameter, + FloatScalarParameter, + IntScalarParameter, +) +from hannah.nas.dataflow.ops import ( + conv2d, + batch_nom, + relu, + dropout, + pooling, + linear, + add, + identity, +) # noqa @dataflow @@ -13,14 +44,22 @@ def residual(input): @dataflow -def add(input, other): # noqa: F811 +def add(input, other): # noqa: F811 return op("Add", input, other) @dataflow -def conv_bn_relu(input, out_channel, kernel_size=DefaultInt(1), stride=DefaultInt(1), groups=DefaultInt(1)): +def conv_bn_relu( + input, + out_channel, + kernel_size=DefaultInt(1), + stride=DefaultInt(1), + groups=DefaultInt(1), +): input_tensor = input.tensor_type() - weight = weight_tensor(shape=(out_channel, input_tensor['c'], kernel_size, kernel_size), name='weight') + weight = weight_tensor( + shape=(out_channel, input_tensor["c"], kernel_size, kernel_size), name="weight" + ) conv = op("Conv2d", input, weight, stride=stride, groups=groups) bn = op("BatchNorm2d", conv) act = op("Relu", bn) @@ -28,23 +67,39 @@ def conv_bn_relu(input, out_channel, kernel_size=DefaultInt(1), stride=DefaultIn @dataflow -def depthwise_separable_convolution(input, out_channel, kernel_size, stride, expand_ratio): +def depthwise_separable_convolution( + input, out_channel, kernel_size, stride, expand_ratio +): input_tensor = input.tensor_type() - expand_conv = conv_bn_relu(input, out_channel=input_tensor['c'].size*expand_ratio) - depthwise_conv = conv_bn_relu(expand_conv, out_channel=input_tensor['c'].size*expand_ratio, kernel_size=kernel_size, stride=stride, groups=input_tensor['c'].size) + expand_conv = conv_bn_relu(input, out_channel=input_tensor["c"].size * expand_ratio) + depthwise_conv = conv_bn_relu( + expand_conv, + out_channel=input_tensor["c"].size * expand_ratio, + kernel_size=kernel_size, + stride=stride, + groups=input_tensor["c"].size, + ) pointwise_conv = conv_bn_relu(depthwise_conv, out_channel=out_channel) return pointwise_conv @dataflow def stem(input, out_channel, stride, kernel_size=DefaultInt(3)): - conv = conv_bn_relu(input, out_channel=out_channel, stride=stride, kernel_size=kernel_size) + conv = conv_bn_relu( + input, out_channel=out_channel, stride=stride, kernel_size=kernel_size + ) return conv @dataflow def inverted_block(input, out_channel, expand_ratio, stride): - conv = depthwise_separable_convolution(input, out_channel=out_channel, kernel_size=DefaultInt(3), expand_ratio=expand_ratio, stride=stride) + conv = depthwise_separable_convolution( + input, + out_channel=out_channel, + kernel_size=DefaultInt(3), + expand_ratio=expand_ratio, + stride=stride, + ) # TODO: conditional residual connection (only if applicable?) res = residual(input) residual_add = add(conv, res) @@ -65,7 +120,12 @@ def mobilenet(input, num_cells): expand_ratio = FloatScalarParameter(min=0.5, max=6.0) stride = CategoricalParameter(choices=[1, 2]) s = stem(input, out_channel=out_channel.new(), stride=DefaultInt(2)) - graph = repeat(inverted_block, num_repeats=num_cells)(s, out_channel=out_channel.new(), expand_ratio=expand_ratio.new(), stride=stride.new()) + graph = repeat(inverted_block, num_repeats=num_cells)( + s, + out_channel=out_channel.new(), + expand_ratio=expand_ratio.new(), + stride=stride.new(), + ) # graph = inverted_block(s, out_channel=out_channel.new(), expand_ratio=expand_ratio.new(), stride=stride.new()) clf = classifier(graph, classes=DefaultInt(10)) return clf diff --git a/hannah/nas/test/network.py b/hannah/nas/test/network.py deleted file mode 100644 index f37ca5fe..00000000 --- a/hannah/nas/test/network.py +++ /dev/null @@ -1,92 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from hannah.nas.dataflow.dataflow_graph import dataflow -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.ops import ( # noqa: F401 (Import to load in registry) - add, - conv2d, -) -from hannah.nas.dataflow.registry import op -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.ops import weight_tensor -from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter - - -@dataflow -def conv_relu( - input: TensorExpression, - output_channel=IntScalarParameter(4, 64), - kernel_size=CategoricalParameter([1, 3, 5]), - stride=CategoricalParameter([1, 2]), -): - input_tensor = input.tensor_type() - weight = weight_tensor( - shape=(output_channel, input_tensor["c"], kernel_size, kernel_size), - name="weight", - ) - - c = op("Conv2d", input, weight, stride=stride) - relu = OpType(c, name="Relu") - return relu - - -@dataflow -def block( - input: TensorExpression, - expansion=IntScalarParameter(1, 6), - output_channel=IntScalarParameter(4, 64), - kernel_size=CategoricalParameter([1, 3, 5]), - stride=CategoricalParameter([1, 2]), -): - input_tensor = input.tensor_type() - out = conv_relu( - input, output_channel=output_channel, kernel_size=kernel_size, stride=stride - ) - out = conv_relu( - out, output_channel=output_channel.new(), kernel_size=kernel_size, stride=stride - ) - out = conv_relu( - out, output_channel=output_channel.new(), kernel_size=kernel_size, stride=stride - ) - return out - - -@dataflow -def residual(input: TensorExpression, stride, output_channel): - out = conv_relu( - input, - stride=stride, - output_channel=output_channel.new(), - kernel_size=CategoricalParameter([1, 3, 5]), - ) - return out - - -@dataflow -def add(input: TensorExpression, other: TensorExpression): # noqa - out = op("Add", input, other) - return out - - -@dataflow -def residual_block(input: TensorExpression, stride, output_channel): - main_branch = block(input, stride=stride, output_channel=output_channel) - residual_branch = residual(input, stride=stride, output_channel=output_channel) - add_branches = add(main_branch, residual_branch) - return add_branches diff --git a/hannah/nas/test/test_add.py b/hannah/nas/test/test_add.py deleted file mode 100644 index 61392324..00000000 --- a/hannah/nas/test/test_add.py +++ /dev/null @@ -1,80 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from hannah.nas.dataflow.dataflow_graph import dataflow -from hannah.nas.dataflow.dataflow_utils import process_int -from hannah.nas.dataflow.ops import add # import to register -from hannah.nas.dataflow.registry import op -from hannah.nas.expressions.placeholder import DefaultInt -from hannah.nas.ops import tensor_by_tuples, weight_tensor -from hannah.nas.parameters.parameters import CategoricalParameter - - -def test_add(): - tensor_a = tensor_by_tuples((1, 4, 16, 16), ("a", "b", "c", "d")) - tensor_b = tensor_by_tuples((1, 4, 16, 16), ("a", "b", "c", "d")) - - add_op = op("Add", tensor_a, tensor_b) - returned_tensor = add_op.tensor_type() - for name, ax in returned_tensor.axis.items(): - print("{}: {}".format(name, ax.size.evaluate())) - print() - - -@dataflow -def parallel_convs(input): - channel = process_int(32) - kernel_size = CategoricalParameter([1, 3, 5]) - stride = CategoricalParameter([1, 2]) - dilation = DefaultInt(1) - - input_tensor = input.tensor_type() - - weight1 = weight_tensor( - shape=(channel, input_tensor["c"], kernel_size, kernel_size), name="weight" - ) - conv1 = op("Conv2d", input, weight1, dilation=dilation, stride=stride) - - weight2 = weight_tensor( - shape=(channel, input_tensor["c"], kernel_size, kernel_size), name="weight" - ) - conv2 = op("Conv2d", input, weight2, dilation=dilation, stride=stride) - - add_op = op("Add", conv1, conv2) - - return add_op - - -def test_parallel_convs(): - input_tensor = tensor_by_tuples((1, 3, 16, 16), ("n", "c", "h", "w"), name="input") - convs = parallel_convs(input_tensor) - - # convs['parallel_convs.0.Conv2d.0'].stride.set_current(2) - - returned_tensor = convs.output.tensor_type() - - for name, ax in returned_tensor.axis.items(): - print("{}: {}".format(name, ax.size.evaluate())) - print() - - -if __name__ == "__main__": - test_add() - test_parallel_convs() - print() diff --git a/hannah/nas/test/test_adjacency.py b/hannah/nas/test/test_adjacency.py deleted file mode 100644 index 0981b7d7..00000000 --- a/hannah/nas/test/test_adjacency.py +++ /dev/null @@ -1,52 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import matplotlib.pyplot as plt -import networkx as nx -import pytest - -from hannah.nas.ops import batched_image_tensor -from hannah.nas.parameters.parameters import IntScalarParameter -from hannah.nas.test.network import residual_block - - -@pytest.mark.xfail -def test_adjacency(): - input = batched_image_tensor(shape=(1, 3, 32, 32), name="input") - graph = residual_block( - input, - stride=IntScalarParameter(1, 2), - output_channel=IntScalarParameter(4, 512, 4), - ) - - a, indices = graph.adjacency() - print(a) - g = nx.from_numpy_array(a) - nx.topological_sort(g) - a_top = nx.to_numpy_array(g) - print(a_top) - # mapping = {i: n for n, i in indices.items()} - # g = nx.relabel_nodes(g, mapping) - # nx.draw(g, with_labels=True) - # plt.show() - - print() - - -if __name__ == "__main__": - test_adjacency() diff --git a/hannah/nas/test/test_constraint_model.py b/hannah/nas/test/test_constraint_model.py deleted file mode 100644 index af4660cf..00000000 --- a/hannah/nas/test/test_constraint_model.py +++ /dev/null @@ -1,114 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pytest -from z3 import And - -from hannah.nas.constraints.constraint_model import ConstraintModel -from hannah.nas.dataflow.dataflow_graph import flatten -from hannah.nas.ops import batched_image_tensor -from hannah.nas.parameters.parameters import IntScalarParameter -from hannah.nas.test.network import residual_block - - -@pytest.mark.skip(reason="Tests are written for old version of constraint model. Must be updated.") -def test_constraint_model(): - # Create a network and flatten the graph - input = batched_image_tensor(name="input") - graph = residual_block( - input, - stride=IntScalarParameter(1, 2), - output_channel=IntScalarParameter(4, 512, 4), - ) - graph = flatten(graph) - - # build a constraint model - cm = ConstraintModel() - cm.build_model(graph) - - # retrieve constraint vars from cm for better clarity - out_channel_main = cm.vars[ - "residual_block.0.block.0.conv_relu.2.Conv2d.0.weight.0.axis.o.size" - ] - out_channel_residual = cm.vars[ - "residual_block.0.residual.0.conv_relu.0.Conv2d.0.weight.0.axis.o.size" - ] - - # Check assumptions for satisfiability - assert cm.solver.check(out_channel_main <= 256).r > 0 - assert cm.solver.check(out_channel_main >= 1024).r < 0 - assert cm.solver.check(out_channel_main == 128).r > 0 - assert cm.solver.check(out_channel_main == 129).r < 0 - - assert ( - cm.solver.check(And(out_channel_main >= 64), (out_channel_residual <= 128)).r - > 0 - ) - assert ( - cm.solver.check(And(out_channel_main <= 64), (out_channel_residual >= 128)).r - < 0 - ) - - # one can find a possible (if current constraints SAT) configuration with .model() - # cm.solver.check() - # model = cm.solver.model() - # print(model) - - # Iterative solving - cm.solver.push() - cm.solver.add(out_channel_main <= 256) - - assert cm.solver.check(out_channel_residual == 128).r > 0 - assert cm.solver.check(out_channel_residual == 512).r < 0 - - # restore previous model - cm.solver.pop() - - assert cm.solver.check(out_channel_residual == 128).r > 0 - assert cm.solver.check(out_channel_residual == 512).r > 0 - - -@pytest.mark.skip(reason="Tests are written for old version of constraint model. Must be updated.") -def test_constraint_model_parameters(): - input = batched_image_tensor(name="input") - graph = residual_block( - input, - stride=IntScalarParameter(1, 2), - output_channel=IntScalarParameter(4, 512, 4), - ) - graph = flatten(graph) - - # build a constraint model - cm = ConstraintModel() - cm.build_model(graph) - - params = graph.parametrization(flatten=True) - - for i in range(1): - cm.solver.push() - for name, param in params.items(): - assert name in cm.vars - cm.solver.add(cm.vars[name] == param.sample().item()) - - print("Sat: ", cm.solver.check()) - cm.solver.pop() - - -if __name__ == "__main__": - test_constraint_model() - test_constraint_model_parameters() diff --git a/hannah/nas/test/test_conv2d.py b/hannah/nas/test/test_conv2d.py deleted file mode 100644 index d5f7986d..00000000 --- a/hannah/nas/test/test_conv2d.py +++ /dev/null @@ -1,108 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pytest - -from hannah.nas.dataflow.dataflow_graph import dataflow -from hannah.nas.dataflow.ops import conv2d # noqa #Import to load in registry -from hannah.nas.dataflow.registry import op -from hannah.nas.expressions.placeholder import DefaultInt -from hannah.nas.ops import batched_image_tensor, weight_tensor -from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter - - -@dataflow -def conv2d( # noqa - input, - channel, - kernel_size=DefaultInt(1), - stride=DefaultInt(1), - dilation=DefaultInt(1), -): - weight = weight_tensor( - shape=(channel, input["c"], kernel_size, kernel_size), name="weight" - ) - padding = kernel_size // 2 - return op( - "Conv2d", input, weight, dilation=dilation, stride=stride, padding=padding - ) - - -@dataflow -def chained_convs( - input, - channel, - kernel_size=DefaultInt(1), - stride=DefaultInt(1), - dilation=DefaultInt(1), -): - padding = kernel_size // 2 - weight1 = weight_tensor( - shape=(channel, input["c"], kernel_size, kernel_size), name="weight" - ) - conv1 = op( - "Conv2d", input, weight1, dilation=dilation, stride=stride, padding=padding - ) - - weight2 = weight_tensor( - shape=(channel, input["c"], kernel_size, kernel_size), name="weight" - ) - conv2 = op( - "Conv2d", conv1, weight2, dilation=dilation, stride=stride, padding=padding - ) - - return conv2 - - -@pytest.mark.xfail() -def test_conv2d(): - inp = batched_image_tensor(name="input") - - kernel_size = CategoricalParameter([1, 3, 5]) - stride = CategoricalParameter([1, 2]) - - conv = conv2d( - inp, channel=IntScalarParameter(4, 64), kernel_size=kernel_size, stride=stride - ) - - conv["conv2d.0.Conv2d.0"].kernel_size.set_current(3) - conv["conv2d.0.Conv2d.0"].stride.set_current(2) - - returned_tensor = conv.output.tensor_type() - for name, ax in returned_tensor.tensor_type.axis.items(): - print("{}: {}".format(name, ax.size.evaluate())) - print() - - -def test_chained_conv2d(): - inp = batched_image_tensor(name="input") - - ks = CategoricalParameter([1, 3, 5]) - ks.set_current(3) - convs = chained_convs(inp, channel=IntScalarParameter(4, 64), kernel_size=ks) - returned_tensor = convs.output.tensor_type() - - for name, ax in returned_tensor.axis.items(): - print("{}: {}".format(name, ax.size.evaluate())) - print() - - -if __name__ == "__main__": - # test_conv2d() - test_chained_conv2d() - print() diff --git a/hannah/nas/test/test_darts_space.py b/hannah/nas/test/test_darts_space.py deleted file mode 100644 index cae58e49..00000000 --- a/hannah/nas/test/test_darts_space.py +++ /dev/null @@ -1,13 +0,0 @@ -from hannah.nas.ops import batched_image_tensor -from hannah.nas.spaces.darts.darts_space import darts_space -from hannah.nas.dataflow.ops import conv2d, identity, sum, concat - - -def test_darts_space(): - input = batched_image_tensor(name='input') - # darts = darts_space(input) - print() - - -if __name__ == '__main__': - test_darts_space() \ No newline at end of file diff --git a/hannah/nas/test/test_dataflow.py b/hannah/nas/test/test_dataflow.py deleted file mode 100644 index 8ea0cc38..00000000 --- a/hannah/nas/test/test_dataflow.py +++ /dev/null @@ -1,98 +0,0 @@ -from hannah.nas.dataflow.dataflow_graph import dataflow, DataFlowGraph, flatten -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.registry import op -from hannah.nas.dataflow.tensor_expression import TensorExpression -from hannah.nas.expressions.placeholder import DefaultInt -from hannah.nas.parameters.parameters import CategoricalParameter, FloatScalarParameter, IntScalarParameter -from hannah.nas.ops import batched_image_tensor, weight_tensor -from hannah.nas.dataflow.ops import conv2d, add # noqa: F401 (Import to load in registry) - - -@dataflow -def conv_relu(input: TensorExpression, - output_channel=IntScalarParameter(4, 64), - kernel_size=CategoricalParameter([1, 3, 5]), - stride=CategoricalParameter([1, 2])): - input_tensor = input.tensor_type() - weight = weight_tensor(shape=(output_channel, input_tensor['c'], kernel_size, kernel_size), name='weight') - - c = op("Conv2d", input, weight, stride=stride) - relu = OpType(c, name='Relu') - return relu - - -@dataflow -def block(input: TensorExpression, - expansion=FloatScalarParameter(1, 6, name='expansion'), - output_channel=IntScalarParameter(4, 64, name='out_channels'), - kernel_size=CategoricalParameter([1, 3, 5], name='kernel_size'), - stride=CategoricalParameter([1, 2], name='stride')): - - out = conv_relu(input, output_channel=output_channel.new()*expansion.new(), kernel_size=kernel_size.new(), stride=DefaultInt(1)) - out = conv_relu(out, output_channel=output_channel.new(), kernel_size=DefaultInt(1), stride=stride.new()) - return out - - -@dataflow -def add(input: TensorExpression, other: TensorExpression): # noqa - out = op('Add', input, other) - return out - - -def test_dataflow(): - input = batched_image_tensor(name='input') - out = conv_relu(input) - assert isinstance(out, DataFlowGraph) - - -def test_dataflow_linking(): - input = batched_image_tensor(name='input') - out = conv_relu(input) - out = conv_relu(out) - assert isinstance(out, DataFlowGraph) - - -def test_dataflow_block(): - input = batched_image_tensor(name='input') - out = block(input) - out = block(out) - - assert isinstance(out, DataFlowGraph) - - -def test_parallel_blocks(): - input = batched_image_tensor(name='input') - graph_0 = block(input, stride=IntScalarParameter(min=1, max=2)) - graph_1 = block(input, stride=DefaultInt(2)) - graph = add(graph_0, graph_1) - - assert isinstance(graph, DataFlowGraph) - - -def test_flatten(): - input = batched_image_tensor(name='input') - graph_0 = block(input, stride=IntScalarParameter(min=1, max=2)) - graph_1 = block(input, stride=DefaultInt(2)) - graph = add(graph_0, graph_1) - flattened_graph = flatten(graph) - - assert isinstance(flattened_graph, OpType) - - -def test_parameter_extraction(): - input = batched_image_tensor(name='input') - out = block(input, stride=IntScalarParameter(min=1, max=2, name='stride')) - out = block(out) - # flattened_graph = flatten(out) - params = out.parametrization(include_empty=True, flatten=True) - - assert isinstance(out, DataFlowGraph) - assert 'block.0.conv_relu.1.Conv2d.0.stride' in params and isinstance(params['block.0.conv_relu.1.Conv2d.0.stride'], IntScalarParameter) - - -if __name__ == '__main__': - test_dataflow() - test_dataflow_linking() - test_dataflow_block() - test_parallel_blocks() - test_parameter_extraction() diff --git a/hannah/nas/test/test_description_ultratrail.py b/hannah/nas/test/test_description_ultratrail.py deleted file mode 100644 index 67bda1a9..00000000 --- a/hannah/nas/test/test_description_ultratrail.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pytest - -from hannah.nas.hardware_description.device import Ultratrail -from hannah.nas.parameters import IntScalarParameter - - -@pytest.mark.xfail -def test_ultratrail_description(): - ultratrail = Ultratrail( - weight_bits=IntScalarParameter(min=1, max=8), - bias_bits=IntScalarParameter(min=1, max=8), - activation_bits=IntScalarParameter(min=1, max=8), - accumulator_bits=IntScalarParameter(min=1, max=32), - max_weight_bits=IntScalarParameter(min=4, max=8), - ) - - print(ultratrail) - - -if __name__ == "__main__": - test_ultratrail_description() diff --git a/hannah/nas/test/test_dfg_removal.py b/hannah/nas/test/test_dfg_removal.py deleted file mode 100644 index f4fe7daf..00000000 --- a/hannah/nas/test/test_dfg_removal.py +++ /dev/null @@ -1,140 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from hannah.nas.dataflow.dataflow_graph import dataflow, flatten -from hannah.nas.dataflow.ops import ( # noqa: F401 (Import to load in registry) - add, - conv2d, -) -from hannah.nas.dataflow.registry import op -from hannah.nas.expressions.placeholder import DefaultInt -from hannah.nas.ops import batched_image_tensor, weight_tensor -from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter - - -@dataflow -def conv( - input, - channel, - kernel_size=DefaultInt(1), - stride=DefaultInt(1), - dilation=DefaultInt(1), -): - weight = weight_tensor( - shape=(channel, input["c"], kernel_size, kernel_size), name="weight" - ) - padding = kernel_size // 2 - return op( - "Conv2d", input, weight, dilation=dilation, stride=stride, padding=padding - ) - - -@dataflow -def convs( - input, - channel, - kernel_size=DefaultInt(1), - stride=DefaultInt(1), - dilation=DefaultInt(1), -): - padding = kernel_size // 2 - input_tensor = input.tensor_type() - - weight1 = weight_tensor( - shape=(channel, input_tensor["c"], kernel_size, kernel_size), name="weight" - ) - conv1 = op( - "Conv2d", input, weight1, dilation=dilation, stride=stride, padding=padding - ) - # conv1_tensor = conv1.tensor_type() - - # weight2 = weight_tensor(shape=(channel, conv1_tensor['c'], kernel_size, kernel_size), name='weight') - weight2 = weight_tensor( - shape=(channel, input_tensor["c"], kernel_size, kernel_size), name="weight" - ) - conv2 = op( - "Conv2d", conv1, weight2, dilation=dilation, stride=stride, padding=padding - ) - - return conv2 - - -@dataflow -def add(input, other): # noqa - return op("Add", input, other) - - -def traverse_users(node): - print(node) - for user in node.users: - traverse_users(user) - - -def test_dfg_removal(): - inp = batched_image_tensor(name="input") - - kernel_size = CategoricalParameter([1, 3, 5]) - stride = CategoricalParameter([1, 2]) - channel = IntScalarParameter(4, 64) - channel.set_current(52) - graph = conv(inp, channel=channel, kernel_size=kernel_size, stride=stride) - - -def test_chained_convs_removal(): - inp = batched_image_tensor(name="input") - - ks = CategoricalParameter([1, 3, 5]) - ks.set_current(3) - graph = convs(inp, channel=IntScalarParameter(4, 64), kernel_size=ks) - # flattened_graph = flatten(graph) - - print() - - -def test_chained_dfg_removal(): - inp = batched_image_tensor(name="input") - - ks = CategoricalParameter([1, 3, 5]) - ks.set_current(3) - graph = convs(inp, channel=IntScalarParameter(4, 64), kernel_size=ks) - graph1 = convs(graph, channel=IntScalarParameter(4, 64), kernel_size=ks) - flattened_graph = flatten(graph1) - # traverse_users(inp) - print() - - -def test_parallel_branch_dfg_removal(): - inp = batched_image_tensor(name="input") - - ks = CategoricalParameter([1, 3, 5]) - ks.set_current(3) - graph_0 = convs(inp, channel=IntScalarParameter(4, 64), kernel_size=ks) - graph_1 = convs(inp, channel=IntScalarParameter(4, 64), kernel_size=ks) - - graph_add = add(graph_0, graph_1) - flattened_graph = flatten(graph_add) - traverse_users(inp) - print() - - -if __name__ == "__main__": - # test_dfg_removal() - # test_chained_convs_removal() - # test_chained_dfg_removal() - test_parallel_branch_dfg_removal() - print() diff --git a/hannah/nas/test/test_fake_quantize.py b/hannah/nas/test/test_fake_quantize.py new file mode 100644 index 00000000..036e8de5 --- /dev/null +++ b/hannah/nas/test/test_fake_quantize.py @@ -0,0 +1,74 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +import torch + +from hannah.nas.functional_operators.data_type import IntType +from hannah.nas.functional_operators.executor import BasicExecutor +from hannah.nas.functional_operators.op import Tensor +from hannah.nas.functional_operators.quant import FixedQuantize +from hannah.nas.functional_operators.op import search_space + + +def test_fixed_quantize(): + for bits in range(3, 8): + for signed in [True, False]: + dtype = IntType(bits=bits, signed=signed) + + scale = -1.0 / 2 ** (bits - 1) + zero_point = float(2 ** (bits - 1) - 1) + + quantizer = FixedQuantize(scale=scale, zero_point=zero_point, dtype=dtype) + + x = torch.tensor([0.0, 0.5, 0.75]) + + x_quantized = quantizer.forward((x,)) + + torch.testing.assert_close( + x_quantized, torch.tensor([0.0, 0.5, 0.75]), rtol=1e-3, atol=1e-3 + ) + + +def test_onnx_export(): + shape = (1, 1, 3, 3) + input = Tensor(name="input", shape=shape, axis=("N", "C", "H", "W"), grad=False) + + @search_space + def network(input): + out = FixedQuantize( + scale=1.0, zero_point=0.0, dtype=IntType(bits=8, signed=True) + )(input) + + return out + + executor = BasicExecutor(network(input)) + executor.initialize() + + real_input = torch.rand(shape) + output = executor.forward(real_input) + + # Convert the model to ONNX + executor.eval() + + torch.onnx.export(executor, real_input, "test.onnx", verbose=True) + + # registry = torch.onnx.OnnxRegistry() + + # onnx_program = torch.onnx.dynamo_export(executor, real_input, export_options = torch.onnx.ExportOptions(op_level_debug=True, diagnostic_options=torch.onnx.DiagnosticOptions(verbosity_level=logging.DEBUG))).save('fixed_quantize.onnx') diff --git a/hannah/nas/test/test_graph_transformer.py b/hannah/nas/test/test_graph_transformer.py deleted file mode 100644 index 3fbce12e..00000000 --- a/hannah/nas/test/test_graph_transformer.py +++ /dev/null @@ -1,194 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from hannah.nas.dataflow.dataflow_graph import ( - DataFlowGraph, - dataflow, - find_first_input, - flatten, -) -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.registry import op -from hannah.nas.dataflow.transformations.graph_tranformer import GraphTransformer -from hannah.nas.expressions.placeholder import DefaultInt -from hannah.nas.ops import batched_image_tensor, weight_tensor -from hannah.nas.parameters.parameters import IntScalarParameter -from hannah.nas.test.network import residual_block - - -def get_dict(graph): - queue = [graph] - visited = [] - hierarchy_dict = {} - - while queue: - current = queue.pop(-1) - visited.append(current) - name = current.id.split(".") - sub_dict = hierarchy_dict - - for i, key in enumerate(name): - try: - k = int(key) - except Exception: - k = key - k = key - - if k not in sub_dict: - sub_dict[k] = {} - if i != len(name) - 1: - sub_dict = sub_dict[k] - else: - sub_dict[k]["texpr"] = current - - for operand in current.operands: - queue.append(operand) - return hierarchy_dict - - -def get_parent_scope(node): - return node.id.split(".")[:-2] - - -def check_dfg_change(node, successor): - node_parent_scope = ".".join(get_parent_scope(node)) - successor_parent_scope = ".".join(get_parent_scope(successor)) - return node_parent_scope != successor_parent_scope - - -def get_dfg_depth(node): - scope = get_parent_scope(node) - return int(len(scope) / 2) - - -def create_dfgs(op_graph): - queue = [find_first_input(op_graph)] - enter = None - output = None - previous = None - operands = [] - - while queue: - current = queue.pop(-1) - parent_scope = get_parent_scope(current) - - if previous: - if check_dfg_change(previous, current): - if get_dfg_depth(previous) > 0: - dfg = DataFlowGraph( - output=previous, name=get_parent_scope(previous)[-2] - ) - dfg.enter = enter - enter = current - print(f"Enter dfg {get_parent_scope(current)}") - if get_dfg_depth(previous) != get_dfg_depth(current): - print( - f"changed dfg depth from {get_dfg_depth(previous)} to {get_dfg_depth(current)}" - ) - - for user in current.users: - queue.append(user) - - previous = current - - -def make_dataflow_graph(hierarchy_dict, dfgs, output_dict): - for key in hierarchy_dict.keys(): - for num in hierarchy_dict[key].keys(): - if "texpr" in hierarchy_dict[key][num]: - output = hierarchy_dict[key][num]["texpr"] - - name_list = output.id.split(".") - if len(name_list) > 2: - name = output.id.split(".")[-4] - else: - dfgs.append(output) - return output - else: - output = make_dataflow_graph( - hierarchy_dict[key][num], dfgs, output_dict - ) - name = key - dfg = DataFlowGraph(output=output, name=f"{name}") - output_dict[output] = dfg - dfgs.append(dfg) - return dfgs[0] - - -def unflatten(graph): - hierarchy = {} - - queue = [graph] - visited = [] - - while queue: - current = queue.pop(-1) - visited.append(current) - name = current.id.split(".") - - -def write_down(graph): - print(graph.id) - for operand in graph.operands: - write_down(operand) - - -@dataflow -def exchange_block(input): - input_tensor = input.tensor_type() - weight = weight_tensor( - shape=(DefaultInt(1), input_tensor["c"], DefaultInt(1), DefaultInt(1)), - name="weight", - ) - c = op("Conv2d", input, weight, stride=DefaultInt(1)) - relu = OpType(c, name="Relu") - return relu - - -def test_graph_transformer(): - # Create a network and flatten the graph - input = batched_image_tensor(name="input") - graph = residual_block( - input, - stride=IntScalarParameter(1, 2), - output_channel=IntScalarParameter(4, 512, 4), - ) - - # flat = flatten(graph) - # create_dfgs(flat) - - # d = get_dict(flat) - # dfgs = [] - # output_dict = {} - # make_dataflow_graph(d, dfgs, output_dict) - # dfg = unflatten(flat) - - transformer = GraphTransformer(graph) - - def transform(source, target): - args = [op for op in source.operands] - kwargs = {} - - return args, kwargs - - transformer.transform("conv_relu", exchange_block, transform) - print() - - -if __name__ == "__main__": - test_graph_transformer() diff --git a/hannah/nas/test/test_mobilenet.py b/hannah/nas/test/test_mobilenet.py deleted file mode 100644 index b405f6cd..00000000 --- a/hannah/nas/test/test_mobilenet.py +++ /dev/null @@ -1,13 +0,0 @@ -from hannah.nas.ops import batched_image_tensor -from hannah.nas.spaces.mobilenet.mobilenet import mobilenet - - - -def test_mbn_space(): - input = batched_image_tensor(name='input') - space = mobilenet(input, num_cells=4) - print() - - -if __name__ == '__main__': - test_mbn_space() diff --git a/hannah/nas/test/test_onnx_export.py b/hannah/nas/test/test_onnx_export.py index b7f45be9..fa76cb22 100644 --- a/hannah/nas/test/test_onnx_export.py +++ b/hannah/nas/test/test_onnx_export.py @@ -23,7 +23,7 @@ import onnx from hannah.models.embedded_vision_net.models import embedded_vision_net, search_space -from hannah.models.ai8x.models import ai8x_search_space +from hannah.models.ai8x.models_simplified import ai8x_search_space from hannah.nas.constraints.random_walk import RandomWalkConstraintSolver from hannah.nas.export import to_onnx from hannah.nas.functional_operators.op import ChoiceOp, Tensor, scope diff --git a/hannah/nas/test/test_op_to_torch_conversion.py b/hannah/nas/test/test_op_to_torch_conversion.py deleted file mode 100644 index c1360455..00000000 --- a/hannah/nas/test/test_op_to_torch_conversion.py +++ /dev/null @@ -1,62 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pytest - -from hannah.nas.dataflow.dataflow_graph import dataflow -from hannah.nas.dataflow.ops import conv2d # Import to load in registry -from hannah.nas.dataflow.registry import op -from hannah.nas.expressions.placeholder import DefaultInt -from hannah.nas.ops import batched_image_tensor, weight_tensor -from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter - - -@dataflow -def conv( - input, - channel, - kernel_size=DefaultInt(1), - stride=DefaultInt(1), - dilation=DefaultInt(1), -): - weight = weight_tensor( - shape=(channel, input["c"], kernel_size, kernel_size), name="weight" - ) - padding = kernel_size // 2 - return op( - "Conv2d", input, weight, dilation=dilation, stride=stride, padding=padding - ) - - -@pytest.mark.xfail -def test_conv2d(): - inp = batched_image_tensor(name="input") - - kernel_size = CategoricalParameter([1, 3, 5]) - stride = CategoricalParameter([1, 2]) - channel = IntScalarParameter(4, 64) - channel.set_current(52) - conv_dataflow = conv(inp, channel=channel, kernel_size=kernel_size, stride=stride) - - torch_op = conv_dataflow["conv.0.Conv2d.0"].convert(target="torch") - print() - - -if __name__ == "__main__": - test_conv2d() - print() diff --git a/hannah/nas/test/test_repeat.py b/hannah/nas/test/test_repeat.py deleted file mode 100644 index ece8e63a..00000000 --- a/hannah/nas/test/test_repeat.py +++ /dev/null @@ -1,56 +0,0 @@ -from hannah.nas.dataflow.dataflow_graph import dataflow, DataFlowGraph, flatten -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor_type import TensorType -from hannah.nas.expressions.placeholder import UndefinedInt -from hannah.nas.parameters.parameters import CategoricalParameter, FloatScalarParameter, IntScalarParameter -from hannah.nas.ops import axis, tensor, batched_image_tensor, float_t -from hannah.nas.dataflow.repeat import repeat - - -@dataflow -def conv_relu(input: TensorType, - output_channel=IntScalarParameter(4, 64), - kernel_size=CategoricalParameter([1, 3, 5]), - stride=CategoricalParameter([1, 2])): - - weight = tensor((axis('o', size=output_channel), - axis('i', size=UndefinedInt()), - axis('kh', size=kernel_size), - axis('kw', size=kernel_size)), - dtype=float_t(), - name='weight') - - op = OpType(input, weight, stride=stride, name='conv2d') - relu = OpType(op, name='relu') - return relu - - -@dataflow -def block(input: TensorType, - expansion=FloatScalarParameter(1, 6), - output_channel=IntScalarParameter(4, 64), - kernel_size=CategoricalParameter([1, 3, 5]), - stride=CategoricalParameter([1, 2])): - - out = conv_relu(input, - output_channel=output_channel*expansion, - kernel_size=kernel_size, - stride=stride) - out = conv_relu(out, - output_channel=output_channel, - kernel_size=1, - stride=1) - return out - - -def test_repeat(): - input = batched_image_tensor(name='input') - graph = repeat(block, num_repeats=IntScalarParameter(min=1, max=5))(input) - graph = block(graph) - print(graph) - - assert isinstance(graph, DataFlowGraph) - - -if __name__ == '__main__': - test_repeat() diff --git a/hannah/nas/test/test_scoping.py b/hannah/nas/test/test_scoping.py deleted file mode 100644 index d283dff7..00000000 --- a/hannah/nas/test/test_scoping.py +++ /dev/null @@ -1,67 +0,0 @@ -from hannah.nas.dataflow.dataflow_graph import dataflow, DataFlowGraph -from hannah.nas.dataflow.dataflow_utils import traverse_by_users -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor import Tensor -from hannah.nas.dataflow.tensor_type import TensorType -from hannah.nas.expressions.placeholder import UndefinedInt -from hannah.nas.parameters.parameters import CategoricalParameter, FloatScalarParameter, IntScalarParameter -from hannah.nas.ops import axis, tensor, batched_image_tensor, float_t - - -@dataflow -def conv_relu(input: TensorType, - output_channel=IntScalarParameter(4, 64), - kernel_size=CategoricalParameter([1, 3, 5]), - stride=CategoricalParameter([1, 2])): - - weight = tensor((axis('o', size=output_channel), - axis('i', size=UndefinedInt()), - axis('kh', size=kernel_size), - axis('kw', size=kernel_size)), - dtype=float_t(), - name='weight') - - op = OpType(input, weight, stride=stride, name='conv2d') - relu = OpType(op, name='relu') - return relu - - -@dataflow -def block(input: Tensor, - expansion=FloatScalarParameter(1, 6), - output_channel=IntScalarParameter(4, 64), - kernel_size=CategoricalParameter([1, 3, 5]), - stride=CategoricalParameter([1, 2])): - - out = conv_relu(input, - output_channel=output_channel*expansion, - kernel_size=kernel_size, - stride=stride) - out = conv_relu(out, - output_channel=output_channel, - kernel_size=1, - stride=1) - return out - - -@dataflow -def network(input: Tensor): - out = block(input) - out = block(out) - return out - - -def test_scoping(): - input = batched_image_tensor(name='input') - out = block(input) - graph = block(out) - # graph = network(input) - traverse_by_users(input) - print() - print(graph['block.0.conv_relu.0.conv2d.0.weight.0']) - - assert isinstance(graph, DataFlowGraph) - - -if __name__ == '__main__': - test_scoping() diff --git a/hannah/nas/test/test_target_desc_markdown.py b/hannah/nas/test/test_target_desc_markdown.py new file mode 100644 index 00000000..6702325d --- /dev/null +++ b/hannah/nas/test/test_target_desc_markdown.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2023 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from hannah.nas.hardware_description.backend import MarkdownBackend +from hannah.nas.hardware_description.testing import get_device + + +def test_simple_device(): + simple_device = get_device("simple_device") + + backend = MarkdownBackend() + created_markdown = backend.generate(simple_device) + + print("Created Markdown:\n\n", created_markdown, sep="") + + +if __name__ == "__main__": + test_simple_device() diff --git a/hannah/nas/test/test_target_desc_search_space.py b/hannah/nas/test/test_target_desc_search_space.py new file mode 100644 index 00000000..9a6aa494 --- /dev/null +++ b/hannah/nas/test/test_target_desc_search_space.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from hannah.nas.functional_operators.operators import Conv2d, Tensor +from hannah.nas.functional_operators.op import search_space +from hannah.nas.hardware_description.backend import HannahBackend +from hannah.nas.hardware_description.testing import get_device +from hannah.nas.parameters import CategoricalParameter + + +@search_space +def get_simple_space(kernel_size=3): + input = Tensor("input", shape=[1, 3, 32, 32], axis=["N", "C", "H", "W"]) + weight = Tensor( + "weight", shape=[16, 3, kernel_size, kernel_size], axis=["O", "I", "kH", "kW"] + ) + conv = Conv2d(stride=1, padding=1)(input, weight) + + return conv + + +def test_simple_device(): + simple_device = get_device("simple_device") + + backend = HannahBackend() + hannah_target = backend.generate(simple_device) + + simple_space = get_simple_space() + + hannah_target.run(simple_space) + + for match in hannah_target.matches: + print(match) + + +def test_simple_device_categorical(): + simple_device = get_device("simple_device") + + backend = HannahBackend() + hannah_target = backend.generate(simple_device) + + simple_space = get_simple_space(CategoricalParameter("kernel_size", (3, 5, 7))) + + hannah_target.run(simple_space) + + for match in hannah_target.matches: + print(match) + + +if __name__ == "__main__": + test_simple_device() diff --git a/hannah/nas/test/test_tvm_backend.py b/hannah/nas/test/test_tvm_backend.py new file mode 100644 index 00000000..b10e28ac --- /dev/null +++ b/hannah/nas/test/test_tvm_backend.py @@ -0,0 +1,75 @@ +# +# Copyright (c) 2023 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys + +from pytest import importorskip + +importorskip("tvm") + +from tvm import relay +from tvm.relay import transform +from tvm.relay.dataflow_pattern import is_op, wildcard +from tvm.relay.op.contrib.register import get_pattern_table +from tvm.relay.testing.resnet import get_workload + +from hannah.nas.functional_operators.operators import Conv2d, Tensor +from hannah.nas.hardware_description.backend import TVMBackend +from hannah.nas.hardware_description.testing import get_device +from hannah.nas.parameters import CategoricalParameter + +""" + +Unfied device desriptions for neural architecture search (HANNAH), and compilation (TVM, equality saturation) + +""" + + +def test_simple_device(): + simple_device = get_device("simple_device", relu=False) + + backend = TVMBackend() + tvm_target = backend.generate(simple_device) + + sys.path.append(".") + + os.makedirs("temp", exist_ok=True) + with open("temp/pattern_table.py", "w") as f: + f.write(tvm_target) + + import temp.pattern_table as pt + + table = pt.pattern_table() + print(table) + + print(get_pattern_table("simple_device")) + + mod, params = get_workload() + + mod = transform.MergeComposite(table)(mod) + print(mod) + + # mod = transform.AnnotateTarget(simple_device.name, False)(mod) + # mod = transform.MergeCompilerRegions()(mod) + # mod = transform.PartitionGraph()(mod) + + +if __name__ == "__main__": + test_simple_device() diff --git a/hannah/nas/test/test_z3.py b/hannah/nas/test/test_z3.py deleted file mode 100644 index 657f3e13..00000000 --- a/hannah/nas/test/test_z3.py +++ /dev/null @@ -1,179 +0,0 @@ -from z3 import * -from hannah.nas.dataflow.dataflow_graph import DataFlowGraph -from hannah.nas.dataflow.op_type import OpType -from hannah.nas.dataflow.tensor import Tensor -from hannah.nas.expressions.placeholder import DefaultInt -from hannah.nas.ops import batched_image_tensor, weight_tensor -from hannah.nas.parameters.parameters import IntScalarParameter -from hannah.nas.test.network import residual_block - - -class ConstraintModel: - def __init__(self) -> None: - self.solver = Solver() - self.vars = {} - - def model(self): - if self.solver.check(): - return self.solver.model() - else: - return None - - def check(self): - return self.solver.check() - - def traverse(self, graph): - queue = [graph] - visited = [graph] - - input_dict = {} - output_dict = {} - enter_dict = {} - - while queue: - graph = queue.pop(-1) - - if isinstance(graph, DataFlowGraph): - h_in = Int(graph.id + '.h_in') - input_dict[graph] = h_in - - for e in graph.enter: - enter_dict[e] = h_in - h_out = Int(graph.id + '.h_out') - - output_dict[graph.output] = h_out - - self.solver.add(h_in > 0) - self.solver.add(h_out > 0) - - for user in graph.users: - if user in input_dict and not graph == user.output: - self.solver.add(input_dict[user] == h_out) - - if graph in output_dict: - self.solver.add(h_out == output_dict[graph]) - - if graph in enter_dict: - self.solver.add(h_in == enter_dict[graph]) - - if graph.output not in visited: - queue.append(graph.output) - visited.append(graph.output) - elif isinstance(graph, OpType): - - h_in = Int(graph.id + '.h_in') - input_dict[graph] = h_in - h_out = Int(graph.id + '.h_out') - - self.solver.add(h_in > 0) - self.solver.add(h_out > 0) - - if graph.name == "Conv2d": - s = Int(graph.id + '.stride') - self.vars[graph.id + '.stride'] = s - k = Int(graph.id + '.kernel_size') - p = Int(graph.id + '.padding') - d = Int(graph.id + '.dilation') - - self.solver.add(s > 0) - self.solver.add(s <= 2) - self.solver.add(k > 0) - self.solver.add(k <= 9) - - self.solver.add(d == 1) - self.solver.add(k % 2 != 0) - self.solver.add(k / 2 == p) - # self.solver.add(p == 2) - - self.solver.add(h_out == ((h_in + p * 2 - d * (k - 1) - 1) / s) + 1) - else: - self.solver.add(h_out == h_in) - - # connect the output of this node with the input - # of the following (user) node - for user in graph.users: - if user in input_dict: - if hasattr(user, 'output') and graph == user.output: - break - self.solver.add(input_dict[user] == h_out) - - if graph in output_dict: - self.solver.add(h_out == output_dict[graph]) - - if graph in enter_dict: - self.solver.add(h_in == enter_dict[graph]) - - for o in graph.operands: - if o not in visited: - queue.append(o) - visited.append(o) - elif isinstance(graph, Tensor): - dim = Int("dim") - if 'input' in graph.name: - self.solver.add(dim == int(graph['h'].size.evaluate())) - - for user in graph.users: - if user in input_dict: - self.solver.add(input_dict[user] == dim) - # elif 'weight' in graph.name: - # self.solver.add(dim == int(graph['kh'].size.evaluate())) - - print() - - - - -def test_z3(): - p = Int('p') - s = Int('s') - k = Int('k') - d = Int('d') - h_0 = Int('h_0') - h_1 = Int('h_1') - - solver = Solver() - solver.add(h_1 == (((h_0 + p * 2 - d * (k - 1) - 1) / s) + 1)) - solver.add(h_0 == 32) - # solver.add(s >= 1) - solver.add(s == 2) - - solver.add(d == 1) - solver.add(k > 0) - solver.add(k <= 9) - solver.add(k % 2 != 0) - solver.add(p == k / 2) - - pb = Int('pb') - sb = Int('sb') - kb = Int('kb') - db = Int('db') - hb_0 = Int('hb_0') - hb_1 = Int('hb_1') - - solver.add(hb_1 == (((hb_0 + pb * 2 - db * (kb - 1) - 1) / sb) + 1)) - solver.add(hb_0 == 32) - # solver.add(sb >= 1) - solver.add(sb == 2) - - solver.add(db == 1) - solver.add(kb > 0) - solver.add(kb <= 9) - solver.add(kb % 2 != 0) - solver.add(pb == kb / 2) - - solver.add(hb_1 == h_1) - - print() - -def test_traversal(): - input = batched_image_tensor(shape=(1,3, 32, 32), name='input') - out = residual_block(input, stride=IntScalarParameter(1, 2), output_channel=DefaultInt(64)) - - cm = ConstraintModel() - cm.traverse(out) - # TODO: INPUT AND OUPUTS OF GRAPHS - print() - -if __name__ == '__main__': - # test_z3() - test_traversal() \ No newline at end of file diff --git a/hannah/nn/qat.py b/hannah/nn/qat.py index 6fe88b42..b4ac7454 100644 --- a/hannah/nn/qat.py +++ b/hannah/nn/qat.py @@ -1,8 +1,8 @@ # -# Copyright (c) 2022 University of Tübingen. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +# See https://github.com/ekut-es/hannah for further info. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,14 +31,18 @@ import torch.nn.functional as F import torch.nn.intrinsic as nni from torch import Tensor +from torch.ao.quantization.qconfig import QConfig as QConfigAO from torch.nn import init from torch.nn.modules.utils import _pair, _single from torch.nn.parameter import Parameter -from hannah.quantization.qconfig import QConfig +from hannah.quantization.qconfig import QConfig as QConfigHannah from . import quantized as q +QConfig = Union[QConfigHannah, QConfigAO] + + _BN_CLASS_MAP = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d, 3: nn.BatchNorm3d} @@ -87,7 +91,6 @@ def _real_conv_forward( class _ConvBnNd( nn.modules.conv._ConvNd, _ConvForwardMixin ): # pytype: disable=module-attr - _version = 2 def __init__( @@ -196,8 +199,6 @@ def scaled_weight(self) -> Tensor: scale_factor = self.scale_factor weight_shape = [1] * len(self.weight.shape) weight_shape[0] = -1 - bias_shape = [1] * len(self.weight.shape) - bias_shape[1] = -1 scaled_weight = self.weight_fake_quant( self.weight * scale_factor.reshape(weight_shape) ) @@ -330,7 +331,7 @@ def from_float(cls, mod): Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ - assert type(mod) == cls._FLOAT_MODULE, ( + assert type(mod) is cls._FLOAT_MODULE, ( "qat." + cls.__name__ + ".from_float only works for " @@ -475,7 +476,6 @@ def __init__( qconfig: Union[QConfig, QConfig] = None, out_quant: bool = True, ) -> None: - super().__init__( in_channels, out_channels, @@ -954,7 +954,7 @@ def from_float(cls, mod): Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ - assert type(mod) == cls._FLOAT_MODULE, ( + assert type(mod) is cls._FLOAT_MODULE, ( " qat." + cls.__name__ + ".from_float only works for " @@ -962,7 +962,7 @@ def from_float(cls, mod): ) assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert mod.qconfig, "Input float module must have a valid qconfig" - if type(mod) == LinearReLU: + if type(mod) is LinearReLU: mod = mod[0] qconfig = mod.qconfig @@ -1034,7 +1034,7 @@ def from_float(cls, mod): Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ - assert type(mod) == cls._FLOAT_MODULE, ( + assert type(mod) is cls._FLOAT_MODULE, ( " qat." + cls.__name__ + ".from_float only works for " @@ -1042,7 +1042,7 @@ def from_float(cls, mod): ) assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert mod.qconfig, "Input float module must have a valid qconfig" - if type(mod) == LinearReLU: + if type(mod) is LinearReLU: mod = mod[0] qconfig = mod.qconfig @@ -1088,7 +1088,7 @@ def from_float(cls, mod): Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ - assert type(mod) == cls._FLOAT_MODULE, ( + assert type(mod) is cls._FLOAT_MODULE, ( " qat." + cls.__name__ + ".from_float only works for " diff --git a/hannah/quantization/qconfig.py b/hannah/quantization/qconfig.py index 3d72665b..03329d89 100644 --- a/hannah/quantization/qconfig.py +++ b/hannah/quantization/qconfig.py @@ -94,6 +94,7 @@ def __init__( self.debug = debug def quantize(self, x: Union[Tensor, Parameter]) -> Tensor: + # print("Pre-Quantized:", x.min(), x.max(), x.mean(), x.std(), x.shape) if self.debug: print("x", x) x = x / self.scale @@ -102,6 +103,8 @@ def quantize(self, x: Union[Tensor, Parameter]) -> Tensor: print("rounded", x) x = torch.clamp(x, self.min, self.max) + # print("Post-Quantized:", x.min(), x.max(), x.mean(), x.std(), x.shape) + return x def __call__(self, x: Union[Tensor, Parameter]) -> Tensor: diff --git a/hannah/train.py b/hannah/train.py index 542f5995..1d1ffe33 100644 --- a/hannah/train.py +++ b/hannah/train.py @@ -22,7 +22,7 @@ import os import shutil from pathlib import Path -from typing import Any, Dict, List, Mapping, Type, Union +from typing import Any, Dict, List, Mapping, Sequence, Type, Union import pandas as pd import tabulate @@ -198,7 +198,7 @@ def train( backend_output.append(profile_backend(config, lit_module)) @rank_zero_only - def summarize_stage(stage: str, output: Mapping["str", float]) -> None: + def summarize_stage(stage: str, output: Sequence[Mapping["str", float]]) -> None: if not output: return result_frame = pd.DataFrame.from_dict(output) diff --git a/hannah/utils/__init__.py b/hannah/utils/__init__.py index 0eae0fc4..62e947c3 100644 --- a/hannah/utils/__init__.py +++ b/hannah/utils/__init__.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from . import registry from .imports import lazy_import from .tuple import pair, quadruple, single, triple from .utils import ( @@ -46,4 +47,5 @@ "pair", "triple", "quadruple", + "registry", ] diff --git a/hannah/utils/logger.py b/hannah/utils/logger.py index b85048c1..dfbfbe71 100644 --- a/hannah/utils/logger.py +++ b/hannah/utils/logger.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023 Hannah contributors. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. # See https://github.com/ekut-es/hannah for further info. @@ -30,10 +30,13 @@ from pytorch_lightning.loggers import Logger from torch import Tensor +_PATH = Union[str, pathlib.Path] + import fsspec log = logging.getLogger(__name__) + def _is_dir(fs, path, strict=False): return fs.isdir(path) or (not strict and fs.exists(path) and not fs.isfile(path)) @@ -62,7 +65,7 @@ class JSONLogger(Logger): def __init__( self, - root_dir: pathlib.Path, + root_dir: _PATH, name: str = "lightning_logs", version: Optional[Union[int, str]] = None, prefix: str = "", diff --git a/hannah/utils/registry.py b/hannah/utils/registry.py new file mode 100644 index 00000000..f9d4a726 --- /dev/null +++ b/hannah/utils/registry.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +class Registry: + def __init__(self, name=""): + self._name = name + self.registered_classes = {} + + def register(self, cls): + self.registered_classes[cls.__name__] = cls + + def instantiate(self, name, *args, **kwargs): + return self.registered_classes[name](*args, **kwargs) + + def __iter__(self): + return iter(self.registered_classes.values()) + + def __len__(self): + return len(self.registered_classes) + + def __getitem__(self, key): + return self.registered_classes[key] + + def __contains__(self, key): + return key in self.registered_classes + + def __repr__(self): + return f"Registry({self._name}, {self.registered_classes})" + + def __str__(self): + return f"Registry({self._name}, {self.registered_classes})" + + def keys(self): + return self.registered_classes.keys() + + def values(self): + return self.registered_classes.values() + + def items(self): + return self.registered_classes.items() + + def get(self, key, default=None): + return self.registered_classes.get(key, default) diff --git a/poetry.lock b/poetry.lock index 4b3e2ad2..5bac0449 100644 --- a/poetry.lock +++ b/poetry.lock @@ -35,87 +35,87 @@ files = [ [[package]] name = "aiohttp" -version = "3.11.2" +version = "3.11.4" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.9" files = [ - {file = "aiohttp-3.11.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:783741f534c14957fbe657d62a34b947ec06db23d45a2fd4a8aeb73d9c84d7e6"}, - {file = "aiohttp-3.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:435f7a08d8aa42371a94e7c141205a9cb092ba551084b5e0c57492e6673601a3"}, - {file = "aiohttp-3.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c681f34e2814bc6e1eef49752b338061b94a42c92734d0be9513447d3f83718c"}, - {file = "aiohttp-3.11.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73a664478ae1ea011b5a710fb100b115ca8b2146864fa0ce4143ff944df714b8"}, - {file = "aiohttp-3.11.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1d06c8fd8b453c3e553c956bd3b8395100401060430572174bb7876dd95ad49"}, - {file = "aiohttp-3.11.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b1f4844909321ef2c1cee50ddeccbd6018cd8c8d1ddddda3f553e94a5859497"}, - {file = "aiohttp-3.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdc6f8dce09281ae534eaf08a54f0d38612398375f28dad733a8885f3bf9b978"}, - {file = "aiohttp-3.11.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d2d942421cf3a1d1eceae8fa192f1fbfb74eb9d3e207d35ad2696bd2ce2c987c"}, - {file = "aiohttp-3.11.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:08ebe7a1d6c1e5ca766d68407280d69658f5f98821c2ba6c41c63cabfed159af"}, - {file = "aiohttp-3.11.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:2793d3297f3e49015140e6d3ea26142c967e07998e2fb00b6ee8d041138fbc4e"}, - {file = "aiohttp-3.11.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4a23475d8d5c56e447b7752a1e2ac267c1f723f765e406c81feddcd16cdc97bc"}, - {file = "aiohttp-3.11.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:556564d89e2f4a6e8fe000894c03e4e84cf0b6cfa5674e425db122633ee244d1"}, - {file = "aiohttp-3.11.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:57993f406ce3f114b2a6756d7809be3ffd0cc40f33e8f8b9a4aa1b027fd4e3eb"}, - {file = "aiohttp-3.11.2-cp310-cp310-win32.whl", hash = "sha256:177b000efaf8d2f7012c649e8aee5b0bf488677b1162be5e7511aa4f9d567607"}, - {file = "aiohttp-3.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:ff5d22eece44528023254b595c670dfcf9733ac6af74c4b6cb4f6a784dc3870c"}, - {file = "aiohttp-3.11.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:50e0aee4adc9abcd2109c618a8d1b2c93b85ac277b24a003ab147d91e068b06d"}, - {file = "aiohttp-3.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9aa4e68f1e4f303971ec42976fb170204fb5092de199034b57199a1747e78a2d"}, - {file = "aiohttp-3.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d84930b4145991214602372edd7305fc76b700220db79ac0dd57d3afd0f0a1ca"}, - {file = "aiohttp-3.11.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4ec8afd362356b8798c8caa806e91deb3f0602d8ffae8e91d2d3ced2a90c35e"}, - {file = "aiohttp-3.11.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fb0544a0e8294a5a5e20d3cacdaaa9a911d7c0a9150f5264aef36e7d8fdfa07e"}, - {file = "aiohttp-3.11.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7b0a1618060e3f5aa73d3526ca2108a16a1b6bf86612cd0bb2ddcbef9879d06"}, - {file = "aiohttp-3.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d878a0186023ac391861958035174d0486f3259cabf8fd94e591985468da3ea"}, - {file = "aiohttp-3.11.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e33a7eddcd07545ccf5c3ab230f60314a17dc33e285475e8405e26e21f02660"}, - {file = "aiohttp-3.11.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4d7fad8c456d180a6d2f44c41cfab4b80e2e81451815825097db48b8293f59d5"}, - {file = "aiohttp-3.11.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8d954ba0eae7f33884d27dc00629ca4389d249eb8d26ca07c30911257cae8c96"}, - {file = "aiohttp-3.11.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:afa55e863224e664a782effa62245df73fdfc55aee539bed6efacf35f6d4e4b7"}, - {file = "aiohttp-3.11.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:10a5f91c319d9d4afba812f72984816b5fcd20742232ff7ecc1610ffbf3fc64d"}, - {file = "aiohttp-3.11.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6e8e19a80ba194db5c06915a9df23c0c06e0e9ca9a4db9386a6056cca555a027"}, - {file = "aiohttp-3.11.2-cp311-cp311-win32.whl", hash = "sha256:9c8d1db4f65bbc9d75b7b271d68fb996f1c8c81a525263862477d93611856c2d"}, - {file = "aiohttp-3.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:2adb967454e10e69478ba4a8d8afbba48a7c7a8619216b7c807f8481cc66ddfb"}, - {file = "aiohttp-3.11.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:f833a80d9de9307d736b6af58c235b17ef7f90ebea7b9c49cd274dec7a66a2f1"}, - {file = "aiohttp-3.11.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:382f853516664d2ebfc75dc01da4a10fdef5edcb335fe7b45cf471ce758ecb18"}, - {file = "aiohttp-3.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d3a2bcf6c81639a165da93469e1e0aff67c956721f3fa9c0560f07dd1e505116"}, - {file = "aiohttp-3.11.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de3b4d5fb5d69749104b880a157f38baeea7765c93d9cd3837cedd5b84729e10"}, - {file = "aiohttp-3.11.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0a90a0dc4b054b5af299a900bf950fe8f9e3e54322bc405005f30aa5cacc5c98"}, - {file = "aiohttp-3.11.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32334f35824811dd20a12cc90825d000e6b50faaeaa71408d42269151a66140d"}, - {file = "aiohttp-3.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cba0b8d25aa2d450762f3dd6df85498f5e7c3ad0ddeb516ef2b03510f0eea32"}, - {file = "aiohttp-3.11.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bbb2dbc2701ab7e9307ca3a8fa4999c5b28246968e0a0202a5afabf48a42e22"}, - {file = "aiohttp-3.11.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:97fba98fc5d9ccd3d33909e898d00f2494d6a9eec7cbda3d030632e2c8bb4d00"}, - {file = "aiohttp-3.11.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0ebdf5087e2ce903d8220cc45dcece90c2199ae4395fd83ca616fcc81010db2c"}, - {file = "aiohttp-3.11.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:122768e3ae9ce74f981b46edefea9c6e5a40aea38aba3ac50168e6370459bf20"}, - {file = "aiohttp-3.11.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:5587da333b7d280a312715b843d43e734652aa382cba824a84a67c81f75b338b"}, - {file = "aiohttp-3.11.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:85de9904bc360fd29a98885d2bfcbd4e02ab33c53353cb70607f2bea2cb92468"}, - {file = "aiohttp-3.11.2-cp312-cp312-win32.whl", hash = "sha256:b470de64d17156c37e91effc109d3b032b39867000e2c126732fe01d034441f9"}, - {file = "aiohttp-3.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:3f617a48b70f4843d54f52440ea1e58da6bdab07b391a3a6aed8d3b311a4cc04"}, - {file = "aiohttp-3.11.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:5d90b5a3b0f32a5fecf5dd83d828713986c019585f5cddf40d288ff77f366615"}, - {file = "aiohttp-3.11.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d23854e5867650d40cba54d49956aad8081452aa80b2cf0d8c310633f4f48510"}, - {file = "aiohttp-3.11.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:486273d3b5af75a80c31c311988931bdd2a4b96a74d5c7f422bad948f99988ef"}, - {file = "aiohttp-3.11.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9075313f8e41b481e4cb10af405054564b0247dc335db5398ed05f8ec38787e2"}, - {file = "aiohttp-3.11.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44b69c69c194ffacbc50165911cf023a4b1b06422d1e1199d3aea82eac17004e"}, - {file = "aiohttp-3.11.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b339d91ac9060bd6ecdc595a82dc151045e5d74f566e0864ef3f2ba0887fec42"}, - {file = "aiohttp-3.11.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64e8f5178958a9954043bc8cd10a5ae97352c3f2fc99aa01f2aebb0026010910"}, - {file = "aiohttp-3.11.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3129151378f858cdc4a0a4df355c9a0d060ab49e2eea7e62e9f085bac100551b"}, - {file = "aiohttp-3.11.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:14eb6c628432720e41b4fab1ada879d56cfe7034159849e083eb536b4c2afa99"}, - {file = "aiohttp-3.11.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:e57a10aacedcf24666f4c90d03e599f71d172d1c5e00dcf48205c445806745b0"}, - {file = "aiohttp-3.11.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:66e58a2e8c7609a3545c4b38fb8b01a6b8346c4862e529534f7674c5265a97b8"}, - {file = "aiohttp-3.11.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:9b6d15adc9768ff167614ca853f7eeb6ee5f1d55d5660e3af85ce6744fed2b82"}, - {file = "aiohttp-3.11.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2914061f5ca573f990ec14191e6998752fa8fe50d518e3405410353c3f44aa5d"}, - {file = "aiohttp-3.11.2-cp313-cp313-win32.whl", hash = "sha256:1c2496182e577042e0e07a328d91c949da9e77a2047c7291071e734cd7a6e780"}, - {file = "aiohttp-3.11.2-cp313-cp313-win_amd64.whl", hash = "sha256:cccb2937bece1310c5c0163d0406aba170a2e5fb1f0444d7b0e7fdc9bd6bb713"}, - {file = "aiohttp-3.11.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:994cb893936dd2e1803655ae8667a45066bfd53360b148e22b4e3325cc5ea7a3"}, - {file = "aiohttp-3.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3666c750b73ce463a413692e3a57c60f7089e2d9116a2aa5a0f0eaf2ae325148"}, - {file = "aiohttp-3.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6ad9a7d2a3a0f235184426425f80bd3b26c66b24fd5fddecde66be30c01ebe6e"}, - {file = "aiohttp-3.11.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c979fc92aba66730b66099cd5becb42d869a26c0011119bc1c2478408a8bf7a"}, - {file = "aiohttp-3.11.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:766d0ebf8703d28f854f945982aa09224d5a27a29594c70d921c43c3930fe7ac"}, - {file = "aiohttp-3.11.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79efd1ee3827b2f16797e14b1e45021206c3271249b4d0025014466d416d7413"}, - {file = "aiohttp-3.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d6e069b882c1fdcbe5577dc4be372eda705180197140577a4cddb648c29d22e"}, - {file = "aiohttp-3.11.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5e9a766c346b2ed7e88937919d84ed64b4ef489dad1d8939f806ee52901dc142"}, - {file = "aiohttp-3.11.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2b02a68b9445c70d7f5c8b578c5f5e5866b1d67ca23eb9e8bc8658ae9e3e2c74"}, - {file = "aiohttp-3.11.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:374baefcb1b6275f350da605951f5f02487a9bc84a574a7d5b696439fabd49a3"}, - {file = "aiohttp-3.11.2-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:d2f991c18132f3e505c108147925372ffe4549173b7c258cf227df1c5977a635"}, - {file = "aiohttp-3.11.2-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:34f37c59b12bc3afc52bab6fcd9cd3be82ff01c4598a84cbea934ccb3a9c54a0"}, - {file = "aiohttp-3.11.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:33af11eca7bb0f5c6ffaf5e7d9d2336c2448f9c6279b93abdd6f3c35f9ee321f"}, - {file = "aiohttp-3.11.2-cp39-cp39-win32.whl", hash = "sha256:83a70e22e0f6222effe7f29fdeba6c6023f9595e59a0479edacfbd7de4b77bb7"}, - {file = "aiohttp-3.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:c28c1677ea33ccb8b14330560094cc44d3ff4fad617a544fd18beb90403fe0f1"}, - {file = "aiohttp-3.11.2.tar.gz", hash = "sha256:68d1f46f9387db3785508f5225d3acbc5825ca13d9c29f2b5cce203d5863eb79"}, + {file = "aiohttp-3.11.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a60f8206818e3582c999c999c799ab068e14f1870ade47d1fe8536dbfd88010b"}, + {file = "aiohttp-3.11.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e5786e5926f888ce3a996d38d9c9b8f9306f399edb1f1ca3ce7760dab9b1043c"}, + {file = "aiohttp-3.11.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:262e45dbd7f1829bcb024259f65b2cf69d1ef5b37626af6955a1c487613aeb3a"}, + {file = "aiohttp-3.11.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:696adff3594bd449e0fe287441062bdc6f5300928426275b39ed27884ba083a7"}, + {file = "aiohttp-3.11.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6dd1411ecfc070af4df129e81fe42c799d95d81c29c22d2c3e4341d974c38f1a"}, + {file = "aiohttp-3.11.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:06defa9017ab50d215446ebbee294e07eb2fcee72d9a909a08192cfacbd43a08"}, + {file = "aiohttp-3.11.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4bc936d10b8fa3f2aa66e59e034085208b588442263400ddb042703d0db99421"}, + {file = "aiohttp-3.11.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:769457243dc4bc902d376cd14c5c7ec234a4faadb4f283dc2738f004cce9a9e1"}, + {file = "aiohttp-3.11.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a360c18b2cb391fec9585ba1efc55150e2fbc6100308113117dfea521e810d8"}, + {file = "aiohttp-3.11.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3e9fd9c11299d6d230eb2669fd1ed0238d33970e36b495b0432ace7f157fc931"}, + {file = "aiohttp-3.11.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:0ccbe8ece8a7796ef41b86a3240034c5918d9b324c2ae48fa0be33565e297c64"}, + {file = "aiohttp-3.11.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:9a8b6b3c788a8a6f88f5ce23d729cfde7a2ccebbeb09db0822ef266de0445a27"}, + {file = "aiohttp-3.11.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cbe3e356523d0b336543996f92a0e65f760be82447db21c95c60392c8075ff5c"}, + {file = "aiohttp-3.11.4-cp310-cp310-win32.whl", hash = "sha256:a54424050d1eb36edfef913b1bc8552d52a37864c0ea7df3e1e764663e11053a"}, + {file = "aiohttp-3.11.4-cp310-cp310-win_amd64.whl", hash = "sha256:a51f983d91edae7777b5a2af8e5d83224ba01284502c6874a17647ad6cbf0211"}, + {file = "aiohttp-3.11.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:89261fee04715083ef3b5a0d222b094021793c1728b8ff21da361c79f6384095"}, + {file = "aiohttp-3.11.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4ef6eb1367046fb22085f10c5c84ea2efd0d836ad8088306d652ab1d743faf9e"}, + {file = "aiohttp-3.11.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d68bb99bc6a4b0a3eceb95a246f5a0262e600e094b5178c2b1ab0f4bcbae6729"}, + {file = "aiohttp-3.11.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a550b4ff70d06c15057d75ddad89a3e7c496e0609d28c567c20b61cd1265c0a6"}, + {file = "aiohttp-3.11.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9b41e0fb3b415beccd6d0c6e5f3ee34b7952cd76120a1db3e45507b83dc5ef81"}, + {file = "aiohttp-3.11.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8feffa8396724116be5bc05bf4fcba0c738cbe908c82a95f71371e32b28cd2ca"}, + {file = "aiohttp-3.11.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1dd5b7947e23a08c70d4c1924809b91211f14136ffd13d303dc487913cfebfeb"}, + {file = "aiohttp-3.11.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ab5c6a521b156edef13a57a6d524903c547573ff8101e3d1bbe9ee1b97267973"}, + {file = "aiohttp-3.11.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:010bc9606f798eda8ef071759c7b163893071502bcaedc7d5dc49f9d8f12e553"}, + {file = "aiohttp-3.11.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e7d182164aebad4e2faf2742ee7486d4af73d933461adbd8f183ac9b1837323c"}, + {file = "aiohttp-3.11.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:88e681c0d17bb285d2ccbb73ae77ef86339b632ee7967044c2284411120b9730"}, + {file = "aiohttp-3.11.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0d2cea21ec05b832e9f6a2390b23d32ce5575f6cf4812bd171d4493f59c101fe"}, + {file = "aiohttp-3.11.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:635397b5b4de2397f8136f8fd15c8ebee560e36473195c7aa992ffb8e46acdd3"}, + {file = "aiohttp-3.11.4-cp311-cp311-win32.whl", hash = "sha256:cb2d5a24586b508f658ddd710f7d4b7e4f5656cb5d569aeb1f432c1c3704347a"}, + {file = "aiohttp-3.11.4-cp311-cp311-win_amd64.whl", hash = "sha256:ee081375d10fa2f3f7b0d050c8b9c1ae23190e1d9be256035bf8a41059c4df3a"}, + {file = "aiohttp-3.11.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:5cd60673be31449c63f59886f3581478bbdfaddd87e7394a4d73ad134d9be9b9"}, + {file = "aiohttp-3.11.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4ff6105856ae688b29d5daaede1256f5e02e9d5cb3059f8f5ef55d975c2e6992"}, + {file = "aiohttp-3.11.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b169507c98b924fd68b82ae366c285daf6d22456835294c329c3226d61e1f69d"}, + {file = "aiohttp-3.11.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec84106c8b7ff347be06bf579c298a23b6d1d2225c57273a8cd502f257125d4"}, + {file = "aiohttp-3.11.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:03d53b0888f984f4f0c216a37577ee7e7b1ed1dac89cdd2fde61bf2ccb32009b"}, + {file = "aiohttp-3.11.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:822dedad93947fcb1096cc41ee8fd32e9f652777561a37c740e5335699f01cea"}, + {file = "aiohttp-3.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aef239c307f3a3f830933d612c0aef4ad4b3aa9ce5233a0954262a00f5c379f1"}, + {file = "aiohttp-3.11.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49eb5a0338f141ef32299d48f1415486f47953d37b0c7fa6d778b73b66f3a7e2"}, + {file = "aiohttp-3.11.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7be4efe60e9bddf78ee165a296e80170147282081e1366f0580cf4cc0fb1182f"}, + {file = "aiohttp-3.11.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:66e83a9a1131f0060aaedcc57f1a7e489898b6c3607eededccc7a9f80b95bdb4"}, + {file = "aiohttp-3.11.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a7986fb988314fd2225c1ecab45fd457e1f2c097dcc3c0aacd2a7aec7486beb6"}, + {file = "aiohttp-3.11.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a34c30e1461da3a69c5bdcfce44418b6f969e1e68ebf367edfa5eaab380abf7a"}, + {file = "aiohttp-3.11.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cb4c676ab99ca2dd231928d481e19cd540155dff36e70e613179c4927bd520b8"}, + {file = "aiohttp-3.11.4-cp312-cp312-win32.whl", hash = "sha256:d40d9a740053cb7fef72442fa7bd699060ff4c710971ebdb8dd7c8b36417570f"}, + {file = "aiohttp-3.11.4-cp312-cp312-win_amd64.whl", hash = "sha256:365df6cf2ad144479ba0e0b58abdc5276923676d34da4c1c45613a80d2aac130"}, + {file = "aiohttp-3.11.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:f307632f3eaa676f8c2f5df11e4c00ad47dfa79b06cb2fa39156a4e9c6821bdb"}, + {file = "aiohttp-3.11.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:cc2d64b1747efa183ced57b6bce53c9ea8e16e53419e389051b2a214ad0ed051"}, + {file = "aiohttp-3.11.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f37ece590451ecffc815f2eb41f07191d1a31a0404361d1ae2ed532e05c86da4"}, + {file = "aiohttp-3.11.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b10b316413c80a4dcc5228c092a8d019e4b75d4efbca8988cb5b67ae9fa56881"}, + {file = "aiohttp-3.11.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:beaed1b2d03033dd301a7b67430f03c8255d6856a269c20995a0292de596519e"}, + {file = "aiohttp-3.11.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:518578d6821c942362daa14a56f26b739abeede6e408b0b83e27dfcde17730f7"}, + {file = "aiohttp-3.11.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1e09bc44a1abbd96f55d15330d6cab80459cb8b06a0b656efd712ce47a3710d"}, + {file = "aiohttp-3.11.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ae8480148d696dae49126e97568333fc01493069ad46a94b82f69c7a33197ea"}, + {file = "aiohttp-3.11.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b71aab89800fa2eaeb28923ee05e7e56c28dab4ebdba524db06e963431bf6192"}, + {file = "aiohttp-3.11.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:821c9c640d3dc456c6a7b005e38bc5af05326b6a08ce91a068719934d108a1bb"}, + {file = "aiohttp-3.11.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d60255f3ed71aa14a2e75383543ca31bd362fdc7f0d2eafc060d85a9051598df"}, + {file = "aiohttp-3.11.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:9788781f57fb732426ae74b9955b899e677ce42b848e60a11be29358fb20c976"}, + {file = "aiohttp-3.11.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:94acecf2eee13a45f627ed25a28f5a7f2db66b90ff94cd7a1e9cc1ad32cddd43"}, + {file = "aiohttp-3.11.4-cp313-cp313-win32.whl", hash = "sha256:d0fd6510c6d67d08ec80d9ba10cd340a8cfb0dd33436c858ed38d4564abb27c7"}, + {file = "aiohttp-3.11.4-cp313-cp313-win_amd64.whl", hash = "sha256:474f7266a61d1c3218ef4ec0325747884b2d5a13fab5bff5dd3b55d9c849406a"}, + {file = "aiohttp-3.11.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cfe8646a24856624c1eb7649da99333f0d7e75d9cf7c155ea870957d24b7c63c"}, + {file = "aiohttp-3.11.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e69d9869df50dd591228c62fbb3923d6124517d6bfc47a804492813888b497be"}, + {file = "aiohttp-3.11.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:eb4f1fe110332651c00d2df160978cf1be70896ed9e612ff7c7e67955091b2c4"}, + {file = "aiohttp-3.11.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d97668595bf03299148ea968fed2195cc76ad063aeec8161731aa6a5dbc2f675"}, + {file = "aiohttp-3.11.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c0b3378dc294ad6ec6c038ed57164165e0b83ef5f61eee72f6eefccd7df34b8"}, + {file = "aiohttp-3.11.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0898a77298dc24eef368511d98e551e0b2db293fa9b40c982f4d5ab4d8d2a3a"}, + {file = "aiohttp-3.11.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ecdf43ddde709c336a655c8b3858c56af8f7402de2572001a5a99f7bebf2f78"}, + {file = "aiohttp-3.11.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12bf9c139dfa004b65d2d71906abc593dcafe78a508f33d56c1ca9d87b18337f"}, + {file = "aiohttp-3.11.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2d978a95e4b58ef1fd937fbe347ab397c79ba24e17912595b54faafb88b9b937"}, + {file = "aiohttp-3.11.4-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1e32517c01905e0f4e665c3f3a495868ad996a32c243fcd917587d740253d589"}, + {file = "aiohttp-3.11.4-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:4275160583df18158e0d6789797ad314a14ae611b98933fbe7d7a1c3dcc6bad4"}, + {file = "aiohttp-3.11.4-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:1ff7afc3c461bd9217e2b8a90ddbe5edd94687d5a331c4ae6166dca5876d1a4b"}, + {file = "aiohttp-3.11.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:83bd5aa621b732a0ca1aa3490abd2b010247c4677371a804431935aeedf26e74"}, + {file = "aiohttp-3.11.4-cp39-cp39-win32.whl", hash = "sha256:542a4610571b0affc6e13dda9357235f5f1f2ad9859acc69b188eb53901292d6"}, + {file = "aiohttp-3.11.4-cp39-cp39-win_amd64.whl", hash = "sha256:a468b1b9d5499cbfd0411f5d28adbe651c90508540fdaefb4b7a2171a837a88d"}, + {file = "aiohttp-3.11.4.tar.gz", hash = "sha256:9d95cce8bb010597b3f2217155befe4708e0538d3548aa08d640ebf54e3f57cb"}, ] [package.dependencies] @@ -453,17 +453,6 @@ d = ["aiohttp (>=3.10)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] -[[package]] -name = "cerberus" -version = "1.3.5" -description = "Lightweight, extensible schema and data validation tool for Pythondictionaries." -optional = true -python-versions = "*" -files = [ - {file = "Cerberus-1.3.5-py3-none-any.whl", hash = "sha256:7649a5815024d18eb7c6aa5e7a95355c649a53aacfc9b050e9d0bf6bfa2af372"}, - {file = "Cerberus-1.3.5.tar.gz", hash = "sha256:81011e10266ef71b6ec6d50e60171258a5b134d69f8fb387d16e4936d0d47642"}, -] - [[package]] name = "certifi" version = "2024.8.30" @@ -2619,42 +2608,52 @@ test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"] [[package]] name = "libcst" -version = "1.5.0" +version = "1.5.1" description = "A concrete syntax tree with AST-like properties for Python 3.0 through 3.13 programs." optional = false python-versions = ">=3.9" files = [ - {file = "libcst-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:23d0e07fd3ed11480f8993a1e99d58a45f914a711b14f858b8db08ae861a8a34"}, - {file = "libcst-1.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d92c5ae2e2dc9356ad7e3d05077d9b7e5065423e45788fd86729c88729e45c6e"}, - {file = "libcst-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96adc45e96476350df6b8a5ddbb1e1d6a83a7eb3f13087e52eb7cd2f9b65bcc7"}, - {file = "libcst-1.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d5978fd60c66794bb60d037b2e6427ea52d032636e84afce32b0f04e1cf500a"}, - {file = "libcst-1.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6502aeb11412afc759036160c686be1107eb5a4466db56b207c786b9b4da7c4"}, - {file = "libcst-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cccfc0a78e110c0d0a9d2c6fdeb29feb5274c9157508a8baef7edf352420f6d"}, - {file = "libcst-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:585b3aa705b3767d717d2100935d8ef557275ecdd3fac81c3e28db0959efb0ea"}, - {file = "libcst-1.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8935dd3393e30c2f97344866a4cb14efe560200e232166a8db1de7865c2ef8b2"}, - {file = "libcst-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc80ea16c7d44e38f193e4d4ef7ff1e0ba72d8e60e8b61ac6f4c87f070a118bd"}, - {file = "libcst-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02be4aab728261bb76d16e77c9a457884cebb60d09c8edee844de43b0e08aff7"}, - {file = "libcst-1.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a8fcd78be4d9ce3c36d0c5d0bdd384e0c7d5f72970a9e4ebd56070141972b4ad"}, - {file = "libcst-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:52b6aadfe54e3ae52c3b815eaaa17ba4da9ff010d5e8adf6a70697872886dd10"}, - {file = "libcst-1.5.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:83bc5fbe34d33597af1d5ea113dcb9b5dd5afe5a5f4316bac4293464d5e3971a"}, - {file = "libcst-1.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5f10124bf99a0b075eae136ef0ce06204e5f6b8da4596a9c4853a0663e80ddf3"}, - {file = "libcst-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48e581af6127c5af4c9f483e5986d94f0c6b2366967ee134f0a8eba0aa4c8c12"}, - {file = "libcst-1.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7dba93cca0a5c6d771ed444c44d21ce8ea9b277af7036cea3743677aba9fbbb8"}, - {file = "libcst-1.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80b5c4d87721a7bab265c202575809b810815ab81d5e2e7a5d4417a087975840"}, - {file = "libcst-1.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:b48bf71d52c1e891a0948465a94d9817b5fc1ec1a09603566af90585f3b11948"}, - {file = "libcst-1.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:88520b6dea59eaea0cae80f77c0a632604a82c5b2d23dedb4b5b34035cbf1615"}, - {file = "libcst-1.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:208ea92d80b2eeed8cbc879d5f39f241582a5d56b916b1b65ed2be2f878a2425"}, - {file = "libcst-1.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4592872aaf5b7fa5c2727a7d73c0985261f1b3fe7eff51f4fd5b8174f30b4e2"}, - {file = "libcst-1.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2788b2b5838b78fe15df8e9fa6b6903195ea49b2d2ba43e8f423f6c90e4b69f"}, - {file = "libcst-1.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5b5bcd3a9ba92840f27ad34eaa038acbee195ec337da39536c0a2efbbf28efd"}, - {file = "libcst-1.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:4d6acb0bdee1e55b44c6215c59755ec4693ac01e74bb1fde04c37358b378835d"}, - {file = "libcst-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6453b5a8755a6eee3ad67ee246f13a8eac9827d2cfc8e4a269e8bf0393db74bc"}, - {file = "libcst-1.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:40748361f4ea66ab6cdd82f8501c82c29808317ac7a3bd132074efd5fd9bfae2"}, - {file = "libcst-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f71aed85932c2ea92058fd9bbd99a6478bd69eada041c3726b4f4c9af1f564e"}, - {file = "libcst-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60b09abcc2848ab52d479c3a9b71b606d91a941e3779616efd083bb87dbe8ad"}, - {file = "libcst-1.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6fb324ed20f3a725d152df5dba8d80f7e126d9c93cced581bf118a5fc18c1065"}, - {file = "libcst-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:99e7c52150a135d66716b03e00c7b1859a44336dc2a2bf8f9acc164494308531"}, - {file = "libcst-1.5.0.tar.gz", hash = "sha256:8478abf21ae3861a073e898d80b822bd56e578886331b33129ba77fec05b8c24"}, + {file = "libcst-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ab83633e61ee91df575a3838b1e73c371f19d4916bf1816554933235553d41ea"}, + {file = "libcst-1.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b58a49895d95ec1fd34fad041a142d98edf9b51fcaf632337c13befeb4d51c7c"}, + {file = "libcst-1.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d9ec764aa781ef35ab96b693569ac3dced16df9feb40ee6c274d13e86a1472e"}, + {file = "libcst-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99bbffd8596d192bc0e844a4cf3c4fc696979d4e20ab1c0774a01768a59b47ed"}, + {file = "libcst-1.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec6ee607cfe4cc4cc93e56e0188fdb9e50399d61a1262d58229752946f288f5e"}, + {file = "libcst-1.5.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:72132756f985a19ef64d702a821099d4afc3544974662772b44cbc55b7279727"}, + {file = "libcst-1.5.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:40b75bf2d70fc0bc26b1fa73e61bdc46fef59f5c71aedf16128e7c33db8d5e40"}, + {file = "libcst-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:56c944acaa781b8e586df3019374f5cf117054d7fc98f85be1ba84fe810005dc"}, + {file = "libcst-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:db7711a762b0327b581be5a963908fecd74412bdda34db34553faa521563c22d"}, + {file = "libcst-1.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aa524bd012aaae1f485fd44490ef5abf708b14d2addc0f06b28de3e4585c4b9e"}, + {file = "libcst-1.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ffb8135c09e41e8cf710b152c33e9b7f1d0d0b9f242bae0c502eb082fdb1fb"}, + {file = "libcst-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76a8ac7a84f9b6f678a668bff85b360e0a93fa8d7f25a74a206a28110734bb2a"}, + {file = "libcst-1.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89c808bdb5fa9ca02df41dd234cbb0e9de0d2e0c029c7063d5435a9f6781cc10"}, + {file = "libcst-1.5.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:40fbbaa8b839bfbfa5b300623ca2b6b0768b58bbc31b341afbc99110c9bee232"}, + {file = "libcst-1.5.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c7021e3904d8d088c369afc3fe17c279883e583415ef07edacadba76cfbecd27"}, + {file = "libcst-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:f053a5deb6a214972dbe9fa26ecd8255edb903de084a3d7715bf9e9da8821c50"}, + {file = "libcst-1.5.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:666813950b8637af0c0e96b1ca46f5d5f183d2fe50bbac2186f5b283a99f3529"}, + {file = "libcst-1.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b7b58b36022ae77a5a00002854043ae95c03e92f6062ad08473eff326f32efa0"}, + {file = "libcst-1.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eeb13d7c598fe9a798a1d22eae56ab3d3d599b38b83436039bd6ae229fc854d7"}, + {file = "libcst-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5987daff8389b0df60b5c20499ff4fb73fc03cb3ae1f6a746eefd204ed08df85"}, + {file = "libcst-1.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00f3d2f32ee081bad3394546b0b9ac5e31686d3b5cfe4892d716d2ba65f9ec08"}, + {file = "libcst-1.5.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1ff21005c33b634957a98db438e882522febf1cacc62fa716f29e163a3f5871a"}, + {file = "libcst-1.5.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:15697ea9f1edbb9a263364d966c72abda07195d1c1a6838eb79af057f1040770"}, + {file = "libcst-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:cedd4c8336e01c51913113fbf5566b8f61a86d90f3d5cc5b1cb5049575622c5f"}, + {file = "libcst-1.5.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:06a9b4c9b76da4a7399e6f1f3a325196fb5febd3ea59fac1f68e2116f3517cd8"}, + {file = "libcst-1.5.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:940ec4c8db4c2d620a7268d6c83e64ff646e4afd74ae5183d0f0ef3b80e05be0"}, + {file = "libcst-1.5.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fbccb016b1ac6d892344300dcccc8a16887b71bb7f875ba56c0ed6c1a7ade8be"}, + {file = "libcst-1.5.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c615af2117320e9a218083c83ec61227d3547e38a0de80329376971765f27a9e"}, + {file = "libcst-1.5.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02b38fa4d9f13e79fe69e9b5407b9e173557bcfb5960f7866cf4145af9c7ae09"}, + {file = "libcst-1.5.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3334afe9e7270e175de01198f816b0dc78dda94d9d72152b61851c323e4e741e"}, + {file = "libcst-1.5.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:26c804fa8091747128579013df0b5f8e6b0c7904d9c4ee83841f136f53e18684"}, + {file = "libcst-1.5.1-cp313-cp313-win_amd64.whl", hash = "sha256:b5a0d3c632aa2b21c5fa145e4e8dbf86f45c9b37a64c0b7221a5a45caf58915a"}, + {file = "libcst-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1cc7393aaac733e963f0ee00466d059db74a38e15fc7e6a46dddd128c5be8d08"}, + {file = "libcst-1.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bbaf5755be50fa9b35a3d553d1e62293fbb2ee5ce2c16c7e7ffeb2746af1ab88"}, + {file = "libcst-1.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e397f5b6c0fc271acea44579f154b0f3ab36011050f6db75ab00cef47441946"}, + {file = "libcst-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1947790a4fd7d96bcc200a6ecaa528045fcb26a34a24030d5859c7983662289e"}, + {file = "libcst-1.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:697eabe9f5ffc40f76d6d02e693274e0a382826d0cf8183bd44e7407dfb0ab90"}, + {file = "libcst-1.5.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dc06b7c60d086ef1832aebfd31b64c3c8a645adf0c5638d6243e5838f6a9356e"}, + {file = "libcst-1.5.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:19e39cfef4316599ca20d1c821490aeb783b52e8a8543a824972a525322a85d0"}, + {file = "libcst-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:01e01c04f0641188160d3b99c6526436e93a3fbf9783dba970f9885a77ec9b38"}, + {file = "libcst-1.5.1.tar.gz", hash = "sha256:71cb294db84df9e410208009c732628e920111683c2f2b2e0c5b71b98464f365"}, ] [package.dependencies] @@ -5601,127 +5600,127 @@ type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12 [[package]] name = "simsimd" -version = "6.0.5" +version = "6.0.6" description = "Portable mixed-precision BLAS-like vector math library for x86 and ARM" optional = true python-versions = "*" files = [ - {file = "simsimd-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:97e8ee6e6b41a172de32978bf8beadd22e9d6e0d1f80c883984bff5bc5fa2ab0"}, - {file = "simsimd-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c88030432bacb4394a78ba43c57843ed0c2461de2f711bb07022729c7ab3e7e7"}, - {file = "simsimd-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f1edd05ddf2561393233b05d4b58706177436130e79596d3037b0fa802344fe6"}, - {file = "simsimd-6.0.5-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b4b9388e606b5dd286a2205499ca90ba70e01669fc32da91fbc3895a613e327f"}, - {file = "simsimd-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d48936c1c4b260179d7c6803c55e2c9e3dcec1ac06c5c9bd45ee8afd42157d"}, - {file = "simsimd-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:53495c7114d2d1f4683c8bbb2b2a5322b902b6f155ca7e1f4294ad8bb548f994"}, - {file = "simsimd-6.0.5-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:67db32a5188824198e6260d5af806ea92f97e77e0db173498154c3a13e235934"}, - {file = "simsimd-6.0.5-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:70991048164c17c0e222147ea9046de4a8162aa106e7dcbbf049933576643734"}, - {file = "simsimd-6.0.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ecec3b1aa3e2d77aee092d6438b1ba04807cb1e4066852d227db13f93fb95afe"}, - {file = "simsimd-6.0.5-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:11f4a0c1b3a4689011096d906783f697aee6a72d0665e2261b38b204534c2cf9"}, - {file = "simsimd-6.0.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:6cbfa0b2694db8273ab88c90e224d4340be4da5fc8661778ad52afe592ccd7a3"}, - {file = "simsimd-6.0.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:0f69767bd1342ff1bb067a6366f223f63e00488699cb649ca06590196ada6f7f"}, - {file = "simsimd-6.0.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:06fb743c770285ca4019c80590b8cd33b9cd3ba662f99bfacf00509b81584a14"}, - {file = "simsimd-6.0.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d9c08bcb71209bbca2336d91a2c44b0dd36d5915e4eef226ed7502897bd26efd"}, - {file = "simsimd-6.0.5-cp310-cp310-win32.whl", hash = "sha256:b12a55b62229d350675c14010fe8093c3f1b63e17148648bb13538f08a636370"}, - {file = "simsimd-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:ca65754f4377277f78b1f03b00e195fbbb07eab5fbd05148e474ada869b85c8d"}, - {file = "simsimd-6.0.5-cp310-cp310-win_arm64.whl", hash = "sha256:7fafda6a025f4ed81f4668d0fd695bbf0de51168a97ac310e1b7b8bcba6b2c29"}, - {file = "simsimd-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:33f9785d094fc9478c0c1f33fbb33d4f906fc85a73ab96397a792844817494b8"}, - {file = "simsimd-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:eebec66ffe4085d14fdb48ac4cf47fa3a06ab1b3d4f25a316b7fb6f9e4468f9c"}, - {file = "simsimd-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1434c215c70a0e7effa0fe79cae11b706794eb7e47c3f095b5e5cb1bd9ac2517"}, - {file = "simsimd-6.0.5-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b473183e36b02e2889cf367ef23bdcbb19085703818fca656b90dc7a4cd74929"}, - {file = "simsimd-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:154bcd62d2878c688188215f058e6374895bc3d3078a92c006df2f1590164e8a"}, - {file = "simsimd-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b3d41e402b84b5e0713b9a83262286ac3b99bda4ede70fa1d0c10693f76e483"}, - {file = "simsimd-6.0.5-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:652e2f993310d1c7a052178ab5f56874a10285eb273c269b19ac661ae6f19f48"}, - {file = "simsimd-6.0.5-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:3619924158a3002cdce1c4a4a8ae44c828a1bb87b032c57419a320744e4c1e62"}, - {file = "simsimd-6.0.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:131aedc11f33470f122565c807bf53933a1821febcb2296a04344a33d292ea69"}, - {file = "simsimd-6.0.5-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:54f62c73ca5d720a4654ea5d6c9d5011ef53e324b924d03e9f698a03070c11ce"}, - {file = "simsimd-6.0.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a41247d91f83e298cdcb3288b9ffa0ef5759d9617703de6d37806c6f3c7c524e"}, - {file = "simsimd-6.0.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e5a300e464de23fac083822c44c6485ab843f9fcdf3caace10f2018c17bc24de"}, - {file = "simsimd-6.0.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0998db0d57c05a6e7cceebe720c15ca3a4f870a3b1c93644801fadd0addc9069"}, - {file = "simsimd-6.0.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5f37dc913d8eea667ec182fcfdad2beef10cde36ace0affb4eda8a1f03c828b4"}, - {file = "simsimd-6.0.5-cp311-cp311-win32.whl", hash = "sha256:9fca350fc4667037a9d50fc382b10bac788be7df62aedec3473697f25b10a911"}, - {file = "simsimd-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:91589ed0279986a4dd3964cc3b42fff1591ce7f96e745293870f44a45b6cfaca"}, - {file = "simsimd-6.0.5-cp311-cp311-win_arm64.whl", hash = "sha256:324267f9e3a1c5e2b40e387424b6d033d2b20bbd19b82e0930f53ecb518dde89"}, - {file = "simsimd-6.0.5-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:b1681e0db03b6e34540efaafbb31241cf17c75a33ea0b58a295e59f23f94c0b6"}, - {file = "simsimd-6.0.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4ca6b49a205de20fa983df1160903ca854ba50489d2bf5348f2c86de91a8823a"}, - {file = "simsimd-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8cf1a1f1f47524d22af4a1bdb88fd96d0b2181cdba03414bfdc6c5890161a56e"}, - {file = "simsimd-6.0.5-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6e75febe1f8be7cdaecc5dfe6cf4eb11547fc0f11ce6daf1ff2c606a584a987"}, - {file = "simsimd-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a915d87afb256036551cdcde45fbc58e780ddace99d5913b0f33b0ddb1147f93"}, - {file = "simsimd-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4bba18e3ef420e11641aa60f3ebae6ba0785bb4796cb5919cee73699687d6471"}, - {file = "simsimd-6.0.5-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:4412d4b3f5b56dd76f855ba592a374afc6f8c0245deca4e76689746df2057527"}, - {file = "simsimd-6.0.5-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:86f24448520bf1a60292b289b170575ccd33e201a8c623a008cdc4d6fc5cb293"}, - {file = "simsimd-6.0.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:37be47390d0e172ea33dc9dab280149a393a34af0928fe708dc866705bb49df3"}, - {file = "simsimd-6.0.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:925c47451faf59e84f5321d0ddefabea8b72d7489bb67ffb45605afa5855a7ce"}, - {file = "simsimd-6.0.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:7554f487f731063bf4b1b79a123ee8fd6b2a073f3c74f6099529f3f36750f4f8"}, - {file = "simsimd-6.0.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:adc081cabd70edcf8bd6ab555411d040446a4b3e427ea22331ea5620e7c13f4a"}, - {file = "simsimd-6.0.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:f8e597e137c9ecfa943abe7ec16b622151a7cd074afb158126908c49482c510a"}, - {file = "simsimd-6.0.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c5cb6bdc8bc446481bcee5db3fa77aaae34d6834d70378bad6fe097635136ee0"}, - {file = "simsimd-6.0.5-cp312-cp312-win32.whl", hash = "sha256:93d006e9523665c50999b9373d1d3f43eb5417b045a2a10f0e71fd9e9798b83c"}, - {file = "simsimd-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:33408710a43410897ddee52b4b6d101361d60bc935593cec964e4f50251050c7"}, - {file = "simsimd-6.0.5-cp312-cp312-win_arm64.whl", hash = "sha256:1a6579aa88177bc3f29b39636fd296c1eddbc0e6113dae60fd0d37b9d8bfe3a6"}, - {file = "simsimd-6.0.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:87212670927256032ad4534ed113f57b360bf6f2f80b80aaa286147135e8ebbb"}, - {file = "simsimd-6.0.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9d41eb39e827b24c2729e368a48cc69c6fc83c78a9c64d4cbe36bfb6056db7f8"}, - {file = "simsimd-6.0.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:73fc7b123cd098ab0aa9f5a7192b6eabec38fde6199c03a24065a801ef7fcc82"}, - {file = "simsimd-6.0.5-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a638b059fb611e7c1ac0ac55c1a9768df935850fc3f24deb6823fd0e2d0bebf"}, - {file = "simsimd-6.0.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d68f1c6778362ed2faaff101cb51b469c1afb414c093e75227d871b7b616daf9"}, - {file = "simsimd-6.0.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7bb9b91024a855320b355b0e7343aa34ed47cbe91e8242c77f740f7bce529768"}, - {file = "simsimd-6.0.5-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:ab495e1a03164ad1141b8f7fd29b1a67c88bd9889aaf617c4555d7eea83acd7b"}, - {file = "simsimd-6.0.5-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:44873124c0bd8b67f362680cd527e68fe080f7ee88c8e15e7100029b8a07f742"}, - {file = "simsimd-6.0.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:900bf8c6f55386f172402bd0075145a91f09c4eac6707be10411e85e7f96517e"}, - {file = "simsimd-6.0.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:40d66a4a10fa3865d0d9bb5befecaee842a1f36ae8430661b00c1535b9760caa"}, - {file = "simsimd-6.0.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c38080d3d4e2c4f85d79e635a9a3e0a98230da1bb06e658f8e111c8589f7e401"}, - {file = "simsimd-6.0.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a72d72114741911adbf191dba4f81e87d63f51e6ff73abc1367921c469f00ca1"}, - {file = "simsimd-6.0.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:f2aa032ff48d7993b41df0d0886717bd07a4196e11c228a2062988781933ca56"}, - {file = "simsimd-6.0.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:0ae543d1f38b092d73d63372885757d8816b5108f875117bca5dd8ee57394d7b"}, - {file = "simsimd-6.0.5-cp313-cp313-win32.whl", hash = "sha256:15fdc7ec7609f863df054f93b54305be5fe549fa1ba7aecdddba5db4892b54e7"}, - {file = "simsimd-6.0.5-cp313-cp313-win_amd64.whl", hash = "sha256:d068c58742d5504bfeaa67bdcb4f7b8c44f3688f0e575bfd7a0ebba088000194"}, - {file = "simsimd-6.0.5-cp313-cp313-win_arm64.whl", hash = "sha256:e131b3f51fa858bfb462a0ad2d497874a63be63bcfa4b1b9519134f905c87218"}, - {file = "simsimd-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:795c8192ae33dfbb9df5f7e7ca05c41a286a397238434cd7cb18f56baa721d4b"}, - {file = "simsimd-6.0.5-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fdeef507c2e50b6113d3feec9d6f60dbfcd37ea6b5823a067a34837a866d4a6b"}, - {file = "simsimd-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a93dfbe322ba7f30facd257e5d868e0a19cd53daa7e3cf76c71241445f46902"}, - {file = "simsimd-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:824c26e106b260baaef2bc09dd921dada1544e0cded0a1b93c0af1168ede1ef4"}, - {file = "simsimd-6.0.5-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:61e20f80b3237da7e65fdf24b815f728daa405aaf95933eb2ae0c2e6eedd3468"}, - {file = "simsimd-6.0.5-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:ffc4d4a970315ba1d3dcbdd2a63cd41a88b5de5bfdbedcbcd31e89a4052532c9"}, - {file = "simsimd-6.0.5-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:edc18e11002e4a483858766046831d1ce845568d6b95b0ebc08c2e4c48a2cc17"}, - {file = "simsimd-6.0.5-cp37-cp37m-musllinux_1_2_armv7l.whl", hash = "sha256:8fcd9e4af49c204ebab3274f3973a2faea1f561e596f9ea90363535eff79cf40"}, - {file = "simsimd-6.0.5-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:4930e24ac6aaafc769c77156e737802dba291ac796fb8bf1202ca213b74e974c"}, - {file = "simsimd-6.0.5-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:a8aeaa207fc7056f917ef6077e58d5c6213b954a27e825d30c56a4f7639d66ba"}, - {file = "simsimd-6.0.5-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:a09f3ff00b19b789f28a297f2103a9c9bd33714029850d9f040ba0b4daf51c93"}, - {file = "simsimd-6.0.5-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:608427d1f3f4013641395c7cab304e482853ba8a7debb22a6829616a0f95fb24"}, - {file = "simsimd-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:c2178eee56335249daf13be171105e01270bd78535fc9e763f5154e4b38245ef"}, - {file = "simsimd-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:7961a4b5d003a3743c6976989fdf9a2b144dbed103fc32cda220b317185a6aab"}, - {file = "simsimd-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:949a8a681e61475414f473e00402feb31700219b2e60862cfc62940b95d89f1b"}, - {file = "simsimd-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c20f0d033d06897411ed5303b7272c2294f8fcb37306030a0080f088240feb5d"}, - {file = "simsimd-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8c66661be8d0fc5a50f6ed79ec25b57b4cd67cc3967acc266107105bcb749561"}, - {file = "simsimd-6.0.5-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bafe77b98da15985091ef023f8bb145ba5f7a359c0df2994526b5185c0befa58"}, - {file = "simsimd-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea73a0a87920a709156256163ca29d14ccd6904bdbe5a4b0dcb9098c1da8f9ff"}, - {file = "simsimd-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d1ab50f2c42ff3cc97bce3b49dbb4631b6edfd4ed78830005f04cc8007531b6"}, - {file = "simsimd-6.0.5-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:ac24e92c67bd7f6cf5944e661fa899ff826a15f9f5dc5206e027fb022ab7cc2d"}, - {file = "simsimd-6.0.5-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:08f4e5eb024275ab167087242c144123c01ea9248cb041b725107685cec5a5e2"}, - {file = "simsimd-6.0.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f4f22e3880daf1baf05cfb9193bd7b72150ff8c87394daf5f09a8385ee20bdb3"}, - {file = "simsimd-6.0.5-cp38-cp38-musllinux_1_2_armv7l.whl", hash = "sha256:b8cafa194bc864c78045999d9a952b08000f6e1da87868b9ae269ec945132af2"}, - {file = "simsimd-6.0.5-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:0bf3a4bd944de3cfc33916c038d639c34e12756b62d00620cf8d35d1d6c9948e"}, - {file = "simsimd-6.0.5-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cb70e43618925ab07c3cab17bfc6cb2ec09b3884898eda3d6129053da35a3884"}, - {file = "simsimd-6.0.5-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:e880d568b5ea42eb2977202b3b77bdfc5932423b925f740e725c8301646d8421"}, - {file = "simsimd-6.0.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:160d342c5c5302e7a5897e596344574a6a70ee63c380261c8a8f342d7ffb16cf"}, - {file = "simsimd-6.0.5-cp38-cp38-win32.whl", hash = "sha256:08326b2b99627ac7b0304b67b84e6ee3d51d0e2013b793b09f0f0130a1e991df"}, - {file = "simsimd-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:72d74a5f8d2e2058ead9ee4e27ef1b1213b1f9e2b9d4b7bfadd58ec57432a888"}, - {file = "simsimd-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6655febe3194ca88460a078cccc514156160a29c1734995fdd01081552f3b15"}, - {file = "simsimd-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ec6802f3160312622af1564e434e710abb15b49cf630fbbb25f8576780283022"}, - {file = "simsimd-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:75ba79f67ccf33bb76f9511c5e4ee3ea8dca0b84cfbf013b1581b6ca1cada32e"}, - {file = "simsimd-6.0.5-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6521cc8967da5bcdbcd6744d3c70404fbafd4818b6d62d7a55547719f02f5e39"}, - {file = "simsimd-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ce6b1f74f90edacb30805f5c4434aeec4e8bc3987b060485c54e1fea767b7b7"}, - {file = "simsimd-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5332881267b846da4f2647b306c8a44c3b908ee8c40e0af552d726d9cc51f409"}, - {file = "simsimd-6.0.5-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:96ca2cd1cdf421069824f0f98eccfa268a68aff65a7b49ccc1cc5d5534786063"}, - {file = "simsimd-6.0.5-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:87f307b0c27ea1f35581829d74354bd4ab40faa5fd3bef56842daf7f7d692793"}, - {file = "simsimd-6.0.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2244186de645c051dfc09cf8d4a6e373fec95aa04757abd1ee852637e35b10b5"}, - {file = "simsimd-6.0.5-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:e422be0f5babe9d6f82bfb3bb88d88bed6fb3113e7fdf37b82cef670a1704d32"}, - {file = "simsimd-6.0.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4fbbc09f188c5b988d2249d54511d055c9ef8a17fb7c6e5f04c4db5b5864d396"}, - {file = "simsimd-6.0.5-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ed0855a2a2ea842aa77ff1be05b4c6c825ebd308687df83a0a7d0ecd9cf93d0b"}, - {file = "simsimd-6.0.5-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:4a9d16fd6e5f175e4d9aff41f0139440875a17c2f267672649bb7e73c3de1f2c"}, - {file = "simsimd-6.0.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:efae84ce075a80dbc2a6ad6660c11ce7b4551a044cba7b715032b01b84b96eac"}, - {file = "simsimd-6.0.5-cp39-cp39-win32.whl", hash = "sha256:517d8ebe410934ad8d204cb0d1442868a8af6b0855f5e5eae3d673cd90f42b9f"}, - {file = "simsimd-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:7a215934e47731b886bc7d9fe76cbb905280ec7f879948683bd4d5db0ab11690"}, - {file = "simsimd-6.0.5-cp39-cp39-win_arm64.whl", hash = "sha256:772b311598621d98cbf681edab6a06662f1284c27fb90ccadbfedd0d5cff1e8a"}, - {file = "simsimd-6.0.5.tar.gz", hash = "sha256:3951df7ef31af6257c22b91fbb16b4cd683a13bf48426de2c4d490fb5e8c5088"}, + {file = "simsimd-6.0.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b5189d956bf35b6e73afdfc294674110e8deeb9c92d35f56f0c02f707fe3f5d3"}, + {file = "simsimd-6.0.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:936864168136a8b974745ab28ec1e9efb918399a68b01650634e1cfc2d505734"}, + {file = "simsimd-6.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5ae3835f6e8c97cc2276e49993e771d265711c0c561ce6fe91a57ecdc0892591"}, + {file = "simsimd-6.0.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5bbc24a3de2e6970f5c619dd911d20eced9c6d8beee8a9d0618ca68750a3a0f0"}, + {file = "simsimd-6.0.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c157008d23528fe70cb3cb28c58562a124cd367a1976bda548d299eed5cbc6a"}, + {file = "simsimd-6.0.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7ebbe49a5ae275b0e345df4143a7faccc840891ae1e472f8991cd710e64ff066"}, + {file = "simsimd-6.0.6-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:567ecf3628ed1030307c11b5be2827d16579bddb35c5f59d619d83628bc39fff"}, + {file = "simsimd-6.0.6-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c6bcb009b2777dc633886848569c927982bf726e2c71ae6e26672c59e1f3741f"}, + {file = "simsimd-6.0.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:aa8da3cb8660c8e172228d3c5ca1e0d2c658a8221b617fa66c273f37bd8544c2"}, + {file = "simsimd-6.0.6-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:6cd9e526c545be385bac93e5265acb7cb1df79eb643ff807d912ee0481964c91"}, + {file = "simsimd-6.0.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:6c01f34eb825542a3b8443bcb6e2f204f96a6404a2df5e0c00ce10f4b97520c7"}, + {file = "simsimd-6.0.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f7903c6a85ad51d244162652d2f594df9f2e6f47143a61ce05d048cb61f8b711"}, + {file = "simsimd-6.0.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:94072026c8d9f162fa266597eb9725d8b77c95ae7c6cd2a3e5469cda76ecbdc6"}, + {file = "simsimd-6.0.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:aba322dd07ead8ae5258950112ac7bd5e226c4daa79864795dc738f9f95a013c"}, + {file = "simsimd-6.0.6-cp310-cp310-win32.whl", hash = "sha256:87c1af803db6b4bda2fc8facfb66ce13389d90a0c834d385dc120f8741ce74ca"}, + {file = "simsimd-6.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:440fb9e0238e1de12626fe6670db6cace6c93d2256c9d2be04282b91d14ae220"}, + {file = "simsimd-6.0.6-cp310-cp310-win_arm64.whl", hash = "sha256:ad65c93e50309a4ef07cf45e97f4876de2f7123b818c107dde9cb42515d9e60f"}, + {file = "simsimd-6.0.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5e85a8f850cc026aac14978871aee70278201de3f0038613019a7946628bec53"}, + {file = "simsimd-6.0.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ede6accf7ebbe21947eb087de9f38adcb6eaf4ea869db611f7455eb14fe09721"}, + {file = "simsimd-6.0.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a0d32b21d29534cd908b4f23d8778f4a163b74bdb0e6c3a44bac26ef6c3eabdf"}, + {file = "simsimd-6.0.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fb8a12cb07e2d2e5d3215d0e6102e3cafde6336823328252359c737ac06b557"}, + {file = "simsimd-6.0.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8e92ce5bfb8cac62a5f35b74e7d66d26c844480c2cd15cde81b18ef58645ca5"}, + {file = "simsimd-6.0.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37f3f88b4c5ba4792a8a4cfe9a78986899d32e9d628c0fb31f68c89a13f1df6f"}, + {file = "simsimd-6.0.6-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:66f7550dd23f4eb38523c22e794c2486df7f0090b1dffc287ec66dba3d552c61"}, + {file = "simsimd-6.0.6-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:47c7017cf1e06ac74492bcf639fcafe1fad1c6f48212010e4f1c4cd2348c529b"}, + {file = "simsimd-6.0.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6764cb1d511eabb9cc260ae19da78742c9a1ea829138eee44a75a58a315ac6ce"}, + {file = "simsimd-6.0.6-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:bb77af32bd68af0c4e674335e4e1f5ea2a2a11d8aeb12b98059504d6af540752"}, + {file = "simsimd-6.0.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c6d7f25748f2aa7f44478702a1da81246dbff508e93a493b388fd5d4f6eefd87"}, + {file = "simsimd-6.0.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:3368f6762f0a7c364b7775832afc35792b37810bd5ef66e5112a9c27d4f4d5b1"}, + {file = "simsimd-6.0.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:429edac00081df912d57ef495c6608b832d09f6cec21dfcc820bae88c7d56ff8"}, + {file = "simsimd-6.0.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:98aa21b76dcedf4f93e58500675f0da45135787b6e003eb128233adb8a8f3a24"}, + {file = "simsimd-6.0.6-cp311-cp311-win32.whl", hash = "sha256:0a82e20d8317d082237305de762485a789a2db90a510460f3ced5d1edd3daf66"}, + {file = "simsimd-6.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:f260d6edf4b620d8bb5508f34a2726f277ca31dcf5bb6dba59bf89c951495ad8"}, + {file = "simsimd-6.0.6-cp311-cp311-win_arm64.whl", hash = "sha256:2dfdc5de67d51826afede135c82d85c0578b42cbc4aad5b23460d40bf3791046"}, + {file = "simsimd-6.0.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:6af52f3841cddef2fec984afe7eafb62604cb1a54111e5dd70138963aae4b376"}, + {file = "simsimd-6.0.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5ba288103a93412182bdd86ca42af09c46b0364e48fc33b3b6d974758218e592"}, + {file = "simsimd-6.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:427626f12103fac06ccff4ae8e5ff3084e10aed2ad7c072b98c51dfcc705fbbb"}, + {file = "simsimd-6.0.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:628134f9c9d402af992f794143b98c38623ec6f54a7f68be6677e2096979a5f2"}, + {file = "simsimd-6.0.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b42134b5ef13e43f990cf85d134d2147942b9b56dd00fbe1df55fea503fae927"}, + {file = "simsimd-6.0.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64590dc9da6b81fa9d7b73b15eec54db69054508b1bff258604dc76138c0689e"}, + {file = "simsimd-6.0.6-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:806f1f7d0f2838966de451493cfab24c962b0058177a7304aa25d4df49cefbeb"}, + {file = "simsimd-6.0.6-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:0c0ba7cba0d5f5473a3edca0884fec5b861c2516d6e92bb9720e0e9121f28c9f"}, + {file = "simsimd-6.0.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:61cfbfc77098e186433dbf5e4b93594d2b5ccbefb12c9cb6ca508c9a1616ab03"}, + {file = "simsimd-6.0.6-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:cea5390dae55dbdc83d9517b95d64dbbb2457fbffe96ea8173765198fba13b76"}, + {file = "simsimd-6.0.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0445d6914a88d3162a7694d9800c553527d8fe1d2d0b053689fbaca40b2a6b32"}, + {file = "simsimd-6.0.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:36945209b3508d47baec95ba77e7ce61332c2065cd88abbda37c033595d14ac5"}, + {file = "simsimd-6.0.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af5207d5713516b1591c037829a9ef484f94f74eb7194289708c8e48ec997e4c"}, + {file = "simsimd-6.0.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b3cc3bc0120b0bde86fd675f7031f2080cda912a7be8f708a36db0ba381e1281"}, + {file = "simsimd-6.0.6-cp312-cp312-win32.whl", hash = "sha256:53a64c7793b99e384c17bfb79a688548c422443ae2a3a2d37bd3df3f4c8fab90"}, + {file = "simsimd-6.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:c97d865ea20606262647c73c543e61745b625b6f7ad13f707571f9dbde246672"}, + {file = "simsimd-6.0.6-cp312-cp312-win_arm64.whl", hash = "sha256:fddd33293c2d233d2014a095d50a8507c22dc256aa9263932078d1ef9a6d1498"}, + {file = "simsimd-6.0.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:67e6434c9d04eb259691296794486cbbe0b47bae9300a8519ef320c00c035161"}, + {file = "simsimd-6.0.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ff498a2734ae42d13211651b31743df0995964dd9fa1c511254b43bc9b121712"}, + {file = "simsimd-6.0.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:205800799df116297d8936ec2b594a758d93536126236db5f227c4a0797782e5"}, + {file = "simsimd-6.0.6-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f85542bd370561ede66815b59df10e2fbba1997686548a33364cf299fb1f443d"}, + {file = "simsimd-6.0.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:756d76ecfa9d853363855377f1fe0b06d55e2cfb48f57f6e1363f61454266834"}, + {file = "simsimd-6.0.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:438a82c7005cd229693663703486d53205db5796168e2434511cdedd37eb4c07"}, + {file = "simsimd-6.0.6-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:88bff7d397fe6cb4c62b6aecaf32f6c7d2ffbc09522531e6f28317311136e8d8"}, + {file = "simsimd-6.0.6-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:c5b373d4a27a1392fefa1253f636954947aad9718dbdca2ae233f6e7b3347b2f"}, + {file = "simsimd-6.0.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:871071e26e74da1f5a15ea0b033a3ef22901bad24fc3bdafa3d7c2922a511468"}, + {file = "simsimd-6.0.6-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:8a92e6648524e2739134a753949b2ab0c54c91d60a21f6a72d2f38d1a06314fb"}, + {file = "simsimd-6.0.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:e97fa0f2dd3d8a1c3f7d2a347d4737dad70be9f3708020642950a8e507d67518"}, + {file = "simsimd-6.0.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:8a1d5438fec10b2dfe90c0c050f16c03b323d848d0d0369545c7fcd7caace73b"}, + {file = "simsimd-6.0.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:591115d284c817d46d23eff041b2f414691ecb50e3c79e548427edcdffb461c4"}, + {file = "simsimd-6.0.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7900e5c49a99cb1f92ebd5b501aa345b70488e1cdbe0807269452944215b30d8"}, + {file = "simsimd-6.0.6-cp313-cp313-win32.whl", hash = "sha256:45c40f7268e4d1e9506db49ddcd8e4f9be0c7857ba0dc668e85946cdc8d20712"}, + {file = "simsimd-6.0.6-cp313-cp313-win_amd64.whl", hash = "sha256:68483ff5213fab2baf41d758854a867abff4635c4cbd7e3cab1573b1b7296e6b"}, + {file = "simsimd-6.0.6-cp313-cp313-win_arm64.whl", hash = "sha256:14a7f14450263f02f1509536302824896bc88ab90a2623a5525293eb69e1a2bf"}, + {file = "simsimd-6.0.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fc58ec09dbfde9cc74fa0e7abfa5926674743509a42fa482d487d65ded2f4ac4"}, + {file = "simsimd-6.0.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dee88effd675eb6fa00a0c7956f67da41f3c14ee0944d186c26b1d2f129aea85"}, + {file = "simsimd-6.0.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:278c334d6b3200b28f63be99db7e3d0b7d8108630e14c26a519ad8d5704c91b9"}, + {file = "simsimd-6.0.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:475af54f170860729d922040b522e81379ab8295141f783fa5168c7099e52953"}, + {file = "simsimd-6.0.6-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:8725ce30750a7d4ee6b552a0a5db3d9541398be6f51ab038047cb1bbf0cc1a28"}, + {file = "simsimd-6.0.6-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:8d265bc13e46f55fc51c4de092c3d60aa2d06c72796d5457e6bf6663444b264c"}, + {file = "simsimd-6.0.6-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:f2deca6de9bca6f0ef647a6ce34910a454885bd5655ba838b28b863f27642de5"}, + {file = "simsimd-6.0.6-cp37-cp37m-musllinux_1_2_armv7l.whl", hash = "sha256:01819a0f410babae8a6d8aaebae1ecdbfcc33ef15dfaed147f48bf2328ad04ee"}, + {file = "simsimd-6.0.6-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:8eb0e77c58432bb01337d2cea26e7a3aa57ec38866fd293418be3f5936bf3efe"}, + {file = "simsimd-6.0.6-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:ce1c641e465d3d99ab08bef7ce9dc341183f6729f58717163bd2121d3b3fed15"}, + {file = "simsimd-6.0.6-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:97f88a8b6ab516269d4dc1e99e33a2cd4359ee1c7c1a48e6278e44f8a8b0a154"}, + {file = "simsimd-6.0.6-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:c91f29feba0bc246d4eb60b8a441778ebc69c152a1c9e2069f9d29e7f7e25f8f"}, + {file = "simsimd-6.0.6-cp37-cp37m-win32.whl", hash = "sha256:d4792394b8f56d88198c1ffb64a6047cdbc5e94088d4b31a62cf8947e8e67b2f"}, + {file = "simsimd-6.0.6-cp37-cp37m-win_amd64.whl", hash = "sha256:7974e9e00945ccc51aaaf26dfb186c3fb5deb059b912452ad427cf49ec79f596"}, + {file = "simsimd-6.0.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:12b6abe3c290e43246368dec02326d89814c7d79c0d55a9861d0989afad7d40f"}, + {file = "simsimd-6.0.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0e18ab52b5bd788b54cd97c8a9172dc69cf961982f4af11d8177cb7264f7da15"}, + {file = "simsimd-6.0.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8037fbd97738c43607625b9570187040bdf3c5d4281d3978cbe9ab29eec2afef"}, + {file = "simsimd-6.0.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5909b7f0305bd82112fac279cddfa17041bd2582156d44b1de18219699438f11"}, + {file = "simsimd-6.0.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a6a9e59555ea3be214d9b0c32a1c3f4e18b496aee1767e87115c8ab764710d96"}, + {file = "simsimd-6.0.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8de287c235ce7950412d08e1c175eae9f4587bca07e6da54c795cedcceb2125a"}, + {file = "simsimd-6.0.6-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:2a92202d6dfbf6b64d5eb8b909dd3fced21636f95dbd4b98c4ec65e22b2b3920"}, + {file = "simsimd-6.0.6-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:419e70d111af89e4e342c0f9a0406a25fde430ab33fbc835e4bda542578ea9b0"}, + {file = "simsimd-6.0.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:1f0b9f48b2e8a9ed44ed49998c76119d35012a6c15891339cb9701c7793d84f6"}, + {file = "simsimd-6.0.6-cp38-cp38-musllinux_1_2_armv7l.whl", hash = "sha256:45b216d2b046370afd0a98b1c8cac26f8d92b3847ad99ce7a9d116d54abbac62"}, + {file = "simsimd-6.0.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:8d6fbee80d018cfa675e648e67ad017e9035cf337beed27e1d960dc56ac4d25b"}, + {file = "simsimd-6.0.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:bb12265aae33ad3d5a35e602c684d17d9b44a53d0b53a553df4d0dab4f446d12"}, + {file = "simsimd-6.0.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:addc7ef20edcf63904c6fa6af4aeeeb88dd022b1e51f662bcf2c4df438563c7d"}, + {file = "simsimd-6.0.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9a2d65753395380c1196e364896667edd95cb18d426c58b0fbfa5b0a6928c62c"}, + {file = "simsimd-6.0.6-cp38-cp38-win32.whl", hash = "sha256:587421f4f8869ee63c8685cf72d172ac6096778ecafc2225c17d058e112e0e09"}, + {file = "simsimd-6.0.6-cp38-cp38-win_amd64.whl", hash = "sha256:84d662fa9e98a3b16a7d475627ee0115b65f594afd6012299df22977b3d8731d"}, + {file = "simsimd-6.0.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d5074ca110097c217892151f09e5db5eb9fd32bdf0681cccfb148658fbaf4284"}, + {file = "simsimd-6.0.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eee2fe5b4ffb9a5adddbb06823f9cbb8fa9694f9abbcda509c011e0703d50d4c"}, + {file = "simsimd-6.0.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f9bfef8da409ceb5f579f74f1494dd25cc199cb6c0856d8e3eeb4bbd11cef1c9"}, + {file = "simsimd-6.0.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76302e54ef372826cf34ff409a44c513d96a7175207958635287f81c2c0fa5ae"}, + {file = "simsimd-6.0.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:684290cccca75c9227fe43994318425a5919823906c2601c6f64beaae9ba91bc"}, + {file = "simsimd-6.0.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96ee96bb7dc9c0f7ff5f3036565f3ffe2d5fe5cb7891c612a2c0f46092c22deb"}, + {file = "simsimd-6.0.6-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:01796a4997e3fc5ca054633b78c4ad555a09ad449119c6d98d7fca2788cd14c2"}, + {file = "simsimd-6.0.6-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ef71393a583eebf60566bf127e8e903b7975a22bf7e2a22988364160a5a4c804"}, + {file = "simsimd-6.0.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1b21e667f6cf481d0d0f5e3854eae7182cf41f184a563eef191c69b7e91c52cd"}, + {file = "simsimd-6.0.6-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:bc0c8ce7a9bc28386788fc6cb0c447f132cefb01ca7f83c4a9193b817b11a30a"}, + {file = "simsimd-6.0.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3eb01310531ce366ae03e3553c93bcc84dd2cab14b18bf4dca15d6468cd22987"}, + {file = "simsimd-6.0.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c0e63fdad92fbf210dee1697e271457ac16563cd9f8ac4e85d4fb703356795ed"}, + {file = "simsimd-6.0.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9435822f5a3d518a6c7548796b39162ab7ae0f91b8daf147c1a32b572261ef8c"}, + {file = "simsimd-6.0.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a7125ecea94e1207583397fd982eddcdfe6a3103abc55f7b228d9c9eb5c5e03c"}, + {file = "simsimd-6.0.6-cp39-cp39-win32.whl", hash = "sha256:9b99e9afb33c888904f576fc002932abe1bc02d426532c5a1748d2a6bfde14e9"}, + {file = "simsimd-6.0.6-cp39-cp39-win_amd64.whl", hash = "sha256:fb374841905e7dc1483d0b39f3f8b81fbb676b8d7405bb2e781c718683d97ab6"}, + {file = "simsimd-6.0.6-cp39-cp39-win_arm64.whl", hash = "sha256:ab05dde439afc95db209072dffcc609d2c8043c044170886336327253da050c7"}, + {file = "simsimd-6.0.6.tar.gz", hash = "sha256:f2c2b886bd7806be626c58affbf48152ac315b91c7eaebb77e712f1f7d4bb0c5"}, ] [[package]] @@ -7078,30 +7077,28 @@ files = [ [[package]] name = "zigzag-dse" -version = "3.7.4" +version = "2.5.4" description = "ZigZag - Deep Learning Hardware Design Space Exploration" optional = true -python-versions = ">=3.11" +python-versions = ">=3.9" files = [ - {file = "zigzag_dse-3.7.4-py3-none-any.whl", hash = "sha256:73f1ce52659f58745eef23d39de3e44ed1f86d791ae72785cf30897a10e82b53"}, - {file = "zigzag_dse-3.7.4.tar.gz", hash = "sha256:efd09c3f2759cea8aa614d9a88f42823097b7e8823690d879f2e85c587a7ccf1"}, + {file = "zigzag-dse-2.5.4.tar.gz", hash = "sha256:0244663b8704e3874728b1e93a4b591cf2c43219c1fecbd11e0e9a0a9e1062ab"}, + {file = "zigzag_dse-2.5.4-py3-none-any.whl", hash = "sha256:8ef3f6f3c4948ec000455e0360fe4078ae2801fb73a50714dfd9f90b0085a773"}, ] [package.dependencies] -cerberus = "*" matplotlib = "*" multiprocessing-on-dill = "*" networkx = "*" numpy = "*" onnx = "*" pyyaml = "*" -seaborn = "*" sympy = "*" +tomli = {version = "*", markers = "python_version < \"3.11\""} tqdm = "*" -typeguard = "*" [package.extras] -dev = ["build", "bumpver", "pip-tools", "pre-commit", "pytest", "twine"] +dev = ["build", "bumpver", "pip-tools", "twine"] [[package]] name = "zipp" @@ -7131,4 +7128,4 @@ vision = ["albumentations", "gdown", "imagecorruptions", "kornia", "pycocotools" [metadata] lock-version = "2.0" python-versions = ">=3.10 <3.12" -content-hash = "038672aa7bc64273ee5d399e932a95214c3e55a6cbdc37a870c13b84076a4f9c" +content-hash = "f2b2d4fc896f5063b7707d63c75d8a8a37b21da21b88548c39c25d1c3cf2118f" diff --git a/pyproject.toml b/pyproject.toml index 9b04ec7f..a630abf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ scikit-learn = "~1.2.2" hydra-colorlog = "^1.2.0" z3-solver = "^4.11.2.0" xgboost = "^1.5.2" +pytorch-lightning = "^2.4.0" # optional dependencies for backends onnxruntime = {version = "^1.4.0", optional = true} @@ -58,17 +59,17 @@ pycocotools = {version = "^2.0.6", optional = true} gdown = {version = "^4.5.3", optional = true} albumentations = {version = "^1.4.3", optional = true} kornia = {version = "^0.7.1", optional = true} -lightning = "^2.1.2" dgl = "1.1.3" -pytorch-lightning = "^2.1.3" +nn-meter = {version = "^2.0", optional = true} +zigzag-dse = {version = "^2.5.3", optional = true} +rich = "^13.7.0" onnx = "^1.16.0" spox = "^0.12.0" optree = "^0.11.0" netdeployonnx = {path = "external/netdeployonnx", develop = true} tensorboard = "^2.17.1" pymoo = {version = "^0.6.1.3", optional = true} -zigzag-dse = {version = "^3.6.1", optional = true, python = ">=3.11"} -nn-meter = {version = "^2.0", optional = true} +lightning = "^2.4.0" [tool.poetry.dev-dependencies] pytest = ">=7.2.0" @@ -102,6 +103,8 @@ hannah-eval = 'hannah.tools.eval:main' hannah-objectdetection-eval = 'hannah.tools.objectdetection_eval:main' hannah-nas-eval = "hannah.nas.eval.__main__:main" hannah-exec = "hannah.tools.exec:main" +hannah-hw = "hannah.nas.hardware_description.__main__:main" + [tool.poetry.group.dev.dependencies] pyre-check = "^0.9.17" diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 00000000..ddbfd549 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,28 @@ +import pytest +from hannah.utils.registry import Registry + +class TestClass: + def __init__(self, x): + self.x = x + +def test_register(): + registry = Registry() + registry.register(TestClass) + assert 'TestClass' in registry.registered_classes + +def test_instantiate(): + registry = Registry() + registry.register(TestClass) + instance = registry.instantiate('TestClass', 5) + assert isinstance(instance, TestClass) + assert instance.x == 5 + +def test_iter(): + registry = Registry() + registry.register(TestClass) + assert list(registry) == [TestClass] + +def test_len(): + registry = Registry() + registry.register(TestClass) + assert len(registry) == 1 \ No newline at end of file