Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
caniko committed Nov 10, 2024
1 parent f1b4e5d commit 7f32c03
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 70 deletions.
3 changes: 1 addition & 2 deletions schemantic/schema/arg_type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def __or__(self, other: Self) -> Self:
msg = f"type hints do not match: {self.type_hint} != {other.type_hint}"
raise AttributeError(msg)

self.merge_owner_to_default_with_other(other)
return self
return self.merge_owner_to_default_with_other(other)

@model_serializer(mode="wrap")
def serialize_model(self, handler, info):
Expand Down
53 changes: 32 additions & 21 deletions schemantic/schema/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import copy
from functools import reduce
from typing import Annotated, Any, Iterable, Mapping, TypeAlias

Expand All @@ -9,13 +10,13 @@
from schemantic.utils.mapping import dict_sorted_by_dict_key, extract_sort_keys

LoadedSchema: TypeAlias = Annotated[dict[str, Any], Doc("Schema that has been loaded from a schemantic map or file")]
LoadedOrPathSchema: TypeAlias = Annotated[LoadedSchema | FilePath, Doc("Loaded schema or path to a schema")]
ParsedSchema: TypeAlias = Annotated[
dict[str, dict[str, Any]], Doc("Schema that has been parsed from a schemantic map or file")
]
SchemaDefinition: TypeAlias = Annotated[dict[str, Any], Doc("Definition of the outer a schema")]
SchemaPathDefinition: TypeAlias = Annotated[SchemaDefinition | FilePath, Doc("Path to schema or schema dict")]
NameToInstance: TypeAlias = Annotated[dict[str, Any], Doc("The mapping name of the instance to the instance itself")]

SchemaDefinition: TypeAlias = Annotated[dict[str, Any], Doc("Definition of the outer a schema")]

class ClassNameMixin(BaseModel):
class_name: str = ...
Expand Down Expand Up @@ -43,21 +44,19 @@ def all_fields(self) -> ArgNameToTypeInfo:
return {**self.required, **self.optional, **self.mixed}

def intersection(self, other: "SignatureModel", dump_arg_name_only: bool = False) -> "SignatureModel":
required: set[str] = set(self.required)
optional: set[str] = set(self.optional)
mixed: set[str] = set(self.mixed)
required: ArgNameToTypeInfo = copy(self.required)
optional: ArgNameToTypeInfo = copy(self.optional)
mixed: ArgNameToTypeInfo = copy(self.mixed)

assert required.isdisjoint(optional) and required.isdisjoint(mixed) and optional.isdisjoint(mixed)

mixed.intersection_update(other.mixed)
_intersection_update_type_info(mixed, other.mixed)
_merge_new_fields_to_info_trackers(other.required, required, optional, mixed)
_merge_new_fields_to_info_trackers(other.optional, optional, required, mixed)

return SignatureModel(
dump_arg_name_only=dump_arg_name_only,
required={field: self.required[field] for field in required},
optional={field: self.optional[field] for field in optional},
mixed={field: self.mixed[field] for field in mixed},
required=required,
optional=optional,
mixed=mixed,
)


Expand Down Expand Up @@ -117,18 +116,30 @@ class CultureSchema(BaseModel):


def _merge_new_fields_to_info_trackers(
new_fields: Iterable[str], target_set: set[str], mutually_exclusive_set: set[str], mutually_inclusive_set: set[str]
new_fields: ArgNameToTypeInfo, target_mapping: ArgNameToTypeInfo, mutually_exclusive_mapping: ArgNameToTypeInfo, mutually_inclusive_mapping: ArgNameToTypeInfo
) -> None:
for new_field in new_fields:
if new_field in mutually_inclusive_set:
continue

if new_field in mutually_exclusive_set:
mutually_exclusive_set.remove(new_field)
target_set.discard(new_field)
mutually_inclusive_set.add(new_field)
for field_name, field_info in new_fields.items():
if field_name in mutually_inclusive_mapping:
mutually_inclusive_mapping[field_name] |= field_info

elif existing := mutually_exclusive_mapping.pop(field_name, None):
assert field_name not in mutually_inclusive_mapping, "Field was just explored in mutually exclusive mapping, it should not be in the mutually inclusive mapping, yet"
merged = existing | field_info
if existing_ii := target_mapping.pop(field_name, None):
merged |= existing_ii
mutually_inclusive_mapping[field_name] = merged

elif field_name in target_mapping:
target_mapping[field_name] |= field_info

else:
target_set.add(new_field)
target_mapping[field_name] = field_info


