diff --git a/schemantic/schema/arg_type_info.py b/schemantic/schema/arg_type_info.py index 5f053cd..a16e5dc 100644 --- a/schemantic/schema/arg_type_info.py +++ b/schemantic/schema/arg_type_info.py @@ -34,8 +34,7 @@ def __or__(self, other: Self) -> Self: msg = f"type hints do not match: {self.type_hint} != {other.type_hint}" raise AttributeError(msg) - self.merge_owner_to_default_with_other(other) - return self + return self.merge_owner_to_default_with_other(other) @model_serializer(mode="wrap") def serialize_model(self, handler, info): diff --git a/schemantic/schema/model.py b/schemantic/schema/model.py index 5cd015e..d651631 100644 --- a/schemantic/schema/model.py +++ b/schemantic/schema/model.py @@ -1,3 +1,4 @@ +from copy import copy from functools import reduce from typing import Annotated, Any, Iterable, Mapping, TypeAlias @@ -9,13 +10,13 @@ from schemantic.utils.mapping import dict_sorted_by_dict_key, extract_sort_keys LoadedSchema: TypeAlias = Annotated[dict[str, Any], Doc("Schema that has been loaded from a schemantic map or file")] +LoadedOrPathSchema: TypeAlias = Annotated[LoadedSchema | FilePath, Doc("Loaded schema or path to a schema")] ParsedSchema: TypeAlias = Annotated[ dict[str, dict[str, Any]], Doc("Schema that has been parsed from a schemantic map or file") ] -SchemaDefinition: TypeAlias = Annotated[dict[str, Any], Doc("Definition of the outer a schema")] -SchemaPathDefinition: TypeAlias = Annotated[SchemaDefinition | FilePath, Doc("Path to schema or schema dict")] NameToInstance: TypeAlias = Annotated[dict[str, Any], Doc("The mapping name of the instance to the instance itself")] +SchemaDefinition: TypeAlias = Annotated[dict[str, Any], Doc("Definition of the outer a schema")] class ClassNameMixin(BaseModel): class_name: str = ... @@ -43,21 +44,19 @@ def all_fields(self) -> ArgNameToTypeInfo: return {**self.required, **self.optional, **self.mixed} def intersection(self, other: "SignatureModel", dump_arg_name_only: bool = False) -> "SignatureModel": - required: set[str] = set(self.required) - optional: set[str] = set(self.optional) - mixed: set[str] = set(self.mixed) + required: ArgNameToTypeInfo = copy(self.required) + optional: ArgNameToTypeInfo = copy(self.optional) + mixed: ArgNameToTypeInfo = copy(self.mixed) - assert required.isdisjoint(optional) and required.isdisjoint(mixed) and optional.isdisjoint(mixed) - - mixed.intersection_update(other.mixed) + _intersection_update_type_info(mixed, other.mixed) _merge_new_fields_to_info_trackers(other.required, required, optional, mixed) _merge_new_fields_to_info_trackers(other.optional, optional, required, mixed) return SignatureModel( dump_arg_name_only=dump_arg_name_only, - required={field: self.required[field] for field in required}, - optional={field: self.optional[field] for field in optional}, - mixed={field: self.mixed[field] for field in mixed}, + required=required, + optional=optional, + mixed=mixed, ) @@ -117,18 +116,30 @@ class CultureSchema(BaseModel): def _merge_new_fields_to_info_trackers( - new_fields: Iterable[str], target_set: set[str], mutually_exclusive_set: set[str], mutually_inclusive_set: set[str] + new_fields: ArgNameToTypeInfo, target_mapping: ArgNameToTypeInfo, mutually_exclusive_mapping: ArgNameToTypeInfo, mutually_inclusive_mapping: ArgNameToTypeInfo ) -> None: - for new_field in new_fields: - if new_field in mutually_inclusive_set: - continue - - if new_field in mutually_exclusive_set: - mutually_exclusive_set.remove(new_field) - target_set.discard(new_field) - mutually_inclusive_set.add(new_field) + for field_name, field_info in new_fields.items(): + if field_name in mutually_inclusive_mapping: + mutually_inclusive_mapping[field_name] |= field_info + + elif existing := mutually_exclusive_mapping.pop(field_name, None): + assert field_name not in mutually_inclusive_mapping, "Field was just explored in mutually exclusive mapping, it should not be in the mutually inclusive mapping, yet" + merged = existing | field_info + if existing_ii := target_mapping.pop(field_name, None): + merged |= existing_ii + mutually_inclusive_mapping[field_name] = merged + + elif field_name in target_mapping: + target_mapping[field_name] |= field_info + else: - target_set.add(new_field) + target_mapping[field_name] = field_info + + +def _intersection_update_type_info(target: ArgNameToTypeInfo, source: ArgNameToTypeInfo) -> None: + for s_field_name, s_field_info in source.items(): + if s_field_name in target: + target[s_field_name] |= s_field_info def _sort_signature_dump(dumped_schema: dict, dump_arg_name_only: bool = False) -> None: diff --git a/schemantic/schemer/abstract.py b/schemantic/schemer/abstract.py index c2ede0e..5d5de6f 100644 --- a/schemantic/schemer/abstract.py +++ b/schemantic/schemer/abstract.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, FilePath, validate_call from schemantic.schema.arg_type_info import InitArgTypeInfo -from schemantic.schema.model import LoadedSchema, ParsedSchema, SchemaDefinition, SchemaPathDefinition, NameToInstance +from schemantic.schema.model import LoadedSchema, ParsedSchema, SchemaDefinition, LoadedOrPathSchema, NameToInstance Schema = TypeVar("Schema", bound=BaseModel) T = TypeVar("T") @@ -14,17 +14,17 @@ class AbstractSchemer(BaseModel, Generic[Schema], ABC, arbitrary_types_allowed=True): @abstractmethod - def _logical_post_dump_sort(self, dumped_schema: dict) -> dict: + def logical_post_dump_sort(self, dumped_schema: dict) -> None: ... @abstractmethod - def load_into_mapping_name_to_instance(self, schema_finding_info: SchemaPathDefinition) -> dict[str, NameToInstance]: ... + def load_into_mapping_name_to_instance(self, schema_finding_info: LoadedOrPathSchema) -> dict[str, NameToInstance]: ... @validate_call def dump(self, dump_path: Path, **schema_kwargs) -> None: schema: Schema = self.schema(**schema_kwargs) dumped_schema = schema.model_dump(exclude_defaults=True) - dumped_schema = self._logical_post_dump_sort(dumped_schema) + self.logical_post_dump_sort(dumped_schema) match dump_path.suffix: case ".toml": @@ -57,7 +57,7 @@ def load(schema_path: FilePath) -> dict: msg = f"{schema_path.suffix} is unsupported" raise NotImplementedError(msg) - def ensure_config_is_loaded(self, schema_finding_info: SchemaPathDefinition) -> LoadedSchema: + def ensure_config_is_loaded(self, schema_finding_info: LoadedOrPathSchema) -> LoadedSchema: if isinstance(schema_finding_info, Path): return self.load(schema_finding_info) if isinstance(schema_finding_info, dict): @@ -92,17 +92,29 @@ def mapping_name(self) -> str: ... def arg_to_info(self) -> dict[str, InitArgTypeInfo]: ... @validate_call - def load_into_mapping_name_to_instance(self, schema_finding_info: SchemaPathDefinition) -> dict[str, NameToInstance]: + def load_into_mapping_name_to_instance(self, schema_finding_info: LoadedOrPathSchema) -> dict[str, NameToInstance]: + """ + Load schema and parse it into a dictionary of instances. + """ loaded_schema = self.ensure_config_is_loaded(schema_finding_info) return self._parse_into_instance_by_mapping_name(loaded_schema, {}) @validate_call def load_definitions( self, - schema_finding_info: SchemaPathDefinition, + schema_finding_info: LoadedOrPathSchema, *, _inferior_config_kwargs: Optional[SchemaDefinition] = None, ) -> ParsedSchema: + """ + Load schema and parse it into a dictionary of configuration. + + Example + ------- + >>> schemer = MySchemer() + >>> schemer.load_definitions("path/to/schema.toml") + ... {"config1": {"key1": "value1"}, "config2": {"key2": "value2"}} + """ defined_schema = self.ensure_config_is_loaded(schema_finding_info) return self._schema_parser(defined_schema, _inferior_config_kwargs=_inferior_config_kwargs) diff --git a/schemantic/schemer/many.py b/schemantic/schemer/many.py index 32953b6..fd7fca9 100644 --- a/schemantic/schemer/many.py +++ b/schemantic/schemer/many.py @@ -21,7 +21,7 @@ LoadedSchema, ParsedSchema, SchemaDefinition, - SchemaPathDefinition, + LoadedOrPathSchema, merge_signature_models_by_flattening, merge_signature_models_by_intersection, NameToInstance, ) @@ -32,7 +32,7 @@ CULTURE_KEY, DEFINED_MAPPING_KEY, GROUP_MEMBER_KEY, - HOMOLOGUE_INSTANCE_KEY, INIT_SIGNATURE_MAPPING_KEY, CLASS_NAME_KEY, + HOMOLOGUE_INSTANCE_KEY, INIT_SIGNATURE_MAPPING_KEY, CLASS_NAME_KEY, ARGUMENT_TO_TYPING_KEY, ) from schemantic.utils.mapping import dict_sorted_by_dict_key, update_assert_disjoint @@ -131,8 +131,8 @@ def _parse_into_instance_by_mapping_name( def from_originating_type(cls, origin: type, **kwargs) -> Self: return cls(single_schemer=SingleSchemer.from_origin(origin=origin), **kwargs) - def _logical_post_dump_sort(self, dumped_schema: dict) -> dict: - return dumped_schema + def logical_post_dump_sort(self, dumped_schema: dict) -> None: + pass class GroupSchemer(HomologousGroupSchemer[GroupSchema]): @@ -254,14 +254,13 @@ def _schema_parser( ) return result - def _logical_post_dump_sort(self, dumped_schema: dict) -> dict: - _ensure_defined_mapping_in_common_key(dumped_schema) + def logical_post_dump_sort(self, dumped_schema: dict) -> None: + _ensure_defined_key_in_mapping(dumped_schema, COMMON_MAPPING_KEY) for key, member_schema in dumped_schema[GROUP_MEMBER_KEY].items(): dumped_schema[GROUP_MEMBER_KEY][key] = { DEFINED_MAPPING_KEY: member_schema.pop(DEFINED_MAPPING_KEY, {}), **member_schema, } - return dumped_schema class CultureSchemer(AbstractSchemer): @@ -276,28 +275,28 @@ def schema( culture = {} signature_models = [] - for model_schema in self.source_schemers: - if isinstance(model_schema, SingleSchemer): - schema = model_schema.schema() + for schemer in self.source_schemers: + if isinstance(schemer, SingleSchemer): + schema = schemer.schema() schema.dump_arg_name_only = True - signature_models.append(model_schema.signature_model) + signature_models.append(schemer.signature_model) - elif isinstance(model_schema, HomologueSchemer): - schema = model_schema.schema(name_getter_kwargs=homologue_name_getter_kwargs) + elif isinstance(schemer, HomologueSchemer): + schema = schemer.schema(name_getter_kwargs=homologue_name_getter_kwargs) schema.init_signature.dump_arg_name_only = True - signature_models.append(model_schema.single_schemer.signature_model) + signature_models.append(schemer.single_schemer.signature_model) - elif isinstance(model_schema, GroupSchemer): - schema = model_schema.schema() + elif isinstance(schemer, GroupSchemer): + schema = schemer.schema() for member_schema in schema.members.values(): member_schema.dump_arg_name_only = True - signature_models.append(model_schema.common_signature_model) + signature_models.append(schemer.common_signature_model) else: - msg = f"{model_schema.__class__} is not supported" + msg = f"{schemer.__class__} is not supported" raise NotImplementedError(msg) - culture[model_schema.mapping_name] = schema + culture[schemer.mapping_name] = schema signature_common_schema = merge_signature_models_by_intersection(*signature_models) signature_common_schema.dump_arg_name_only = True @@ -308,7 +307,7 @@ def schema( argument_to_typing=merge_signature_models_by_flattening(*signature_models), ) - def parse_schema(self, schema_finding_info: SchemaPathDefinition) -> dict: + def load_definitions(self, schema_finding_info: LoadedOrPathSchema) -> dict: loaded_schema = self.ensure_config_is_loaded(schema_finding_info) common = loaded_schema.get(COMMON_MAPPING_KEY, {}).get(DEFINED_MAPPING_KEY, {}) result: ParsedSchema = {} @@ -317,7 +316,7 @@ def parse_schema(self, schema_finding_info: SchemaPathDefinition) -> dict: result[schemer_name] = schemer._schema_parser(internal, _inferior_config_kwargs=common) return result - def load_into_mapping_name_to_instance(self, schema_finding_info: SchemaPathDefinition) -> NameToInstance: + def load_into_mapping_name_to_instance(self, schema_finding_info: LoadedOrPathSchema) -> NameToInstance: loaded_schema = self.ensure_config_is_loaded(schema_finding_info) common = loaded_schema.get(COMMON_MAPPING_KEY, {}).get(DEFINED_MAPPING_KEY, {}) result: ParsedSchema = {} @@ -330,17 +329,28 @@ def load_into_mapping_name_to_instance(self, schema_finding_info: SchemaPathDefi def _mapping_name_to_schemer(self) -> dict[str, SingleSchemer | HomologueSchemer | GroupSchemer]: return {schema.mapping_name: schema for schema in self.source_schemers} - def _logical_post_dump_sort(self, dumped_schema: dict) -> dict: + def logical_post_dump_sort(self, dumped_schema: dict) -> None: if COMMON_MAPPING_KEY in dumped_schema: - _ensure_defined_mapping_in_common_key(dumped_schema) - return dumped_schema + _ensure_defined_key_in_mapping(dumped_schema, COMMON_MAPPING_KEY) + + for schemer in self.source_schemers: + _ensure_defined_key_in_mapping(dumped_schema, CULTURE_KEY, schemer.mapping_name) + if isinstance(schemer, GroupSchemer): + dumped_schema[CULTURE_KEY][schemer.mapping_name].pop(ARGUMENT_TO_TYPING_KEY) Schemer: TypeAlias = SingleSchemer | HomologueSchemer | GroupSchemer | CultureSchemer -def _ensure_defined_mapping_in_common_key(dumped_schema: dict) -> None: - dumped_schema[COMMON_MAPPING_KEY] = { - DEFINED_MAPPING_KEY: dumped_schema[COMMON_MAPPING_KEY].pop(DEFINED_MAPPING_KEY, {}), - **dumped_schema[COMMON_MAPPING_KEY], +def _ensure_defined_key_in_mapping(dumped_schema: dict, *target_path: str) -> None: + target = dumped_schema + *remaining_path, last_key = target_path + + for path in remaining_path: + target = target[path] + + new_target = { + DEFINED_MAPPING_KEY: target[last_key].pop(DEFINED_MAPPING_KEY, {}), + **target[last_key], } + target[last_key] = new_target diff --git a/schemantic/schemer/single.py b/schemantic/schemer/single.py index 1ea769b..e3181ef 100644 --- a/schemantic/schemer/single.py +++ b/schemantic/schemer/single.py @@ -12,7 +12,7 @@ LoadedSchema, ParsedSchema, SchemaDefinition, - SchemaPathDefinition, + LoadedOrPathSchema, SignatureModel, SingleSchema, ) @@ -141,8 +141,8 @@ def _parse_into_instance_by_mapping_name(self, loaded_schema: LoadedSchema, infe def signature_model(self) -> SignatureModel: return SignatureModel(required=self.required, optional=self.optional) - def _logical_post_dump_sort(self, dumped_schema: dict) -> dict: - return dumped_schema + def logical_post_dump_sort(self, dumped_schema: dict) -> None: + pass def _schema_parser( self, diff --git a/schemantic/utils/constant.py b/schemantic/utils/constant.py index d032b6e..c603f32 100644 --- a/schemantic/utils/constant.py +++ b/schemantic/utils/constant.py @@ -8,6 +8,7 @@ CLASS_NAME_KEY: Final[str] = "class_name" INIT_SIGNATURE_MAPPING_KEY: Final[str] = "init_signature" +ARGUMENT_TO_TYPING_KEY: Final[str] = "argument_to_typing" HOMOLOGUE_INSTANCE_KEY: Final[str] = "instances" GROUP_MEMBER_KEY: Final[str] = "members" diff --git a/tests/test_case/abstract.py b/tests/test_case/abstract.py index 0ddabed..a3954aa 100644 --- a/tests/test_case/abstract.py +++ b/tests/test_case/abstract.py @@ -33,12 +33,12 @@ def expected_mapping_name_to_instance(self) -> dict: ... def test_schema_generates(self): self.assertTrue(self.main_schemer.schema()) - def test_parse_schema_to_instance_call(self): + def test_load_into_mapping_name_to_instance_call(self): self.assertTrue(self.expected_mapping_name_to_instance) - def test_parse_schema(self): + def test_load_definitions(self): self.assertEqual( - self.main_schemer.parse_schema(self.dump_expected_schema()), + self.main_schemer.load_definitions(self.dump_expected_schema()), self.expected_mapping_name_to_config, ) diff --git a/tests/test_case/main.py b/tests/test_case/main.py index baa3902..4956072 100644 --- a/tests/test_case/main.py +++ b/tests/test_case/main.py @@ -25,8 +25,8 @@ def _infer_single_schemer(origin: type, pre_definition: dict | None) -> SingleSc class SetSchemerMixin: - def test_parse_schema_to_instance_control(self): - parse_result = self.main_schemer._parse_into_instance_by_mapping_name(self.dump_expected_schema()) + def test_load_into_mapping_name_to_instance_control(self): + parse_result = self.main_schemer._parse_into_instance_by_mapping_name(self.dump_expected_schema(), {}) for mapping_name, instance in parse_result.items(): self.assertEqual(instance, self.expected_mapping_name_to_instance[mapping_name]) diff --git a/tests/test_schema_dataclass.py b/tests/test_schema_dataclass.py index 88a5afe..73cb5d0 100644 --- a/tests/test_schema_dataclass.py +++ b/tests/test_schema_dataclass.py @@ -38,7 +38,7 @@ class WithDataclassMixin: test_class_a: ClassVar[Type] = TestDataclass test_class_b: ClassVar[Type] = OtherTestDataclass - def test_parse_schema_to_instance(self): + def test_load_into_mapping_name_to_instance(self): self.assertTrue(self.schemer._parse_into_instance_by_mapping_name(self.schema_with_instance_configuration)) diff --git a/tests/test_schema_model.py b/tests/test_schema_model.py index 0efb72b..a5f1d14 100644 --- a/tests/test_schema_model.py +++ b/tests/test_schema_model.py @@ -76,7 +76,7 @@ class TestHomologueWithModel(unittest.TestCase, AbstractTestHomologue): }, } - def test_parse_schema_to_instance_control(self): + def test_load_into_mapping_name_to_instance_control(self): origin = self.main_schemer.single_schemer.origin self.assertEqual( self.main_schemer._parse_into_instance_by_mapping_name(self.dump_expected_schema), @@ -122,7 +122,7 @@ class TestGroupWithModel(unittest.TestCase, AbstractTestGroup, WithModelMixin): }, } - def test_parse_schema_to_instance_control(self): + def test_load_into_mapping_name_to_instance_control(self): self.assertEqual( self.main_schemer._parse_into_instance_by_mapping_name(self.dump_expected_schema), { @@ -182,7 +182,7 @@ class TestCultureWithModel(unittest.TestCase, AbstractTestCulture, WithModelMixi }, } - def test_parse_schema_to_instance(self): + def test_load_into_mapping_name_to_instance(self): self.assertEqual( self.main_schemer._parse_into_instance_by_mapping_name(self.dumped_schema_with_instance_configuration), {