From 8bcbda3dc94ce22f469531679af1b5d490104cdb Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Thu, 26 Oct 2023 17:44:46 +0800 Subject: [PATCH] [Feature] Support using lazy object in list, tuple and set --- mmengine/config/__init__.py | 8 +- mmengine/config/config.py | 201 ++++++++++++------ .../config/lazy_module_config/toy_model.py | 7 + tests/test_config/test_config.py | 43 +++- 4 files changed, 186 insertions(+), 73 deletions(-) diff --git a/mmengine/config/__init__.py b/mmengine/config/__init__.py index 9a1bc47db4..c904372206 100644 --- a/mmengine/config/__init__.py +++ b/mmengine/config/__init__.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .config import Config, ConfigDict, DictAction, read_base +from .config import (Config, ConfigDict, ConfigList, ConfigSet, ConfigTuple, + DictAction, read_base) -__all__ = ['Config', 'ConfigDict', 'DictAction', 'read_base'] +__all__ = [ + 'Config', 'ConfigDict', 'DictAction', 'read_base', 'ConfigList', + 'ConfigSet', 'ConfigTuple' +] diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 316ac65d4d..58e597f50f 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -45,19 +45,87 @@ import re # type: ignore -def _lazy2string(cfg_dict, dict_type=None): - if isinstance(cfg_dict, dict): - dict_type = dict_type or type(cfg_dict) - return dict_type({k: _lazy2string(v) for k, v in dict.items(cfg_dict)}) - elif isinstance(cfg_dict, (tuple, list)): - return type(cfg_dict)(_lazy2string(v) for v in cfg_dict) - elif isinstance(cfg_dict, (LazyAttr, LazyObject)): - return f'{cfg_dict.module}.{str(cfg_dict)}' - else: - return cfg_dict +class LazyContainerMeta(type): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lazy = False + + +class LazyContainerMixin(metaclass=LazyContainerMeta): + + def to_builtin(self, keep_lazy=False): + + def _to_builtin(cfg): + if isinstance(cfg, dict): + return dict({k: _to_builtin(v) for k, v in dict.items(cfg)}) + elif isinstance(cfg, tuple): + return tuple(_to_builtin(v) for v in tuple.__iter__(cfg)) + elif isinstance(cfg, list): + return list(_to_builtin(v) for v in list.__iter__(cfg)) + elif isinstance(cfg, set): + return {_to_builtin(v) for v in set.__iter__(cfg)} + elif isinstance(cfg, (LazyAttr, LazyObject)): + if not keep_lazy: + return f'{cfg.module}.{str(cfg)}' + else: + return cfg + else: + return cfg + + return _to_builtin(self) + + def build_lazy(self, value: Any) -> Any: + """If class attribute ``lazy`` is False, the LazyObject will be built + and returned. + Args: + value (Any): The value to be built. -class ConfigDict(Dict): + Returns: + Any: The built value. + """ + if (isinstance(value, (LazyAttr, LazyObject)) + and not self.__class__.lazy): + value = value.build() + return value + + def __deepcopy__(self, memo): + return self.__class__( + copy.deepcopy(item, memo) for item in super().__iter__()) + + def __copy__(self): + return self.__class__(item for item in super().__iter__()) + + def __iter__(self): + # Override `__iter__` to overwrite to support star unpacking + # `*cfg_list` + yield from map(self.build_lazy, super().__iter__()) + + def __getitem__(self, idx): + try: + value = self.build_lazy(super().__getitem__(idx)) + except Exception as e: + raise e + else: + return value + + def __eq__(self, other): + return all(a == b for a, b in zip(self, other)) + + def __reduce_ex__(self, proto): + # Override __reduce_ex__ to avoid dump the built lazy object. + if digit_version(platform.python_version()) < digit_version('3.8'): + return (self.__class__, (tuple(i for i in super().__iter__()), ), + None, None, None) + else: + return (self.__class__, (tuple(i for i in super().__iter__()), ), + None, None, None, None) + + copy = __copy__ + + +class ConfigDict(LazyContainerMixin, Dict): """A dictionary for config which has the same interface as python's built- in dictionary and can be used as a normal dictionary. @@ -72,7 +140,6 @@ class ConfigDict(Dict): object during configuration parsing, and it should be set to False outside the Config to ensure that users do not experience the ``LazyObject``. """ - lazy = False def __init__(__self, *args, **kwargs): object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) @@ -118,8 +185,14 @@ def __getattr__(self, name): @classmethod def _hook(cls, item): # avoid to convert user defined dict to ConfigDict. - if type(item) in (dict, OrderedDict): + if isinstance(item, ConfigDict): + return item + elif type(item) in (dict, OrderedDict): return cls(item) + elif isinstance(item, LazyContainerMixin): + return type(item)( + cls._hook(elem) + for elem in super(LazyContainerMixin, item).__iter__()) elif isinstance(item, (list, tuple)): return type(item)(cls._hook(elem) for elem in item) return item @@ -150,11 +223,6 @@ def __copy__(self): copy = __copy__ - def __iter__(self): - # Implement `__iter__` to overwrite the unpacking operator `**cfg_dict` - # to get the built lazy object - return iter(self.keys()) - def get(self, key: str, default: Optional[Any] = None) -> Any: """Get the value of the key. If class attribute ``lazy`` is True, the LazyObject will be built and returned. @@ -201,20 +269,6 @@ def update(self, *args, **kwargs) -> None: else: self[k].update(v) - def build_lazy(self, value: Any) -> Any: - """If class attribute ``lazy`` is False, the LazyObject will be built - and returned. - - Args: - value (Any): The value to be built. - - Returns: - Any: The built value. - """ - if isinstance(value, (LazyAttr, LazyObject)) and not self.lazy: - value = value.build() - return value - def values(self): """Yield the values of the dictionary. @@ -288,28 +342,28 @@ def __eq__(self, other): return False def _to_lazy_dict(self): - """Convert the ConfigDict to a normal dictionary recursively, and keep - the ``LazyObject`` or ``LazyAttr`` object not built.""" + # NOTE: Keep this function for backward compatibility. - def _to_dict(data): - if isinstance(data, ConfigDict): - return { - key: _to_dict(value) - for key, value in Dict.items(data) - } - elif isinstance(data, dict): - return {key: _to_dict(value) for key, value in data.items()} - elif isinstance(data, (list, tuple)): - return type(data)(_to_dict(item) for item in data) - else: - return data - - return _to_dict(self) + return self.to_builtin(keep_lazy=True) def to_dict(self): - """Convert the ConfigDict to a normal dictionary recursively, and - convert the ``LazyObject`` or ``LazyAttr`` to string.""" - return _lazy2string(self, dict_type=dict) + # NOTE: Keep this function for backward compatibility. + return self.to_builtin() + + +class ConfigList(LazyContainerMixin, list): # type: ignore + ... + + +class ConfigTuple(LazyContainerMixin, tuple): # type: ignore + ... + + +class ConfigSet(LazyContainerMixin, set): # type: ignore + ... + + def pop(self, idx): + return self.build_lazy(super().pop(idx)) def add_args(parser: ArgumentParser, @@ -479,20 +533,8 @@ def fromfile(filename: Union[str, Path], env_variables=env_variables, ) else: - # Enable lazy import when parsing the config. - # Using try-except to make sure ``ConfigDict.lazy`` will be reset - # to False. See more details about lazy in the docstring of - # ConfigDict - ConfigDict.lazy = True - try: + with Config._lazy_context(): cfg_dict, imported_names = Config._parse_lazy_import(filename) - except Exception as e: - raise e - finally: - # disable lazy import to get the real type. See more details - # about lazy in the docstring of ConfigDict - ConfigDict.lazy = False - cfg = Config( cfg_dict, filename=filename, @@ -500,6 +542,21 @@ def fromfile(filename: Union[str, Path], object.__setattr__(cfg, '_imported_names', imported_names) return cfg + @staticmethod + @contextmanager + def _lazy_context(): + ConfigDict.lazy = True + ConfigSet.lazy = True + ConfigList.lazy = True + ConfigTuple.lazy = True + + yield + + ConfigDict.lazy = False + ConfigSet.lazy = False + ConfigList.lazy = False + ConfigTuple.lazy = False + @staticmethod def fromstring(cfg_str: str, file_format: str) -> 'Config': """Build a Config instance from config text. @@ -1110,12 +1167,12 @@ def _parse_lazy_import(filename: str) -> Tuple[ConfigDict, set]: continue ret[key] = value # convert dict to ConfigDict - cfg_dict = Config._dict_to_config_dict_lazy(ret) + cfg_dict = Config._to_lazy_container(ret) return cfg_dict, imported_names @staticmethod - def _dict_to_config_dict_lazy(cfg: dict): + def _to_lazy_container(cfg: dict): """Recursively converts ``dict`` to :obj:`ConfigDict`. The only difference between ``_dict_to_config_dict_lazy`` and ``_dict_to_config_dict_lazy`` is that the former one does not consider @@ -1131,11 +1188,15 @@ def _dict_to_config_dict_lazy(cfg: dict): if isinstance(cfg, dict): cfg_dict = ConfigDict() for key, value in cfg.items(): - cfg_dict[key] = Config._dict_to_config_dict_lazy(value) + cfg_dict[key] = Config._to_lazy_container(value) return cfg_dict - if isinstance(cfg, (tuple, list)): - return type(cfg)( - Config._dict_to_config_dict_lazy(_cfg) for _cfg in cfg) + if isinstance(cfg, list): + return ConfigList(Config._to_lazy_container(_cfg) for _cfg in cfg) + if isinstance(cfg, tuple): + return ConfigTuple(Config._to_lazy_container(_cfg) for _cfg in cfg) + if isinstance(cfg, set): + return ConfigSet(Config._to_lazy_container(_cfg) for _cfg in cfg) + return cfg @staticmethod diff --git a/tests/data/config/lazy_module_config/toy_model.py b/tests/data/config/lazy_module_config/toy_model.py index a9d2a3f64a..3f2bbab383 100644 --- a/tests/data/config/lazy_module_config/toy_model.py +++ b/tests/data/config/lazy_module_config/toy_model.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from torch.distributed.fsdp.wrap import (size_based_auto_wrap_policy, + transformer_auto_wrap_policy) + +from mmengine._strategy import ColossalAIStrategy from mmengine.config import read_base from mmengine.dataset import DefaultSampler from mmengine.hooks import EMAHook @@ -46,4 +50,7 @@ priority=49) ] +# illegal model wrapper config, just for unit test. +strategy = dict(type=ColossalAIStrategy, model_wrapper=dict( + auto_wrap_policy=(size_based_auto_wrap_policy, transformer_auto_wrap_policy))) runner_type = FlexibleRunner diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 905485c16a..3f74abb3e7 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -15,7 +15,7 @@ import pytest import mmengine -from mmengine import Config, ConfigDict, DictAction +from mmengine import Config, ConfigDict, ConfigList, DictAction from mmengine.config.lazy import LazyObject from mmengine.fileio import dump, load from mmengine.registry import MODELS, DefaultScope, Registry @@ -1209,3 +1209,44 @@ def _recursive_check_lazy(self, cfg, expr): [self._recursive_check_lazy(value, expr) for value in cfg] else: self.assertTrue(expr(cfg)) + + +class TestConfigList(TestCase): + + def test_getitem(self): + cfg_list = ConfigList([ + 1, 2, + ConfigDict(type=LazyObject('mmengine')), + LazyObject('mmengine') + ]) + self.assertIs(cfg_list[2]['type'], mmengine) + self.assertIs(cfg_list[3], mmengine) + + def test_star(self): + + def check_star(a, b, c): + self.assertIs(c, mmengine) + + cfg_list = ConfigList([1, 2, LazyObject('mmengine')]) + check_star(*cfg_list) + + def check_for_loop(self): + cfg_list = ConfigList([LazyObject('mmengine')]) + for i in cfg_list: + self.assertIs(i, mmengine) + + def test_copy(self): + cfg_list = ConfigList([ + 1, 2, + ConfigDict(type=LazyObject('mmengine')), + LazyObject('mmengine') + ]) + cfg_copy = cfg_list.copy() + self.assertIsInstance(cfg_copy, ConfigList) + self.assertEqual(cfg_list, cfg_copy) + self.assertIs(cfg_list[2], cfg_copy[2]) + + cfg_copy = copy.deepcopy(cfg_list) + self.assertIsInstance(cfg_copy, ConfigList) + self.assertEqual(cfg_list, cfg_copy) + self.assertIsNot(cfg_list[2], cfg_copy[2])