def _intersection_update_type_info(target: ArgNameToTypeInfo, source: ArgNameToTypeInfo) -> None:
for s_field_name, s_field_info in source.items():
if s_field_name in target:
target[s_field_name] |= s_field_info


def _sort_signature_dump(dumped_schema: dict, dump_arg_name_only: bool = False) -> None:
Expand Down
26 changes: 19 additions & 7 deletions schemantic/schemer/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@
from pydantic import BaseModel, Field, FilePath, validate_call

from schemantic.schema.arg_type_info import InitArgTypeInfo
from schemantic.schema.model import LoadedSchema, ParsedSchema, SchemaDefinition, SchemaPathDefinition, NameToInstance
from schemantic.schema.model import LoadedSchema, ParsedSchema, SchemaDefinition, LoadedOrPathSchema, NameToInstance

Schema = TypeVar("Schema", bound=BaseModel)
T = TypeVar("T")


class AbstractSchemer(BaseModel, Generic[Schema], ABC, arbitrary_types_allowed=True):
@abstractmethod
def _logical_post_dump_sort(self, dumped_schema: dict) -> dict:
def logical_post_dump_sort(self, dumped_schema: dict) -> None:
...

@abstractmethod
def load_into_mapping_name_to_instance(self, schema_finding_info: SchemaPathDefinition) -> dict[str, NameToInstance]: ...
def load_into_mapping_name_to_instance(self, schema_finding_info: LoadedOrPathSchema) -> dict[str, NameToInstance]: ...

@validate_call
def dump(self, dump_path: Path, **schema_kwargs) -> None:
schema: Schema = self.schema(**schema_kwargs)
dumped_schema = schema.model_dump(exclude_defaults=True)
dumped_schema = self._logical_post_dump_sort(dumped_schema)
self.logical_post_dump_sort(dumped_schema)

match dump_path.suffix:
case ".toml":
Expand Down Expand Up @@ -57,7 +57,7 @@ def load(schema_path: FilePath) -> dict:
msg = f"{schema_path.suffix} is unsupported"
raise NotImplementedError(msg)

def ensure_config_is_loaded(self, schema_finding_info: SchemaPathDefinition) -> LoadedSchema:
def ensure_config_is_loaded(self, schema_finding_info: LoadedOrPathSchema) -> LoadedSchema:
if isinstance(schema_finding_info, Path):
return self.load(schema_finding_info)
if isinstance(schema_finding_info, dict):
Expand Down Expand Up @@ -92,17 +92,29 @@ def mapping_name(self) -> str: ...
def arg_to_info(self) -> dict[str, InitArgTypeInfo]: ...

@validate_call
def load_into_mapping_name_to_instance(self, schema_finding_info: SchemaPathDefinition) -> dict[str, NameToInstance]:
def load_into_mapping_name_to_instance(self, schema_finding_info: LoadedOrPathSchema) -> dict[str, NameToInstance]:
"""
Load schema and parse it into a dictionary of instances.
"""
loaded_schema = self.ensure_config_is_loaded(schema_finding_info)
return self._parse_into_instance_by_mapping_name(loaded_schema, {})

@validate_call
def load_definitions(
self,
schema_finding_info: SchemaPathDefinition,
schema_finding_info: LoadedOrPathSchema,
*,
_inferior_config_kwargs: Optional[SchemaDefinition] = None,
) -> ParsedSchema:
"""
Load schema and parse it into a dictionary of configuration.
Example
-------
>>> schemer = MySchemer()
>>> schemer.load_definitions("path/to/schema.toml")
... {"config1": {"key1": "value1"}, "config2": {"key2": "value2"}}
"""
defined_schema = self.ensure_config_is_loaded(schema_finding_info)
return self._schema_parser(defined_schema, _inferior_config_kwargs=_inferior_config_kwargs)

Expand Down
66 changes: 38 additions & 28 deletions schemantic/schemer/many.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LoadedSchema,
ParsedSchema,
SchemaDefinition,
SchemaPathDefinition,
LoadedOrPathSchema,
merge_signature_models_by_flattening,
merge_signature_models_by_intersection, NameToInstance,
)
Expand All @@ -32,7 +32,7 @@
CULTURE_KEY,
DEFINED_MAPPING_KEY,
GROUP_MEMBER_KEY,
HOMOLOGUE_INSTANCE_KEY, INIT_SIGNATURE_MAPPING_KEY, CLASS_NAME_KEY,
HOMOLOGUE_INSTANCE_KEY, INIT_SIGNATURE_MAPPING_KEY, CLASS_NAME_KEY, ARGUMENT_TO_TYPING_KEY,
)
from schemantic.utils.mapping import dict_sorted_by_dict_key, update_assert_disjoint

