Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Speed up optimizer. #909

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions mmengine/optim/optimizer/default_constructor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import logging
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -9,8 +10,8 @@
from mmengine.logging import print_log
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
OPTIMIZERS)
from mmengine.utils import is_list_of
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.utils import digit_version, is_list_of
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from .optimizer_wrapper import OptimWrapper

Expand Down Expand Up @@ -52,6 +53,10 @@
of a model.
- ``bypass_duplicate`` (bool): If true, the duplicate parameters
would not be added into optimizer. Defaults to False.
- ``reduce_param_groups`` (bool): If true, constructor will cluster the
parameter groups with the same learning rate, momentum and other
parameters, which can speed up the optimizer. Defaults to true.
New in version 0.7.2.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a newer version number? 0.10.4?


Note:

Expand Down Expand Up @@ -293,6 +298,9 @@
optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
optimizer_cfg = self.optimizer_cfg.copy()
optimizer_type = self.optimizer_cfg['type']
if isinstance(optimizer_type, str):
optimizer_type = OPTIMIZERS.get(optimizer_type)
# if no paramwise option is specified, just use the global setting
if not self.paramwise_cfg:
optimizer_cfg['params'] = model.parameters()
Expand All @@ -301,8 +309,48 @@
# set param-wise lr and weight decay recursively
params: List = []
self.add_params(params, model)
optimizer_cfg['params'] = params
# grouping parameters with the same hyper-parameters
if self.paramwise_cfg.get('reduce_param_groups', True):
optimizer_cfg['params'] = self.reduce_param_groups(params)

Check warning on line 314 in mmengine/optim/optimizer/default_constructor.py

View check run for this annotation

Codecov / codecov/patch

mmengine/optim/optimizer/default_constructor.py#L314

Added line #L314 was not covered by tests
else:
optimizer_cfg['params'] = params
# enable foreach for pytorch 1.12.0+ to speed up training
if (digit_version(TORCH_VERSION) >= digit_version('1.12.0') and
'foreach' in inspect.getfullargspec(optimizer_type).args):
optimizer_cfg.setdefault('foreach', True)

Check warning on line 320 in mmengine/optim/optimizer/default_constructor.py

View check run for this annotation

Codecov / codecov/patch

mmengine/optim/optimizer/default_constructor.py#L320

Added line #L320 was not covered by tests
else:
optimizer_cfg.pop('foreach', None)

