diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 6543dacdd2..8557f4d34c 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -132,6 +132,10 @@ def register_sophia_optimizers() -> List[str]: def register_bitsandbytes_optimizers() -> List[str]: """Register optimizers in ``bitsandbytes`` to the ``OPTIMIZERS`` registry. + In the `bitsandbytes` library, optimizers that have the same name as the + default optimizers in PyTorch are prefixed with ``bnb_``. For example, + ``bnb_Adagrad``. + Returns: List[str]: A list of registered optimizers' name. """ @@ -141,16 +145,14 @@ def register_bitsandbytes_optimizers() -> List[str]: except ImportError: pass else: - for module_name in [ - 'AdamW8bit', 'Adam8bit', 'Adagrad8bit', 'PagedAdam8bit', - 'PagedAdamW8bit', 'LAMB8bit', 'LARS8bit', 'RMSprop8bit', - 'Lion8bit', 'PagedLion8bit', 'SGD8bit' - ]: - _optim = getattr(bnb.optim, module_name) - if inspect.isclass(_optim) and issubclass(_optim, - torch.optim.Optimizer): - OPTIMIZERS.register_module(module=_optim) - dadaptation_optimizers.append(module_name) + optim_classes = inspect.getmembers( + bnb.optim, lambda _optim: (inspect.isclass(_optim) and issubclass( + _optim, torch.optim.Optimizer))) + for name, optim_cls in optim_classes: + if name in OPTIMIZERS: + name = f'bnb_{name}' + OPTIMIZERS.register_module(module=optim_cls, name=name) + dadaptation_optimizers.append(name) return dadaptation_optimizers