Skip to content

Commit

Permalink
cherry-pick config container
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed Oct 26, 2023
1 parent 5c7a83b commit 9477ac4
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 69 deletions.
8 changes: 6 additions & 2 deletions mmengine/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction
from .config import (Config, ConfigDict, ConfigList, ConfigSet, ConfigTuple,
DictAction)
from .new_config import read_base

__all__ = ['Config', 'ConfigDict', 'DictAction', 'read_base']
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'read_base', 'ConfigList',
'ConfigSet', 'ConfigTuple'
]
160 changes: 107 additions & 53 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,86 @@
DELETE_KEY = '_delete_'


def _lazy2string(cfg_dict, dict_type=None):
if isinstance(cfg_dict, dict):
dict_type = dict_type or type(cfg_dict)
return dict_type({k: _lazy2string(v) for k, v in dict.items(cfg_dict)})
elif isinstance(cfg_dict, (tuple, list)):
return type(cfg_dict)(_lazy2string(v) for v in cfg_dict)
elif isinstance(cfg_dict, LazyObject):
return cfg_dict.dump_str
else:
return cfg_dict


class ConfigDict(Dict):
class LazyContainerMeta(type):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lazy = False


class LazyContainerMixin(metaclass=LazyContainerMeta):

def to_builtin(self, keep_lazy=False):

def _to_builtin(cfg):
if isinstance(cfg, dict):
return dict({k: _to_builtin(v) for k, v in dict.items(cfg)})
elif isinstance(cfg, tuple):
return tuple(_to_builtin(v) for v in tuple.__iter__(cfg))
elif isinstance(cfg, list):
return list(_to_builtin(v) for v in list.__iter__(cfg))
elif isinstance(cfg, set):
return {_to_builtin(v) for v in set.__iter__(cfg)}
elif isinstance(cfg, LazyObject):
if not keep_lazy:
return cfg.dump_str
else:
return cfg
else:
return cfg

return _to_builtin(self)

def build_lazy(self, value: Any) -> Any:
"""If class attribute ``lazy`` is False, the LazyObject will be built
and returned.
Args:
value (Any): The value to be built.
Returns:
Any: The built value.
"""
if (isinstance(value, LazyObject) and not self.__class__.lazy):
value = value.build()
return value

def __deepcopy__(self, memo):
return self.__class__(
copy.deepcopy(item, memo) for item in super().__iter__())

def __copy__(self):
return self.__class__(item for item in super().__iter__())

def __iter__(self):
# Override `__iter__` to overwrite to support star unpacking
# `*cfg_list`
yield from map(self.build_lazy, super().__iter__())

def __getitem__(self, idx):
try:
value = self.build_lazy(super().__getitem__(idx))
except Exception as e:
raise e
else:
return value

def __eq__(self, other):
return all(a == b for a, b in zip(self, other))

def __reduce_ex__(self, proto):
# Override __reduce_ex__ to avoid dump the built lazy object.
if digit_version(platform.python_version()) < digit_version('3.8'):
return (self.__class__, (tuple(i for i in super().__iter__()), ),
None, None, None)
else:
return (self.__class__, (tuple(i for i in super().__iter__()), ),
None, None, None, None)

copy = __copy__


class ConfigDict(LazyContainerMixin, Dict):
"""A dictionary for config which has the same interface as python's built-
in dictionary and can be used as a normal dictionary.
Expand All @@ -55,7 +122,6 @@ class ConfigDict(Dict):
object during configuration parsing, and it should be set to False outside
the Config to ensure that users do not experience the ``LazyObject``.
"""
lazy = False

def __init__(__self, *args, **kwargs):
object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
Expand Down Expand Up @@ -101,8 +167,14 @@ def __getattr__(self, name):
@classmethod
def _hook(cls, item):
# avoid to convert user defined dict to ConfigDict.
if type(item) in (dict, OrderedDict):
if isinstance(item, ConfigDict):
return item
elif type(item) in (dict, OrderedDict):
return cls(item)
elif isinstance(item, LazyContainerMixin):
return type(item)(
cls._hook(elem)
for elem in super(LazyContainerMixin, item).__iter__())
elif isinstance(item, (list, tuple)):
return type(item)(cls._hook(elem) for elem in item)
return item
Expand Down Expand Up @@ -133,11 +205,6 @@ def __copy__(self):

copy = __copy__

def __iter__(self):
# Implement `__iter__` to overwrite the unpacking operator `**cfg_dict`
# to get the built lazy object
return iter(self.keys())