optimizer = OPTIMIZERS.build(optimizer_cfg)
optim_wrapper = OPTIM_WRAPPERS.build(
optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper

@staticmethod
def reduce_param_groups(
params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Reorganize the parameter groups and merge duplicated groups. The
number of parameter groups needs to be as small as possible in order to
efficiently use the PyTorch multi-tensor optimizer. Therefore instead
of using a parameter_group per single parameter, we reorganize the
parameter groups and merge duplicated groups. This approach speeds up
multi-tensor optimizer significantly.

References: https://github.com/facebookresearch/detectron2/blob/main/detectron2/solver/build.py

Args:
params (List[Dict[str, Any]]): The parameter groups.

Returns:
List[Dict[str, Any]]: The reorganized parameter groups.
""" # noqa: E501
groups: dict = dict()

for item in params:
hyper_params_id = tuple(
(x, y) for x, y in item.items() if x != 'params')
if hyper_params_id in groups:
groups[hyper_params_id]['params'].extend(item['params'])
else:
groups[hyper_params_id] = item
return list(groups.values())
66 changes: 56 additions & 10 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def _check_sgd_optimizer(self,
dwconv_decay_mult=1,
dcn_offset_lr_mult=1,
flat_decay_mult=1,
bypass_duplicate=False):
bypass_duplicate=False,
reduce_param_groups=False):
param_groups = optimizer.param_groups
assert isinstance(optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
Expand Down Expand Up @@ -289,7 +290,8 @@ def test_build_default_optimizer_constructor(self):
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3)
flat_decay_mult=0.3,
reduce_param_groups=False)
optim_constructor_cfg = dict(
type='DefaultOptimWrapperConstructor',
optim_wrapper_cfg=optim_wrapper,
Expand Down Expand Up @@ -422,7 +424,8 @@ def test_default_optimizer_constructor_with_model_wrapper(self):
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3)
flat_decay_mult=0.3,
reduce_param_groups=False)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
Expand Down Expand Up @@ -462,7 +465,8 @@ def test_default_optimizer_constructor_with_model_wrapper(self):
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3)
flat_decay_mult=0.3,
reduce_param_groups=False)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
Expand Down Expand Up @@ -518,7 +522,8 @@ def test_default_optimizer_constructor_with_paramwise_cfg(self):
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3)
flat_decay_mult=0.3,
reduce_param_groups=False)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(self.model)
Expand All @@ -539,7 +544,8 @@ def test_default_optimizer_constructor_no_grad(self):
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
dcn_offset_lr_mult=0.1,
reduce_param_groups=False)

for param in self.model.parameters():
param.requires_grad = False
Expand Down Expand Up @@ -573,7 +579,8 @@ def test_default_optimizer_constructor_bypass_duplicate(self):
bias_lr_mult=2,
bias_decay_mult=0.5,
norm_decay_mult=0,
dwconv_decay_mult=0.1)
dwconv_decay_mult=0.1,
reduce_param_groups=False)

with self.assertRaisesRegex(
ValueError,
Expand All @@ -589,7 +596,8 @@ def test_default_optimizer_constructor_bypass_duplicate(self):
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1,
flat_decay_mult=0.3,
bypass_duplicate=True)
bypass_duplicate=True,
reduce_param_groups=False)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)

Expand Down Expand Up @@ -635,7 +643,8 @@ def test_default_optimizer_constructor_custom_key(self):
'sub.gn': dict(lr_mult=0.01),
'non_exist_key': dict(lr_mult=0.0)
},
norm_decay_mult=0.5)
norm_decay_mult=0.5,
reduce_param_groups=False)

with self.assertRaises(TypeError):
# custom_keys should be a dict
Expand Down Expand Up @@ -722,7 +731,9 @@ def test_default_optimizer_constructor_custom_key(self):
type='OptimWrapper',
optimizer=dict(
type='SGD', lr=self.base_lr, momentum=self.momentum))
paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)})
paramwise_cfg = dict(
custom_keys={'param1': dict(lr_mult=10)},
reduce_param_groups=False)

optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
Expand Down Expand Up @@ -767,6 +778,41 @@ def test_default_optimizer_constructor_custom_key(self):
assert param_groups[i][setting] == settings[
setting], f'{name} {setting}'

def test_reduce_param_groups(self):
# ref: https://github.com/facebookresearch/detectron2/blob/main/tests/test_solver.py # noqa: E501
params = [
dict(params=['p1'], lr=1.0, weight_decay=4.0),
dict(params=['p2', 'p6'], lr=2.0, weight_decay=3.0, momentum=2.0),
dict(params=['p3'], lr=2.0, weight_decay=3.0, momentum=2.0),
dict(params=['p4'], lr=1.0, weight_decay=3.0),
dict(params=['p5'], lr=2.0, momentum=2.0),
]
gt_groups = [
{
'lr': 1.0,
'weight_decay': 4.0,
'params': ['p1'],
},
{
'lr': 2.0,
'weight_decay': 3.0,
'momentum': 2.0,
'params': ['p2', 'p6', 'p3'],
},
{
'lr': 1.0,
'weight_decay': 3.0,
'params': ['p4'],
},
{
'lr': 2.0,
'momentum': 2.0,
'params': ['p5'],
},
]
out = DefaultOptimWrapperConstructor.reduce_param_groups(params)
self.assertEqual(out, gt_groups)


@unittest.skipIf(
(digit_version(TORCH_VERSION) < digit_version('1.8.0'))
Expand Down