Skip to content

Commit

Permalink
Merge branch 'fix/cluster_svd' into 'main'
Browse files Browse the repository at this point in the history
Fix/cluster svd

See merge request es/ai/hannah/hannah!363
  • Loading branch information
cgerum committed Nov 28, 2023
2 parents 2d07031 + f6ec6f6 commit 4797b2b
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 23 deletions.
15 changes: 8 additions & 7 deletions hannah/callbacks/clustering.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#
# Copyright (c) 2022 University of Tübingen.
# Copyright (c) 2023 Hannah contributors.
#
# This file is part of hannah.
# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info.
# See https://github.com/ekut-es/hannah for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,7 +23,7 @@
import torch.nn as nn
from pytorch_lightning.callbacks import Callback
from scipy.sparse import csr_matrix
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.cluster import MiniBatchKMeans

from ..models.factory.qat import Conv1d, ConvBn1d, ConvBnReLU1d, ConvReLU1d

Expand Down Expand Up @@ -55,7 +55,7 @@ class kMeans(Callback):
def __init__(self, cluster):
self.cluster = cluster

def on_fit_end(self, trainer, pl_module):
def on_test_epoch_start(self, trainer, pl_module):
"""
Args:
Expand Down Expand Up @@ -92,7 +92,6 @@ def replace_modules(module):
"""
for name, child in module.named_children():
replace_modules(child)

if isinstance(child, ConvBn1d):
tmp = Conv1d(
child.in_channels,
Expand Down Expand Up @@ -160,9 +159,11 @@ def replace_values_by_centers(x):
replace_values_by_centers
) # _ symbolizes inplace function, tensor moved to cpu, since apply_() only works that way
module.to(device=device) # move from cpu to gpu
# PATH = os.getcwd() + '/checkpoints/last.ckpt'
# torch.save(pl_module.state_dict(), PATH)
logger.critical("Clustering error: %f", float(inertia))

def on_epoch_end(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module):
"""
Args:
Expand Down Expand Up @@ -205,4 +206,4 @@ def replace_values_by_centers(x):
)
module.weight.data = clustered_data
module.to(device=device)
logger.info("Clustering error: %f", float(inertia)) # summed over all layers
logger.info("Clustering error: %f", float(inertia)) # summed over all layers
7 changes: 3 additions & 4 deletions hannah/callbacks/svd_compress.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#
# Copyright (c) 2022 University of Tübingen.
# Copyright (c) 2023 Hannah contributors.
#
# This file is part of hannah.
# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info.
# See https://github.com/ekut-es/hannah for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, rank_compression, compress_after):
self.compress_after = compress_after
super().__init__()

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
"""
Args:
Expand All @@ -45,7 +45,6 @@ def on_epoch_start(self, trainer, pl_module):
if trainer.current_epoch == self.compress_after / 2:
with torch.no_grad():
for name, module in pl_module.named_modules():

# First case: conv-net-trax model with Sequential Layers
if name == "model.linear.0.0" and not isinstance(
pl_module.model.linear[0][0], nn.Sequential
Expand Down
20 changes: 20 additions & 0 deletions hannah/conf/compression/clustering.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
##
## Copyright (c) 2022 University of Tübingen.
##
## This file is part of hannah.
## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info.
##
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
##
## http://www.apache.org/licenses/LICENSE-2.0
##
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
##
defaults:
- clustering: kmeans
3 changes: 2 additions & 1 deletion hannah/conf/compression/clustering/kmeans.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
## limitations under the License.
##
_target_: hannah.callbacks.clustering.kMeans
amount: 15
method: kmeans
amount: 16
1 change: 0 additions & 1 deletion hannah/conf/compression/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@ defaults:
- pruning: l1_unstructured
- decomposition: svd
- clustering: kmeans
- quantization: default
12 changes: 3 additions & 9 deletions hannah/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,9 @@ def train(
)
lit_trainer.fit(lit_module, ckpt_path=ckpt_path)

if config.get("compression", None) and (
config.get("compression").get("clustering", None)
or config.get("compression").get("decomposition", None)
):
ckpt_path = None
else:
if lit_trainer.checkpoint_callback.kth_best_model_path:
ckpt_path = "best"
ckpt_path = None
if lit_trainer.checkpoint_callback.kth_best_model_path:
ckpt_path = "best"
ckpt_path = None

if not lit_trainer.fast_dev_run:
reset_seed()
Expand Down
2 changes: 1 addition & 1 deletion hannah/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def common_callbacks(config: DictConfig) -> list:
)
device_stats = DeviceStatsMonitor(cpu_stats=config.get("device_stats", False))
callbacks.append(device_stats)
use_fx_mac_summary = config.get('fx_mac_summary', False)
use_fx_mac_summary = config.get("fx_mac_summary", False)
if use_fx_mac_summary:
mac_summary_callback = FxMACSummaryCallback()
else:
Expand Down

0 comments on commit 4797b2b

Please sign in to comment.