Skip to content

Commit

Permalink
[Feature] Support using lazy object in list, tuple and set
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed Oct 26, 2023
1 parent 6c5eebb commit 8bcbda3
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 73 deletions.
8 changes: 6 additions & 2 deletions mmengine/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction, read_base
from .config import (Config, ConfigDict, ConfigList, ConfigSet, ConfigTuple,
DictAction, read_base)

__all__ = ['Config', 'ConfigDict', 'DictAction', 'read_base']
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'read_base', 'ConfigList',
'ConfigSet', 'ConfigTuple'
]
201 changes: 131 additions & 70 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,87 @@
import re # type: ignore


def _lazy2string(cfg_dict, dict_type=None):
if isinstance(cfg_dict, dict):
dict_type = dict_type or type(cfg_dict)
return dict_type({k: _lazy2string(v) for k, v in dict.items(cfg_dict)})
elif isinstance(cfg_dict, (tuple, list)):
return type(cfg_dict)(_lazy2string(v) for v in cfg_dict)
elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
return f'{cfg_dict.module}.{str(cfg_dict)}'
else:
return cfg_dict
class LazyContainerMeta(type):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lazy = False


class LazyContainerMixin(metaclass=LazyContainerMeta):

def to_builtin(self, keep_lazy=False):

def _to_builtin(cfg):
if isinstance(cfg, dict):
return dict({k: _to_builtin(v) for k, v in dict.items(cfg)})
elif isinstance(cfg, tuple):
return tuple(_to_builtin(v) for v in tuple.__iter__(cfg))
elif isinstance(cfg, list):
return list(_to_builtin(v) for v in list.__iter__(cfg))
elif isinstance(cfg, set):
return {_to_builtin(v) for v in set.__iter__(cfg)}
elif isinstance(cfg, (LazyAttr, LazyObject)):
if not keep_lazy:
return f'{cfg.module}.{str(cfg)}'
else:
return cfg
else:
return cfg

return _to_builtin(self)

def build_lazy(self, value: Any) -> Any:
"""If class attribute ``lazy`` is False, the LazyObject will be built
and returned.
Args:
value (Any): The value to be built.
class ConfigDict(Dict):
Returns:
Any: The built value.
"""
if (isinstance(value, (LazyAttr, LazyObject))
and not self.__class__.lazy):
value = value.build()
return value

def __deepcopy__(self, memo):
return self.__class__(
copy.deepcopy(item, memo) for item in super().__iter__())

def __copy__(self):
return self.__class__(item for item in super().__iter__())

def __iter__(self):
# Override `__iter__` to overwrite to support star unpacking
# `*cfg_list`
yield from map(self.build_lazy, super().__iter__())

def __getitem__(self, idx):
try:
value = self.build_lazy(super().__getitem__(idx))
except Exception as e:
raise e
else:
return value

def __eq__(self, other):
return all(a == b for a, b in zip(self, other))

def __reduce_ex__(self, proto):
# Override __reduce_ex__ to avoid dump the built lazy object.
if digit_version(platform.python_version()) < digit_version('3.8'):
return (self.__class__, (tuple(i for i in super().__iter__()), ),
None, None, None)
else:
return (self.__class__, (tuple(i for i in super().__iter__()), ),
None, None, None, None)

copy = __copy__


class ConfigDict(LazyContainerMixin, Dict):
"""A dictionary for config which has the same interface as python's built-
in dictionary and can be used as a normal dictionary.
Expand All @@ -72,7 +140,6 @@ class ConfigDict(Dict):
object during configuration parsing, and it should be set to False outside
the Config to ensure that users do not experience the ``LazyObject``.
"""
lazy = False