Expand Down Expand Up @@ -131,8 +131,8 @@ def _parse_into_instance_by_mapping_name(
def from_originating_type(cls, origin: type, **kwargs) -> Self:
return cls(single_schemer=SingleSchemer.from_origin(origin=origin), **kwargs)

def _logical_post_dump_sort(self, dumped_schema: dict) -> dict:
return dumped_schema
def logical_post_dump_sort(self, dumped_schema: dict) -> None:
pass


class GroupSchemer(HomologousGroupSchemer[GroupSchema]):
Expand Down Expand Up @@ -254,14 +254,13 @@ def _schema_parser(
)
return result

def _logical_post_dump_sort(self, dumped_schema: dict) -> dict:
_ensure_defined_mapping_in_common_key(dumped_schema)
def logical_post_dump_sort(self, dumped_schema: dict) -> None:
_ensure_defined_key_in_mapping(dumped_schema, COMMON_MAPPING_KEY)
for key, member_schema in dumped_schema[GROUP_MEMBER_KEY].items():
dumped_schema[GROUP_MEMBER_KEY][key] = {
DEFINED_MAPPING_KEY: member_schema.pop(DEFINED_MAPPING_KEY, {}),
**member_schema,
}
return dumped_schema


class CultureSchemer(AbstractSchemer):
Expand All @@ -276,28 +275,28 @@ def schema(
culture = {}
signature_models = []

for model_schema in self.source_schemers:
if isinstance(model_schema, SingleSchemer):
schema = model_schema.schema()
for schemer in self.source_schemers:
if isinstance(schemer, SingleSchemer):
schema = schemer.schema()
schema.dump_arg_name_only = True
signature_models.append(model_schema.signature_model)
signature_models.append(schemer.signature_model)

elif isinstance(model_schema, HomologueSchemer):
schema = model_schema.schema(name_getter_kwargs=homologue_name_getter_kwargs)
elif isinstance(schemer, HomologueSchemer):
schema = schemer.schema(name_getter_kwargs=homologue_name_getter_kwargs)
schema.init_signature.dump_arg_name_only = True
signature_models.append(model_schema.single_schemer.signature_model)
signature_models.append(schemer.single_schemer.signature_model)

elif isinstance(model_schema, GroupSchemer):
schema = model_schema.schema()
elif isinstance(schemer, GroupSchemer):
schema = schemer.schema()
for member_schema in schema.members.values():
member_schema.dump_arg_name_only = True
signature_models.append(model_schema.common_signature_model)
signature_models.append(schemer.common_signature_model)

else:
msg = f"{model_schema.__class__} is not supported"
msg = f"{schemer.__class__} is not supported"
raise NotImplementedError(msg)

culture[model_schema.mapping_name] = schema
culture[schemer.mapping_name] = schema

signature_common_schema = merge_signature_models_by_intersection(*signature_models)
signature_common_schema.dump_arg_name_only = True
Expand All @@ -308,7 +307,7 @@ def schema(
argument_to_typing=merge_signature_models_by_flattening(*signature_models),
)

def parse_schema(self, schema_finding_info: SchemaPathDefinition) -> dict:
def load_definitions(self, schema_finding_info: LoadedOrPathSchema) -> dict:
loaded_schema = self.ensure_config_is_loaded(schema_finding_info)
common = loaded_schema.get(COMMON_MAPPING_KEY, {}).get(DEFINED_MAPPING_KEY, {})
result: ParsedSchema = {}
Expand All @@ -317,7 +316,7 @@ def parse_schema(self, schema_finding_info: SchemaPathDefinition) -> dict:
result[schemer_name] = schemer._schema_parser(internal, _inferior_config_kwargs=common)
return result

def load_into_mapping_name_to_instance(self, schema_finding_info: SchemaPathDefinition) -> NameToInstance:
def load_into_mapping_name_to_instance(self, schema_finding_info: LoadedOrPathSchema) -> NameToInstance:
loaded_schema = self.ensure_config_is_loaded(schema_finding_info)
common = loaded_schema.get(COMMON_MAPPING_KEY, {}).get(DEFINED_MAPPING_KEY, {})
result: ParsedSchema = {}
Expand All @@ -330,17 +329,28 @@ def load_into_mapping_name_to_instance(self, schema_finding_info: SchemaPathDefi
def _mapping_name_to_schemer(self) -> dict[str, SingleSchemer | HomologueSchemer | GroupSchemer]:
return {schema.mapping_name: schema for schema in self.source_schemers}

def _logical_post_dump_sort(self, dumped_schema: dict) -> dict:
def logical_post_dump_sort(self, dumped_schema: dict) -> None:
if COMMON_MAPPING_KEY in dumped_schema:
_ensure_defined_mapping_in_common_key(dumped_schema)
return dumped_schema
_ensure_defined_key_in_mapping(dumped_schema, COMMON_MAPPING_KEY)

for schemer in self.source_schemers:
_ensure_defined_key_in_mapping(dumped_schema, CULTURE_KEY, schemer.mapping_name)
if isinstance(schemer, GroupSchemer):
dumped_schema[CULTURE_KEY][schemer.mapping_name].pop(ARGUMENT_TO_TYPING_KEY)


Schemer: TypeAlias = SingleSchemer | HomologueSchemer | GroupSchemer | CultureSchemer


def _ensure_defined_mapping_in_common_key(dumped_schema: dict) -> None:
dumped_schema[COMMON_MAPPING_KEY] = {
DEFINED_MAPPING_KEY: dumped_schema[COMMON_MAPPING_KEY].pop(DEFINED_MAPPING_KEY, {}),
**dumped_schema[COMMON_MAPPING_KEY],
def _ensure_defined_key_in_mapping(dumped_schema: dict, *target_path: str) -> None:
target = dumped_schema
*remaining_path, last_key = target_path

for path in remaining_path:
target = target[path]

new_target = {
DEFINED_MAPPING_KEY: target[last_key].pop(DEFINED_MAPPING_KEY, {}),
**target[last_key],
}
target[last_key] = new_target
6 changes: 3 additions & 3 deletions schemantic/schemer/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
LoadedSchema,
ParsedSchema,
SchemaDefinition,
SchemaPathDefinition,
LoadedOrPathSchema,
SignatureModel,
SingleSchema,
)
Expand Down Expand Up @@ -141,8 +141,8 @@ def _parse_into_instance_by_mapping_name(self, loaded_schema: LoadedSchema, infe
def signature_model(self) -> SignatureModel:
return SignatureModel(required=self.required, optional=self.optional)

def _logical_post_dump_sort(self, dumped_schema: dict) -> dict:
return dumped_schema
def logical_post_dump_sort(self, dumped_schema: dict) -> None:
pass

def _schema_parser(
self,
Expand Down
1 change: 1 addition & 0 deletions schemantic/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

CLASS_NAME_KEY: Final[str] = "class_name"
INIT_SIGNATURE_MAPPING_KEY: Final[str] = "init_signature"
ARGUMENT_TO_TYPING_KEY: Final[str] = "argument_to_typing"

HOMOLOGUE_INSTANCE_KEY: Final[str] = "instances"
GROUP_MEMBER_KEY: Final[str] = "members"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_case/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def expected_mapping_name_to_instance(self) -> dict: ...
def test_schema_generates(self):
self.assertTrue(self.main_schemer.schema())

def test_parse_schema_to_instance_call(self):
def test_load_into_mapping_name_to_instance_call(self):
self.assertTrue(self.expected_mapping_name_to_instance)

def test_parse_schema(self):
def test_load_definitions(self):
self.assertEqual(
self.main_schemer.parse_schema(self.dump_expected_schema()),
self.main_schemer.load_definitions(self.dump_expected_schema()),
self.expected_mapping_name_to_config,
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_case/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def _infer_single_schemer(origin: type, pre_definition: dict | None) -> SingleSc


class SetSchemerMixin:
def test_parse_schema_to_instance_control(self):
parse_result = self.main_schemer._parse_into_instance_by_mapping_name(self.dump_expected_schema())
def test_load_into_mapping_name_to_instance_control(self):
parse_result = self.main_schemer._parse_into_instance_by_mapping_name(self.dump_expected_schema(), {})
for mapping_name, instance in parse_result.items():
self.assertEqual(instance, self.expected_mapping_name_to_instance[mapping_name])

Expand Down
2 changes: 1 addition & 1 deletion tests/test_schema_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class WithDataclassMixin:
test_class_a: ClassVar[Type] = TestDataclass
test_class_b: ClassVar[Type] = OtherTestDataclass

def test_parse_schema_to_instance(self):
def test_load_into_mapping_name_to_instance(self):
self.assertTrue(self.schemer._parse_into_instance_by_mapping_name(self.schema_with_instance_configuration))


Expand Down
Loading

0 comments on commit 7f32c03

Please sign in to comment.