def get(self, key: str, default: Optional[Any] = None) -> Any:
"""Get the value of the key. If class attribute ``lazy`` is True, the
LazyObject will be built and returned.
Expand Down Expand Up @@ -184,20 +251,6 @@ def update(self, *args, **kwargs) -> None:
else:
self[k].update(v)

def build_lazy(self, value: Any) -> Any:
"""If class attribute ``lazy`` is False, the LazyObject will be built
and returned.
Args:
value (Any): The value to be built.
Returns:
Any: The built value.
"""
if isinstance(value, LazyObject) and not self.lazy:
value = value.build()
return value

def values(self):
"""Yield the values of the dictionary.
Expand Down Expand Up @@ -271,28 +324,29 @@ def __eq__(self, other):
return False

def _to_lazy_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep
the ``LazyObject`` or ``LazyAttr`` object not built."""

def _to_dict(data):
if isinstance(data, ConfigDict):
return {
key: _to_dict(value)
for key, value in Dict.items(data)
}
elif isinstance(data, dict):
return {key: _to_dict(value) for key, value in data.items()}
elif isinstance(data, (list, tuple)):
return type(data)(_to_dict(item) for item in data)
else:
return data
# NOTE: Keep this function for backward compatibility.

return _to_dict(self)
return self.to_builtin(keep_lazy=True)

def to_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and
convert the ``LazyObject`` or ``LazyAttr`` to string."""
return _lazy2string(self, dict_type=dict)
# NOTE: Keep this function for backward compatibility.
return self.to_builtin()


class ConfigList(LazyContainerMixin, list): # type: ignore

def pop(self, idx):
return self.build_lazy(super().pop(idx))


class ConfigTuple(LazyContainerMixin, tuple): # type: ignore
...


class ConfigSet(LazyContainerMixin, set): # type: ignore

def pop(self, idx):
return self.build_lazy(super().pop(idx))


