Skip to content

Commit

Permalink
Update kakao resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
cgerum committed Nov 30, 2023
1 parent 4458e31 commit f0d7dec
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
3 changes: 2 additions & 1 deletion experiments/rhode_island/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ defaults:
- override features: identity # Feature extractor configuration name (use identity for vision datasets)
- override model: timm_mobilenetv3_small_075 # Neural network name (for now timm_resnet50 or timm_efficientnet_lite1)
- override scheduler: 1cycle # learning rate scheduler config name
- override augmentation: ri_augment
- override optimizer: sgd # Optimizer config name
- override normalizer: null # Feature normalizer (used for quantized neural networks)
- override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks)
Expand All @@ -31,7 +32,7 @@ dataset:
data_folder: ${oc.env:HANNAH_DATA_FOLDER,${hydra:runtime.cwd}/../../datasets/}

module:
batch_size: 128
batch_size: 512

trainer:
max_epochs: 15
Expand Down
7 changes: 0 additions & 7 deletions experiments/rhode_island/dvc.yaml

This file was deleted.

16 changes: 13 additions & 3 deletions hannah/callbacks/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
except ModuleNotFoundError:
onnxrt_backend = None

from ..models.factory.qat import QAT_MODULE_MAPPINGS
from typing import Mapping

from ..nn.qat import QAT_MODULE_MAPPINGS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -172,6 +174,8 @@ def on_test_epoch_start(self, trainer, pl_module):
Returns:
"""
logger.info("Exporting module")

pl_module = self.quantize(pl_module)
self.prepare(pl_module)
self.export()
Expand Down Expand Up @@ -222,8 +226,14 @@ def on_test_batch_end(
"""
if batch_idx < self.test_batches:
result = self.run_batch(inputs=batch[0])
target = pl_module(batch[0].to(pl_module.device))
# decode batches from target device
if isinstance(batch, Mapping) or isinstance(batch, dict):
inputs = batch["data"]
else:
inputs = batch[0]

result = self.run_batch(inputs=inputs)
target = pl_module(inputs.to(pl_module.device))
target = target[: result.shape[0]]

mse = torch.nn.functional.mse_loss(
Expand Down
23 changes: 21 additions & 2 deletions hannah/models/kakao_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,28 @@ def resnet8(*args, **kwargs):


def resnet8_025(*args, **kwargs):
num_class = 10
num_class = 4
model = nn.Sequential(
conv_bn(3, 16, kernel_size=8, stride=8, padding=0),
conv_bn(16, 32, kernel_size=5, stride=2, padding=2),
Residual(nn.Sequential(conv_bn(32, 32), conv_bn(32, 32))),
conv_bn(32, 64, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(2),
Residual(nn.Sequential(conv_bn(64, 64), conv_bn(64, 64))),
conv_bn(64, 32, kernel_size=3, stride=1, padding=0),
nn.AdaptiveMaxPool2d((1, 1)),
Flatten(),
nn.Linear(32, num_class, bias=False),
Mul(0.2),
)

return model


def resnet8_012(*args, **kwargs):
num_class = 4
model = nn.Sequential(
conv_bn(3, 16, kernel_size=4, stride=2, padding=0),
conv_bn(3, 16, kernel_size=16, stride=16, padding=0),
conv_bn(16, 32, kernel_size=5, stride=2, padding=2),
Residual(nn.Sequential(conv_bn(32, 32), conv_bn(32, 32))),
conv_bn(32, 64, kernel_size=3, stride=1, padding=1),
Expand Down

0 comments on commit f0d7dec

Please sign in to comment.