def __init__(__self, *args, **kwargs):
object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
Expand Down Expand Up @@ -118,8 +185,14 @@ def __getattr__(self, name):
@classmethod
def _hook(cls, item):
# avoid to convert user defined dict to ConfigDict.
if type(item) in (dict, OrderedDict):
if isinstance(item, ConfigDict):
return item
elif type(item) in (dict, OrderedDict):
return cls(item)
elif isinstance(item, LazyContainerMixin):
return type(item)(
cls._hook(elem)
for elem in super(LazyContainerMixin, item).__iter__())
elif isinstance(item, (list, tuple)):
return type(item)(cls._hook(elem) for elem in item)
return item
Expand Down Expand Up @@ -150,11 +223,6 @@ def __copy__(self):

copy = __copy__

def __iter__(self):
# Implement `__iter__` to overwrite the unpacking operator `**cfg_dict`
# to get the built lazy object
return iter(self.keys())

def get(self, key: str, default: Optional[Any] = None) -> Any:
"""Get the value of the key. If class attribute ``lazy`` is True, the
LazyObject will be built and returned.
Expand Down Expand Up @@ -201,20 +269,6 @@ def update(self, *args, **kwargs) -> None:
else:
self[k].update(v)

def build_lazy(self, value: Any) -> Any:
"""If class attribute ``lazy`` is False, the LazyObject will be built
and returned.
Args:
value (Any): The value to be built.
Returns:
Any: The built value.
"""
if isinstance(value, (LazyAttr, LazyObject)) and not self.lazy:
value = value.build()
return value

def values(self):
"""Yield the values of the dictionary.
Expand Down Expand Up @@ -288,28 +342,28 @@ def __eq__(self, other):
return False

def _to_lazy_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep
the ``LazyObject`` or ``LazyAttr`` object not built."""
# NOTE: Keep this function for backward compatibility.

def _to_dict(data):
if isinstance(data, ConfigDict):
return {
key: _to_dict(value)
for key, value in Dict.items(data)
}
elif isinstance(data, dict):
return {key: _to_dict(value) for key, value in data.items()}
elif isinstance(data, (list, tuple)):
return type(data)(_to_dict(item) for item in data)
else:
return data

return _to_dict(self)
return self.to_builtin(keep_lazy=True)

def to_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and
convert the ``LazyObject`` or ``LazyAttr`` to string."""
return _lazy2string(self, dict_type=dict)
# NOTE: Keep this function for backward compatibility.
return self.to_builtin()


class ConfigList(LazyContainerMixin, list): # type: ignore
...


class ConfigTuple(LazyContainerMixin, tuple): # type: ignore
...


class ConfigSet(LazyContainerMixin, set): # type: ignore
...

def pop(self, idx):
return self.build_lazy(super().pop(idx))


