From 9477ac43ae0c1e9c79eb68141d4a72bc1008c875 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Thu, 26 Oct 2023 17:44:46 +0800 Subject: [PATCH] cherry-pick config container --- mmengine/config/__init__.py | 8 +- mmengine/config/config.py | 160 ++++++++++++------ mmengine/config/new_config.py | 40 +++-- .../config/lazy_module_config/toy_model.py | 7 + tests/test_config/test_config.py | 49 +++++- 5 files changed, 195 insertions(+), 69 deletions(-) diff --git a/mmengine/config/__init__.py b/mmengine/config/__init__.py index 71292d274c..fc0e5cea2e 100644 --- a/mmengine/config/__init__.py +++ b/mmengine/config/__init__.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .config import Config, ConfigDict, DictAction +from .config import (Config, ConfigDict, ConfigList, ConfigSet, ConfigTuple, + DictAction) from .new_config import read_base -__all__ = ['Config', 'ConfigDict', 'DictAction', 'read_base'] +__all__ = [ + 'Config', 'ConfigDict', 'DictAction', 'read_base', 'ConfigList', + 'ConfigSet', 'ConfigTuple' +] diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 116c09ab85..7196fcbe84 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -28,19 +28,86 @@ DELETE_KEY = '_delete_' -def _lazy2string(cfg_dict, dict_type=None): - if isinstance(cfg_dict, dict): - dict_type = dict_type or type(cfg_dict) - return dict_type({k: _lazy2string(v) for k, v in dict.items(cfg_dict)}) - elif isinstance(cfg_dict, (tuple, list)): - return type(cfg_dict)(_lazy2string(v) for v in cfg_dict) - elif isinstance(cfg_dict, LazyObject): - return cfg_dict.dump_str - else: - return cfg_dict - - -class ConfigDict(Dict): +class LazyContainerMeta(type): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lazy = False + + +class LazyContainerMixin(metaclass=LazyContainerMeta): + + def to_builtin(self, keep_lazy=False): + + def _to_builtin(cfg): + if isinstance(cfg, dict): + return dict({k: _to_builtin(v) for k, v in dict.items(cfg)}) + elif isinstance(cfg, tuple): + return tuple(_to_builtin(v) for v in tuple.__iter__(cfg)) + elif isinstance(cfg, list): + return list(_to_builtin(v) for v in list.__iter__(cfg)) + elif isinstance(cfg, set): + return {_to_builtin(v) for v in set.__iter__(cfg)} + elif isinstance(cfg, LazyObject): + if not keep_lazy: + return cfg.dump_str + else: + return cfg + else: + return cfg + + return _to_builtin(self) + + def build_lazy(self, value: Any) -> Any: + """If class attribute ``lazy`` is False, the LazyObject will be built + and returned. + + Args: + value (Any): The value to be built. + + Returns: + Any: The built value. + """ + if (isinstance(value, LazyObject) and not self.__class__.lazy): + value = value.build() + return value + + def __deepcopy__(self, memo): + return self.__class__( + copy.deepcopy(item, memo) for item in super().__iter__()) + + def __copy__(self): + return self.__class__(item for item in super().__iter__()) + + def __iter__(self): + # Override `__iter__` to overwrite to support star unpacking + # `*cfg_list` + yield from map(self.build_lazy, super().__iter__()) + + def __getitem__(self, idx): + try: + value = self.build_lazy(super().__getitem__(idx)) + except Exception as e: + raise e + else: + return value + + def __eq__(self, other): + return all(a == b for a, b in zip(self, other)) + + def __reduce_ex__(self, proto): + # Override __reduce_ex__ to avoid dump the built lazy object. + if digit_version(platform.python_version()) < digit_version('3.8'): + return (self.__class__, (tuple(i for i in super().__iter__()), ), + None, None, None) + else: + return (self.__class__, (tuple(i for i in super().__iter__()), ), + None, None, None, None) + + copy = __copy__ + + +class ConfigDict(LazyContainerMixin, Dict): """A dictionary for config which has the same interface as python's built- in dictionary and can be used as a normal dictionary. @@ -55,7 +122,6 @@ class ConfigDict(Dict): object during configuration parsing, and it should be set to False outside the Config to ensure that users do not experience the ``LazyObject``. """ - lazy = False def __init__(__self, *args, **kwargs): object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) @@ -101,8 +167,14 @@ def __getattr__(self, name): @classmethod def _hook(cls, item): # avoid to convert user defined dict to ConfigDict. - if type(item) in (dict, OrderedDict): + if isinstance(item, ConfigDict): + return item + elif type(item) in (dict, OrderedDict): return cls(item) + elif isinstance(item, LazyContainerMixin): + return type(item)( + cls._hook(elem) + for elem in super(LazyContainerMixin, item).__iter__()) elif isinstance(item, (list, tuple)): return type(item)(cls._hook(elem) for elem in item) return item @@ -133,11 +205,6 @@ def __copy__(self): copy = __copy__ - def __iter__(self): - # Implement `__iter__` to overwrite the unpacking operator `**cfg_dict` - # to get the built lazy object - return iter(self.keys()) - def get(self, key: str, default: Optional[Any] = None) -> Any: """Get the value of the key. If class attribute ``lazy`` is True, the LazyObject will be built and returned. @@ -184,20 +251,6 @@ def update(self, *args, **kwargs) -> None: else: self[k].update(v) - def build_lazy(self, value: Any) -> Any: - """If class attribute ``lazy`` is False, the LazyObject will be built - and returned. - - Args: - value (Any): The value to be built. - - Returns: - Any: The built value. - """ - if isinstance(value, LazyObject) and not self.lazy: - value = value.build() - return value - def values(self): """Yield the values of the dictionary. @@ -271,28 +324,29 @@ def __eq__(self, other): return False def _to_lazy_dict(self): - """Convert the ConfigDict to a normal dictionary recursively, and keep - the ``LazyObject`` or ``LazyAttr`` object not built.""" - - def _to_dict(data): - if isinstance(data, ConfigDict): - return { - key: _to_dict(value) - for key, value in Dict.items(data) - } - elif isinstance(data, dict): - return {key: _to_dict(value) for key, value in data.items()} - elif isinstance(data, (list, tuple)): - return type(data)(_to_dict(item) for item in data) - else: - return data + # NOTE: Keep this function for backward compatibility. - return _to_dict(self) + return self.to_builtin(keep_lazy=True) def to_dict(self): - """Convert the ConfigDict to a normal dictionary recursively, and - convert the ``LazyObject`` or ``LazyAttr`` to string.""" - return _lazy2string(self, dict_type=dict) + # NOTE: Keep this function for backward compatibility. + return self.to_builtin() + + +class ConfigList(LazyContainerMixin, list): # type: ignore + + def pop(self, idx): + return self.build_lazy(super().pop(idx)) + + +class ConfigTuple(LazyContainerMixin, tuple): # type: ignore + ... + + +class ConfigSet(LazyContainerMixin, set): # type: ignore + + def pop(self, idx): + return self.build_lazy(super().pop(idx)) def add_args(parser: ArgumentParser, diff --git a/mmengine/config/new_config.py b/mmengine/config/new_config.py index c8d6f93d4a..de16175f8d 100644 --- a/mmengine/config/new_config.py +++ b/mmengine/config/new_config.py @@ -5,6 +5,7 @@ import os import platform import sys +from contextlib import contextmanager from importlib.machinery import PathFinder from pathlib import Path from types import BuiltinFunctionType, FunctionType, ModuleType @@ -12,7 +13,7 @@ from yapf.yapflib.yapf_api import FormatCode -from .config import Config, ConfigDict +from .config import Config, ConfigDict, ConfigList, ConfigSet, ConfigTuple from .lazy import LazyImportContext, LazyObject, recover_lazy_field RESERVED_KEYS = ['filename', 'text', 'pretty_text'] @@ -177,8 +178,7 @@ def fromfile( # type: ignore # Using try-except to make sure ``ConfigDict.lazy`` will be reset # to False. See more details about lazy in the docstring of # ConfigDict - ConfigDict.lazy = True - try: + with ConfigV2._lazy_context(): module = ConfigV2._get_config_module(filename) module_dict = { k: getattr(module, k) @@ -193,7 +193,6 @@ def fromfile( # type: ignore cfg_dict, filename=filename, format_python_code=format_python_code) - finally: ConfigDict.lazy = False global _CFG_UID _CFG_UID = 0 @@ -203,6 +202,21 @@ def fromfile( # type: ignore return cfg + @staticmethod + @contextmanager + def _lazy_context(): + ConfigDict.lazy = True + ConfigSet.lazy = True + ConfigList.lazy = True + ConfigTuple.lazy = True + + yield + + ConfigDict.lazy = False + ConfigSet.lazy = False + ConfigList.lazy = False + ConfigTuple.lazy = False + @staticmethod def _get_config_module(filename: Union[str, Path]): file = Path(filename).absolute() @@ -222,7 +236,7 @@ def _get_config_module(filename: Union[str, Path]): return module @staticmethod - def _dict_to_config_dict_lazy(cfg: dict): + def _to_lazy_container(cfg: dict): """Recursively converts ``dict`` to :obj:`ConfigDict`. The only difference between ``_dict_to_config_dict_lazy`` and ``_dict_to_config_dict_lazy`` is that the former one does not consider @@ -238,11 +252,17 @@ def _dict_to_config_dict_lazy(cfg: dict): if isinstance(cfg, dict): cfg_dict = ConfigDict() for key, value in cfg.items(): - cfg_dict[key] = ConfigV2._dict_to_config_dict_lazy(value) + cfg_dict[key] = ConfigV2._to_lazy_container(value) return cfg_dict - if isinstance(cfg, (tuple, list)): - return type(cfg)( - ConfigV2._dict_to_config_dict_lazy(_cfg) for _cfg in cfg) + if isinstance(cfg, list): + return ConfigList( + ConfigV2._to_lazy_container(_cfg) for _cfg in cfg) + if isinstance(cfg, tuple): + return ConfigTuple( + ConfigV2._to_lazy_container(_cfg) for _cfg in cfg) + if isinstance(cfg, set): + return ConfigSet(ConfigV2._to_lazy_container(_cfg) for _cfg in cfg) + return cfg @property @@ -433,7 +453,7 @@ def new_import(name, globals=None, locals=None, fromlist=(), level=0): mod = ConfigV2._get_config_module(cur_file) for k in dir(mod): - mod.__dict__[k] = ConfigV2._dict_to_config_dict_lazy( + mod.__dict__[k] = ConfigV2._to_lazy_container( getattr(mod, k)) else: mod = old_import( diff --git a/tests/data/config/lazy_module_config/toy_model.py b/tests/data/config/lazy_module_config/toy_model.py index a9d2a3f64a..3f2bbab383 100644 --- a/tests/data/config/lazy_module_config/toy_model.py +++ b/tests/data/config/lazy_module_config/toy_model.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from torch.distributed.fsdp.wrap import (size_based_auto_wrap_policy, + transformer_auto_wrap_policy) + +from mmengine._strategy import ColossalAIStrategy from mmengine.config import read_base from mmengine.dataset import DefaultSampler from mmengine.hooks import EMAHook @@ -46,4 +50,7 @@ priority=49) ] +# illegal model wrapper config, just for unit test. +strategy = dict(type=ColossalAIStrategy, model_wrapper=dict( + auto_wrap_policy=(size_based_auto_wrap_policy, transformer_auto_wrap_policy))) runner_type = FlexibleRunner diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 2fbe3d1eff..1a3fa85636 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -15,7 +15,7 @@ import pytest import mmengine -from mmengine import Config, ConfigDict, DictAction +from mmengine import Config, ConfigDict, ConfigList, DictAction from mmengine.config.lazy import LazyObject from mmengine.config.old_config import ConfigV1 from mmengine.fileio import dump, load @@ -1002,10 +1002,10 @@ def test_lazy_import(self, tmp_path): cfg = Config.fromfile(lazy_import_cfg_path) cfg_dict = cfg.to_dict() assert (cfg_dict['train_dataloader']['dataset']['type'] == - '') + '') assert (cfg_dict['custom_hooks'][0]['type'] - in ('', - '')) + in ('', + '')) # Dumped config dumped_cfg_path = tmp_path / 'test_dump_lazy.py' cfg.dump(dumped_cfg_path) @@ -1220,3 +1220,44 @@ def _recursive_check_lazy(self, cfg, expr): [self._recursive_check_lazy(value, expr) for value in cfg] else: self.assertTrue(expr(cfg)) + + +class TestConfigList(TestCase): + + def test_getitem(self): + cfg_list = ConfigList([ + 1, 2, + ConfigDict(type=LazyObject('mmengine')), + LazyObject('mmengine') + ]) + self.assertIs(cfg_list[2]['type'], mmengine) + self.assertIs(cfg_list[3], mmengine) + + def test_star(self): + + def check_star(a, b, c): + self.assertIs(c, mmengine) + + cfg_list = ConfigList([1, 2, LazyObject('mmengine')]) + check_star(*cfg_list) + + def check_for_loop(self): + cfg_list = ConfigList([LazyObject('mmengine')]) + for i in cfg_list: + self.assertIs(i, mmengine) + + def test_copy(self): + cfg_list = ConfigList([ + 1, 2, + ConfigDict(type=LazyObject('mmengine')), + LazyObject('mmengine') + ]) + cfg_copy = cfg_list.copy() + self.assertIsInstance(cfg_copy, ConfigList) + self.assertEqual(cfg_list, cfg_copy) + self.assertIs(cfg_list[2], cfg_copy[2]) + + cfg_copy = copy.deepcopy(cfg_list) + self.assertIsInstance(cfg_copy, ConfigList) + self.assertEqual(cfg_list, cfg_copy) + self.assertIsNot(cfg_list[2], cfg_copy[2])