Skip to content

Commit

Permalink
Merge branch 'Remove_freezing_retrain' into 'main'
Browse files Browse the repository at this point in the history
Remove freezing of layer for retraining since we achieve better results by...

See merge request es/ai/hannah/hannah!352
cgerum committed Nov 15, 2023
2 parents 7ef26fd + c3901be commit 8317fb2
Showing 3 changed files with 3 additions and 20 deletions.
2 changes: 1 addition & 1 deletion hannah/conf/dataset/chbmit.yaml
Original file line number Diff line number Diff line change
@@ -21,4 +21,4 @@ dataset: chbmit
split: common_channels_simple
data_folder: ${hydra:runtime.cwd}/datasets/
samplingrate: 256
dataset_name: samp256halfsec
dataset_name: 16c_retrain_id
2 changes: 1 addition & 1 deletion hannah/conf/dataset/chbmitrt.yaml
Original file line number Diff line number Diff line change
@@ -21,4 +21,4 @@ dataset: chbmit
split: common_channels_simple
data_folder: ${hydra:runtime.cwd}/datasets/
samplingrate: 256
dataset_name: samp256halfsec
dataset_name: 16c_retrain_id
19 changes: 1 addition & 18 deletions hannah/train.py
Original file line number Diff line number Diff line change
@@ -145,24 +145,7 @@ def train(
lit_module.setup("train")
input_ckpt = pl_load(config.input_file)
lit_module.load_state_dict(input_ckpt["state_dict"], strict=False)

if config.dataset.get("retrain_patient", None):
if config.get("input_file", None):
for param in lit_module.model.parameters():
param.requires_grad = False

for name, module in lit_module.model.named_modules():
if isinstance(module, nn.Linear):
print("Unfreezing weights and bias of", name)

module.reset_parameters()
module.weight.requires_grad = True
if module.bias is not None:
module.bias.requires_grad = True

else:
raise AttributeError("Patient-specific retraining requires a pretrained model. Please specify the model weights using the \"input_file\" parameter.")


if config["auto_lr"]:
# run lr finder (counts as one epoch)
lr_finder = lit_trainer.lr_find(lit_module)

0 comments on commit 8317fb2

Please sign in to comment.