From e75363ed0fcc14d90b1e6b4eae9e0d97b54e6b9a Mon Sep 17 00:00:00 2001 From: Christoph Gerum Date: Wed, 15 Nov 2023 18:32:56 +0000 Subject: [PATCH] Fix dependency on computer vision datasets --- hannah/conf/module/anomaly_detection.yaml | 2 +- hannah/conf/module/image_classifier.yaml | 2 +- hannah/datasets/fake1d.py | 43 +++++++++-------------- hannah/datasets/vision/base.py | 11 ++---- hannah/datasets/vision/cifar.py | 2 -- hannah/modules/__init__.py | 3 -- 6 files changed, 21 insertions(+), 42 deletions(-) diff --git a/hannah/conf/module/anomaly_detection.yaml b/hannah/conf/module/anomaly_detection.yaml index 7304498a..3ad4608a 100644 --- a/hannah/conf/module/anomaly_detection.yaml +++ b/hannah/conf/module/anomaly_detection.yaml @@ -1,4 +1,4 @@ -_target_: hannah.modules.AnomalyDetectionModule +_target_: hannah.modules.vision.AnomalyDetectionModule num_workers: 0 batch_size: 128 shuffle_all_dataloaders: False diff --git a/hannah/conf/module/image_classifier.yaml b/hannah/conf/module/image_classifier.yaml index 40fb3236..15278d12 100644 --- a/hannah/conf/module/image_classifier.yaml +++ b/hannah/conf/module/image_classifier.yaml @@ -16,7 +16,7 @@ ## See the License for the specific language governing permissions and ## limitations under the License. ## -_target_: hannah.modules.ImageClassifierModule +_target_: hannah.modules.vision.ImageClassifierModule num_workers: 0 batch_size: 128 shuffle_all_dataloaders: False diff --git a/hannah/datasets/fake1d.py b/hannah/datasets/fake1d.py index 570a8997..d3e0801b 100644 --- a/hannah/datasets/fake1d.py +++ b/hannah/datasets/fake1d.py @@ -26,44 +26,35 @@ from ..utils.utils import extract_from_download_cache, list_all_files from .base import AbstractDataset, DatasetType -from .vision.base import TorchvisionDatasetBase -class Fake1dDataset(TorchvisionDatasetBase): +class Fake1dDataset(AbstractDataset): + def __init__(self, config, size): + self.config = config + self.size = size + + self.data = torch.randn((size, config["channels"], config["resolution"])).split( + 1, 0 + ) + self.target = torch.randn( + (size, config.size), dtype=torch.int32, min=0, max=config["num_classes"] + ) + @classmethod def prepare(cls, config): pass @classmethod def splits(cls, config): - resolution = config.resolution - channels = config.channels - - test_data = torchvision.datasets.FakeData( - size=128, - image_size=(channels, resolution), - num_classes=config.num_classes, - ) - val_data = torchvision.datasets.FakeData( - size=128, - image_size=(channels, resolution), - num_classes=config.num_classes, - ) - train_data = torchvision.datasets.FakeData( - size=512, - image_size=(channels, resolution), - num_classes=config.num_classes, - ) - - return cls(config, train_data), cls(config, val_data), cls(config, test_data) + return cls(config, size=128), cls(config, size=32), cls(config, size=32) def __getitem__(self, index): - data, target = self.dataset[index] - data = np.array(data).astype(np.float32) / 255 - data = self.transform(image=data)["image"] - data = torch.squeeze(data) + data, target = self.data[index], self.target[index] return data, target @property def class_names(self): return [f"class{n}" for n in range(self.config.num_classes)] + + def __len__(self): + return len(self.targets) diff --git a/hannah/datasets/vision/base.py b/hannah/datasets/vision/base.py index 9219623f..04f609d3 100644 --- a/hannah/datasets/vision/base.py +++ b/hannah/datasets/vision/base.py @@ -18,20 +18,13 @@ # import logging import re -import tarfile -from collections import Counter, namedtuple -from typing import Dict, List, Optional +from collections import Counter +from typing import List import albumentations as A import cv2 import numpy as np -import pandas as pd -import requests -import torch -import torchvision from albumentations.pytorch import ToTensorV2 -from omegaconf import DictConfig -from sklearn.model_selection import train_test_split from ..base import AbstractDataset diff --git a/hannah/datasets/vision/cifar.py b/hannah/datasets/vision/cifar.py index 04f7f398..7c1803ad 100644 --- a/hannah/datasets/vision/cifar.py +++ b/hannah/datasets/vision/cifar.py @@ -19,10 +19,8 @@ import logging import os -import albumentations as A import torch.utils.data as data import torchvision -from albumentations.pytorch.transforms import ToTensorV2 from torchvision import datasets from .base import TorchvisionDatasetBase diff --git a/hannah/modules/__init__.py b/hannah/modules/__init__.py index 7f76e1c7..77117e8d 100644 --- a/hannah/modules/__init__.py +++ b/hannah/modules/__init__.py @@ -27,14 +27,11 @@ StreamClassifierModule, ) from .object_detection import ObjectDetectionModule -from .vision import AnomalyDetectionModule, ImageClassifierModule __all__ = [ - "AnomalyDetectionModule", "CrossValidationStreamClassifierModule", "SpeechClassifierModule", "StreamClassifierModule", - "ImageClassifierModule", "AnomalyDetectionModule", "ObjectDetectionModule", "CartesianClassifierModule",