def add_args(parser: ArgumentParser,
Expand Down Expand Up @@ -479,27 +533,30 @@ def fromfile(filename: Union[str, Path],
env_variables=env_variables,
)
else:
# Enable lazy import when parsing the config.
# Using try-except to make sure ``ConfigDict.lazy`` will be reset
# to False. See more details about lazy in the docstring of
# ConfigDict
ConfigDict.lazy = True
try:
with Config._lazy_context():
cfg_dict, imported_names = Config._parse_lazy_import(filename)
except Exception as e:
raise e
finally:
# disable lazy import to get the real type. See more details
# about lazy in the docstring of ConfigDict
ConfigDict.lazy = False

cfg = Config(
cfg_dict,
filename=filename,
format_python_code=format_python_code)
object.__setattr__(cfg, '_imported_names', imported_names)
return cfg

@staticmethod
@contextmanager
def _lazy_context():
ConfigDict.lazy = True
ConfigSet.lazy = True
ConfigList.lazy = True
ConfigTuple.lazy = True

yield

ConfigDict.lazy = False
ConfigSet.lazy = False
ConfigList.lazy = False
ConfigTuple.lazy = False

@staticmethod
def fromstring(cfg_str: str, file_format: str) -> 'Config':
"""Build a Config instance from config text.
Expand Down Expand Up @@ -1110,12 +1167,12 @@ def _parse_lazy_import(filename: str) -> Tuple[ConfigDict, set]:
continue
ret[key] = value
# convert dict to ConfigDict
cfg_dict = Config._dict_to_config_dict_lazy(ret)
cfg_dict = Config._to_lazy_container(ret)

return cfg_dict, imported_names

@staticmethod
def _dict_to_config_dict_lazy(cfg: dict):
def _to_lazy_container(cfg: dict):
"""Recursively converts ``dict`` to :obj:`ConfigDict`. The only
difference between ``_dict_to_config_dict_lazy`` and
``_dict_to_config_dict_lazy`` is that the former one does not consider
Expand All @@ -1131,11 +1188,15 @@ def _dict_to_config_dict_lazy(cfg: dict):
if isinstance(cfg, dict):
cfg_dict = ConfigDict()
for key, value in cfg.items():
cfg_dict[key] = Config._dict_to_config_dict_lazy(value)
cfg_dict[key] = Config._to_lazy_container(value)
return cfg_dict
if isinstance(cfg, (tuple, list)):
return type(cfg)(
Config._dict_to_config_dict_lazy(_cfg) for _cfg in cfg)
if isinstance(cfg, list):
return ConfigList(Config._to_lazy_container(_cfg) for _cfg in cfg)
if isinstance(cfg, tuple):
return ConfigTuple(Config._to_lazy_container(_cfg) for _cfg in cfg)
if isinstance(cfg, set):
return ConfigSet(Config._to_lazy_container(_cfg) for _cfg in cfg)

return cfg

@staticmethod
Expand Down
7 changes: 7 additions & 0 deletions tests/data/config/lazy_module_config/toy_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch.distributed.fsdp.wrap import (size_based_auto_wrap_policy,
transformer_auto_wrap_policy)

from mmengine._strategy import ColossalAIStrategy
from mmengine.config import read_base
from mmengine.dataset import DefaultSampler
from mmengine.hooks import EMAHook
Expand Down Expand Up @@ -46,4 +50,7 @@
priority=49)
]

# illegal model wrapper config, just for unit test.
strategy = dict(type=ColossalAIStrategy, model_wrapper=dict(
auto_wrap_policy=(size_based_auto_wrap_policy, transformer_auto_wrap_policy)))
runner_type = FlexibleRunner
43 changes: 42 additions & 1 deletion tests/test_config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

import mmengine
from mmengine import Config, ConfigDict, DictAction
from mmengine import Config, ConfigDict, ConfigList, DictAction
from mmengine.config.lazy import LazyObject
from mmengine.fileio import dump, load
from mmengine.registry import MODELS, DefaultScope, Registry
Expand Down Expand Up @@ -1209,3 +1209,44 @@ def _recursive_check_lazy(self, cfg, expr):
[self._recursive_check_lazy(value, expr) for value in cfg]
else:
self.assertTrue(expr(cfg))


class TestConfigList(TestCase):

def test_getitem(self):
cfg_list = ConfigList([
1, 2,
ConfigDict(type=LazyObject('mmengine')),
LazyObject('mmengine')
])
self.assertIs(cfg_list[2]['type'], mmengine)
self.assertIs(cfg_list[3], mmengine)

def test_star(self):

def check_star(a, b, c):
self.assertIs(c, mmengine)

cfg_list = ConfigList([1, 2, LazyObject('mmengine')])
check_star(*cfg_list)

def check_for_loop(self):
cfg_list = ConfigList([LazyObject('mmengine')])
for i in cfg_list:
self.assertIs(i, mmengine)

def test_copy(self):
cfg_list = ConfigList([
1, 2,
ConfigDict(type=LazyObject('mmengine')),
LazyObject('mmengine')
])
cfg_copy = cfg_list.copy()
self.assertIsInstance(cfg_copy, ConfigList)
self.assertEqual(cfg_list, cfg_copy)
self.assertIs(cfg_list[2], cfg_copy[2])

cfg_copy = copy.deepcopy(cfg_list)
self.assertIsInstance(cfg_copy, ConfigList)
self.assertEqual(cfg_list, cfg_copy)
self.assertIsNot(cfg_list[2], cfg_copy[2])

0 comments on commit 8bcbda3

Please sign in to comment.