From c3901be34b882685c3804bd668f167ac3441f80a Mon Sep 17 00:00:00 2001 From: Julia Werner Date: Wed, 15 Nov 2023 11:26:38 +0000 Subject: [PATCH] Remove freezing of layer for retraining since we achieve better results by... --- hannah/conf/dataset/chbmit.yaml | 2 +- hannah/conf/dataset/chbmitrt.yaml | 2 +- hannah/train.py | 19 +------------------ 3 files changed, 3 insertions(+), 20 deletions(-) diff --git a/hannah/conf/dataset/chbmit.yaml b/hannah/conf/dataset/chbmit.yaml index fce6d25b..1849310f 100644 --- a/hannah/conf/dataset/chbmit.yaml +++ b/hannah/conf/dataset/chbmit.yaml @@ -21,4 +21,4 @@ dataset: chbmit split: common_channels_simple data_folder: ${hydra:runtime.cwd}/datasets/ samplingrate: 256 -dataset_name: samp256halfsec +dataset_name: 16c_retrain_id diff --git a/hannah/conf/dataset/chbmitrt.yaml b/hannah/conf/dataset/chbmitrt.yaml index 88a5f949..a745239c 100644 --- a/hannah/conf/dataset/chbmitrt.yaml +++ b/hannah/conf/dataset/chbmitrt.yaml @@ -21,4 +21,4 @@ dataset: chbmit split: common_channels_simple data_folder: ${hydra:runtime.cwd}/datasets/ samplingrate: 256 -dataset_name: samp256halfsec \ No newline at end of file +dataset_name: 16c_retrain_id \ No newline at end of file diff --git a/hannah/train.py b/hannah/train.py index 62810a05..b601acaf 100644 --- a/hannah/train.py +++ b/hannah/train.py @@ -145,24 +145,7 @@ def train( lit_module.setup("train") input_ckpt = pl_load(config.input_file) lit_module.load_state_dict(input_ckpt["state_dict"], strict=False) - - if config.dataset.get("retrain_patient", None): - if config.get("input_file", None): - for param in lit_module.model.parameters(): - param.requires_grad = False - - for name, module in lit_module.model.named_modules(): - if isinstance(module, nn.Linear): - print("Unfreezing weights and bias of", name) - - module.reset_parameters() - module.weight.requires_grad = True - if module.bias is not None: - module.bias.requires_grad = True - - else: - raise AttributeError("Patient-specific retraining requires a pretrained model. Please specify the model weights using the \"input_file\" parameter.") - + if config["auto_lr"]: # run lr finder (counts as one epoch) lr_finder = lit_trainer.lr_find(lit_module)