def add_args(parser: ArgumentParser,
Expand Down
40 changes: 30 additions & 10 deletions mmengine/config/new_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import os
import platform
import sys
from contextlib import contextmanager
from importlib.machinery import PathFinder
from pathlib import Path
from types import BuiltinFunctionType, FunctionType, ModuleType
from typing import Optional, Tuple, Union

from yapf.yapflib.yapf_api import FormatCode

from .config import Config, ConfigDict
from .config import Config, ConfigDict, ConfigList, ConfigSet, ConfigTuple
from .lazy import LazyImportContext, LazyObject, recover_lazy_field

RESERVED_KEYS = ['filename', 'text', 'pretty_text']
Expand Down Expand Up @@ -177,8 +178,7 @@ def fromfile( # type: ignore
# Using try-except to make sure ``ConfigDict.lazy`` will be reset
# to False. See more details about lazy in the docstring of
# ConfigDict
ConfigDict.lazy = True
try:
with ConfigV2._lazy_context():
module = ConfigV2._get_config_module(filename)
module_dict = {
k: getattr(module, k)
Expand All @@ -193,7 +193,6 @@ def fromfile( # type: ignore
cfg_dict,
filename=filename,
format_python_code=format_python_code)
finally:
ConfigDict.lazy = False
global _CFG_UID
_CFG_UID = 0
Expand All @@ -203,6 +202,21 @@ def fromfile( # type: ignore

return cfg

@staticmethod
@contextmanager
def _lazy_context():
ConfigDict.lazy = True
ConfigSet.lazy = True
ConfigList.lazy = True
ConfigTuple.lazy = True

yield

ConfigDict.lazy = False
ConfigSet.lazy = False
ConfigList.lazy = False
ConfigTuple.lazy = False

@staticmethod
def _get_config_module(filename: Union[str, Path]):
file = Path(filename).absolute()
Expand All @@ -222,7 +236,7 @@ def _get_config_module(filename: Union[str, Path]):
return module

@staticmethod
def _dict_to_config_dict_lazy(cfg: dict):
def _to_lazy_container(cfg: dict):
"""Recursively converts ``dict`` to :obj:`ConfigDict`. The only
difference between ``_dict_to_config_dict_lazy`` and
``_dict_to_config_dict_lazy`` is that the former one does not consider
Expand All @@ -238,11 +252,17 @@ def _dict_to_config_dict_lazy(cfg: dict):
if isinstance(cfg, dict):
cfg_dict = ConfigDict()
for key, value in cfg.items():
cfg_dict[key] = ConfigV2._dict_to_config_dict_lazy(value)
cfg_dict[key] = ConfigV2._to_lazy_container(value)
return cfg_dict
if isinstance(cfg, (tuple, list)):
return type(cfg)(
ConfigV2._dict_to_config_dict_lazy(_cfg) for _cfg in cfg)
if isinstance(cfg, list):
return ConfigList(
ConfigV2._to_lazy_container(_cfg) for _cfg in cfg)
if isinstance(cfg, tuple):
return ConfigTuple(
ConfigV2._to_lazy_container(_cfg) for _cfg in cfg)
if isinstance(cfg, set):
return ConfigSet(ConfigV2._to_lazy_container(_cfg) for _cfg in cfg)

return cfg

@property
Expand Down Expand Up @@ -433,7 +453,7 @@ def new_import(name, globals=None, locals=None, fromlist=(), level=0):
mod = ConfigV2._get_config_module(cur_file)

for k in dir(mod):
mod.__dict__[k] = ConfigV2._dict_to_config_dict_lazy(
mod.__dict__[k] = ConfigV2._to_lazy_container(
getattr(mod, k))
else:
mod = old_import(
Expand Down
7 changes: 7 additions & 0 deletions tests/data/config/lazy_module_config/toy_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch.distributed.fsdp.wrap import (size_based_auto_wrap_policy,
transformer_auto_wrap_policy)

from mmengine._strategy import ColossalAIStrategy
from mmengine.config import read_base
from mmengine.dataset import DefaultSampler
from mmengine.hooks import EMAHook
Expand Down Expand Up @@ -46,4 +50,7 @@
priority=49)
]

# illegal model wrapper config, just for unit test.
strategy = dict(type=ColossalAIStrategy, model_wrapper=dict(
auto_wrap_policy=(size_based_auto_wrap_policy, transformer_auto_wrap_policy)))
runner_type = FlexibleRunner
49 changes: 45 additions & 4 deletions tests/test_config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

import mmengine
from mmengine import Config, ConfigDict, DictAction
from mmengine import Config, ConfigDict, ConfigList, DictAction
from mmengine.config.lazy import LazyObject
from mmengine.config.old_config import ConfigV1
from mmengine.fileio import dump, load
Expand Down Expand Up @@ -1002,10 +1002,10 @@ def test_lazy_import(self, tmp_path):
cfg = Config.fromfile(lazy_import_cfg_path)
cfg_dict = cfg.to_dict()
assert (cfg_dict['train_dataloader']['dataset']['type'] ==
'<mmengine.testing.runner_test_case.ToyDataset>')
'<imp mmengine.testing.runner_test_case.ToyDataset>')
assert (cfg_dict['custom_hooks'][0]['type']
in ('<mmengine.hooks.EMAHook>',
'<mmengine.hooks.ema_hook.EMAHook>'))
in ('<imp mmengine.hooks.EMAHook>',
'<imp mmengine.hooks.ema_hook.EMAHook>'))
# Dumped config
dumped_cfg_path = tmp_path / 'test_dump_lazy.py'
cfg.dump(dumped_cfg_path)
Expand Down Expand Up @@ -1220,3 +1220,44 @@ def _recursive_check_lazy(self, cfg, expr):
[self._recursive_check_lazy(value, expr) for value in cfg]
else:
self.assertTrue(expr(cfg))


class TestConfigList(TestCase):

def test_getitem(self):
cfg_list = ConfigList([
1, 2,
ConfigDict(type=LazyObject('mmengine')),
LazyObject('mmengine')
])
self.assertIs(cfg_list[2]['type'], mmengine)
self.assertIs(cfg_list[3], mmengine)

def test_star(self):

def check_star(a, b, c):
self.assertIs(c, mmengine)

cfg_list = ConfigList([1, 2, LazyObject('mmengine')])
check_star(*cfg_list)

def check_for_loop(self):
cfg_list = ConfigList([LazyObject('mmengine')])
for i in cfg_list:
self.assertIs(i, mmengine)

def test_copy(self):
cfg_list = ConfigList([
1, 2,
ConfigDict(type=LazyObject('mmengine')),
LazyObject('mmengine')
])
cfg_copy = cfg_list.copy()
self.assertIsInstance(cfg_copy, ConfigList)
self.assertEqual(cfg_list, cfg_copy)
self.assertIs(cfg_list[2], cfg_copy[2])

cfg_copy = copy.deepcopy(cfg_list)
self.assertIsInstance(cfg_copy, ConfigList)
self.assertEqual(cfg_list, cfg_copy)
self.assertIsNot(cfg_list[2], cfg_copy[2])

0 comments on commit 9477ac4

Please sign in to comment.