From e0a975374082416528bcf5858713b3188dad66d0 Mon Sep 17 00:00:00 2001 From: canqunxiang Date: Fri, 25 Nov 2022 12:56:26 +0800 Subject: [PATCH 1/5] add requires_grad feat --- mmengine/optim/optimizer/default_constructor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index ddbef1ff5a..63490d7dd5 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -217,6 +217,9 @@ def add_params(self, if self.base_wd is not None: decay_mult = custom_keys[key].get('decay_mult', 1.) param_group['weight_decay'] = self.base_wd * decay_mult + requires_grad = custom_keys[key].get('requires_grad', True) + if not requires_grad: + param_group['params'][0].requires_grad = requires_grad # add custom settings to param_group for k, v in custom_keys[key].items(): param_group[k] = v From 022ef78d83330f6e46161d7c3046af9b5dfeb83d Mon Sep 17 00:00:00 2001 From: cq Date: Sat, 26 Nov 2022 11:22:04 +0800 Subject: [PATCH 2/5] replace inplace op --- mmengine/optim/optimizer/default_constructor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index 63490d7dd5..1d961a117d 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -219,7 +219,7 @@ def add_params(self, param_group['weight_decay'] = self.base_wd * decay_mult requires_grad = custom_keys[key].get('requires_grad', True) if not requires_grad: - param_group['params'][0].requires_grad = requires_grad + param_group['params'][0].requires_grad_(requires_grad) # add custom settings to param_group for k, v in custom_keys[key].items(): param_group[k] = v From effc23fbcaaa1c3251909b6e4ce27733ae0a8d88 Mon Sep 17 00:00:00 2001 From: cq Date: Sat, 26 Nov 2022 23:53:56 +0800 Subject: [PATCH 3/5] add requires_grad in custom_keys unit test --- .../test_optimizer/test_optimizer.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index ddbda7e58d..6ba01a0910 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -636,7 +636,7 @@ def test_default_optimizer_constructor_custom_key(self): 'momentum': self.momentum, 'weight_decay': self.base_wd, }) - # group 3, matches of 'sub' + # group 3, matches of 'sub.conv1' groups.append(['sub.conv1.weight', 'sub.conv1.bias']) group_settings.append({ 'lr': self.base_lr * 0.1, @@ -674,7 +674,10 @@ def test_default_optimizer_constructor_custom_key(self): type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, momentum=self.momentum)) - paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)}) + paramwise_cfg = dict(custom_keys={ + 'param1': dict(lr_mult=10), + 'sub.gn': dict(requires_grad=False) + }) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -697,11 +700,18 @@ def test_default_optimizer_constructor_custom_key(self): 'momentum': self.momentum, 'weight_decay': 0, }) - # group 2, default group + # group 2, matches of 'sub.gn' + groups.append(['sub.gn.weight', 'sub.gn.bias']) + group_settings.append({ + 'lr': self.base_lr, + 'momentum': self.momentum, + 'weight_decay': 0, + 'requires_grad': False, + }) + # group 3, default group groups.append([ - 'sub.conv1.weight', 'sub.conv1.bias', 'sub.gn.weight', - 'sub.gn.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', - 'bn.weight', 'bn.bias' + 'sub.conv1.weight', 'sub.conv1.bias', 'conv1.weight', + 'conv2.weight', 'conv2.bias', 'bn.weight', 'bn.bias' ]) group_settings.append({ 'lr': self.base_lr, From 66be2945208fca2e4780349a056d416e8095f727 Mon Sep 17 00:00:00 2001 From: cq Date: Sun, 27 Nov 2022 00:05:31 +0800 Subject: [PATCH 4/5] fix test logic --- tests/test_optim/test_optimizer/test_optimizer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 6ba01a0910..24436d8ae5 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -726,8 +726,13 @@ def test_default_optimizer_constructor_custom_key(self): for group, settings in zip(groups, group_settings): if name in group: for setting in settings: - assert param_groups[i][setting] == settings[ - setting], f'{name} {setting}' + if setting == 'requires_grad': + assert param_groups[i][setting] == settings[ + setting] == param_groups[i]['params'][ + 0].requires_grad, f'{name} {setting}' + else: + assert param_groups[i][setting] == settings[ + setting], f'{name} {setting}' @unittest.skipIf( From ab2f6f40dfc85186fd6447b8c460647a0150e069 Mon Sep 17 00:00:00 2001 From: cq Date: Sun, 27 Nov 2022 20:38:13 +0800 Subject: [PATCH 5/5] supplement the usage of required_grad --- mmengine/optim/optimizer/default_constructor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index 1d961a117d..72f30135d1 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -98,17 +98,20 @@ class DefaultOptimWrapperConstructor: >>> optim_wrapper = optim_wrapper_builder(model) Example 2: - >>> # assume model have attribute model.backbone and model.cls_head + >>> # assume model have attribute model.backbone, model.backbone.stem + >>> # and model.cls_head >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( >>> type='SGD', lr=0.01, weight_decay=0.95)) >>> paramwise_cfg = dict(custom_keys={ - >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9), + >>> 'backbone.stem': dict(requires_grad=False)}) >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( >>> optim_wrapper_cfg, paramwise_cfg) >>> optim_wrapper = optim_wrapper_builder(model) >>> # Then the `lr` and `weight_decay` for model.backbone is >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for - >>> # model.cls_head is (0.01, 0.95). + >>> # model.cls_head is (0.01, 0.95). the `grad` is invalid + >>> # for model.backbone.stem. """ def __init__(self,