diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index ddbef1ff5a..72f30135d1 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -98,17 +98,20 @@ class DefaultOptimWrapperConstructor: >>> optim_wrapper = optim_wrapper_builder(model) Example 2: - >>> # assume model have attribute model.backbone and model.cls_head + >>> # assume model have attribute model.backbone, model.backbone.stem + >>> # and model.cls_head >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( >>> type='SGD', lr=0.01, weight_decay=0.95)) >>> paramwise_cfg = dict(custom_keys={ - >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9), + >>> 'backbone.stem': dict(requires_grad=False)}) >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( >>> optim_wrapper_cfg, paramwise_cfg) >>> optim_wrapper = optim_wrapper_builder(model) >>> # Then the `lr` and `weight_decay` for model.backbone is >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for - >>> # model.cls_head is (0.01, 0.95). + >>> # model.cls_head is (0.01, 0.95). the `grad` is invalid + >>> # for model.backbone.stem. """ def __init__(self, @@ -217,6 +220,9 @@ def add_params(self, if self.base_wd is not None: decay_mult = custom_keys[key].get('decay_mult', 1.) param_group['weight_decay'] = self.base_wd * decay_mult + requires_grad = custom_keys[key].get('requires_grad', True) + if not requires_grad: + param_group['params'][0].requires_grad_(requires_grad) # add custom settings to param_group for k, v in custom_keys[key].items(): param_group[k] = v diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index ddbda7e58d..24436d8ae5 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -636,7 +636,7 @@ def test_default_optimizer_constructor_custom_key(self): 'momentum': self.momentum, 'weight_decay': self.base_wd, }) - # group 3, matches of 'sub' + # group 3, matches of 'sub.conv1' groups.append(['sub.conv1.weight', 'sub.conv1.bias']) group_settings.append({ 'lr': self.base_lr * 0.1, @@ -674,7 +674,10 @@ def test_default_optimizer_constructor_custom_key(self): type='OptimWrapper', optimizer=dict( type='SGD', lr=self.base_lr, momentum=self.momentum)) - paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)}) + paramwise_cfg = dict(custom_keys={ + 'param1': dict(lr_mult=10), + 'sub.gn': dict(requires_grad=False) + }) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -697,11 +700,18 @@ def test_default_optimizer_constructor_custom_key(self): 'momentum': self.momentum, 'weight_decay': 0, }) - # group 2, default group + # group 2, matches of 'sub.gn' + groups.append(['sub.gn.weight', 'sub.gn.bias']) + group_settings.append({ + 'lr': self.base_lr, + 'momentum': self.momentum, + 'weight_decay': 0, + 'requires_grad': False, + }) + # group 3, default group groups.append([ - 'sub.conv1.weight', 'sub.conv1.bias', 'sub.gn.weight', - 'sub.gn.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', - 'bn.weight', 'bn.bias' + 'sub.conv1.weight', 'sub.conv1.bias', 'conv1.weight', + 'conv2.weight', 'conv2.bias', 'bn.weight', 'bn.bias' ]) group_settings.append({ 'lr': self.base_lr, @@ -716,8 +726,13 @@ def test_default_optimizer_constructor_custom_key(self): for group, settings in zip(groups, group_settings): if name in group: for setting in settings: - assert param_groups[i][setting] == settings[ - setting], f'{name} {setting}' + if setting == 'requires_grad': + assert param_groups[i][setting] == settings[ + setting] == param_groups[i]['params'][ + 0].requires_grad, f'{name} {setting}' + else: + assert param_groups[i][setting] == settings[ + setting], f'{name} {setting}' @unittest.skipIf(