diff --git a/replit_river/codegen/client.py b/replit_river/codegen/client.py index d4117ee..22edead 100644 --- a/replit_river/codegen/client.py +++ b/replit_river/codegen/client.py @@ -1,9 +1,13 @@ import json +import re from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import black from pydantic import BaseModel, Field, RootModel +_NON_ALNUM_RE = re.compile(r"[^a-zA-Z0-9_]+") +_LITERAL_RE = re.compile(r"^Literal\[(.+)\]$") + class RiverConcreteType(BaseModel): type: Optional[str] = Field(default=None) @@ -49,6 +53,61 @@ def encode_type( if isinstance(type, RiverNotType): return ("None", ()) if isinstance(type, RiverUnionType): + # First check if it's a discriminated union. Typebox currently doesn't have + # a way of expressing the intention of having a discriminated union. So we + # do a bit of detection if that is structurally true by checking that all the + # types in the anyOf are objects, have properties, and have one property common + # to all the alternatives that has a literal value. + one_of_candidate_types: List[RiverConcreteType] = [ + t + for t in type.anyOf + if isinstance(t, RiverConcreteType) + and t.type == "object" + and t.properties + and (not t.patternProperties or "^(.*)$" not in t.patternProperties) + ] + if len(type.anyOf) > 0 and len(type.anyOf) == len(one_of_candidate_types): + # We have established that it is a union-of-objects. Now let's see if + # there is a discriminator field common among all options. + literal_fields = set[str]() + for i, oneof_t in enumerate(one_of_candidate_types): + lf = set[str]( + name + for name, prop in oneof_t.properties.items() + if isinstance(prop, RiverConcreteType) + and prop.type in ("string", "number", "boolean") + and prop.const is not None + ) + if i == 0: + literal_fields = lf + else: + literal_fields.intersection_update(lf) + if not literal_fields: + # There are no more candidates. + break + if len(literal_fields) == 1: + # Hooray! we found a discriminated union. + discriminator_name = literal_fields.pop() + one_of: List[str] = [] + + for oneof_t in one_of_candidate_types: + discriminator_value = [ + _NON_ALNUM_RE.sub("", str(prop.const)) + for name, prop in oneof_t.properties.items() + if isinstance(prop, RiverConcreteType) + and name == discriminator_name + and prop.const is not None + ].pop() + type_name, type_chunks = encode_type( + oneof_t, f"{prefix}OneOf_{discriminator_value}", base_model + ) + chunks.extend(type_chunks) + one_of.append(type_name) + if discriminator_name == "$kind": + discriminator_name = "kind" + chunks.append(f"{prefix} = Union[" + ", ".join(one_of) + "]") + chunks.append("") + return (prefix, chunks) any_of: List[str] = [] for i, t in enumerate(type.anyOf): type_name, type_chunks = encode_type(t, f"{prefix}AnyOf_{i}", base_model) @@ -104,14 +163,25 @@ def encode_type( ) chunks.extend(type_chunks) if name == "$kind": + # If the field is a literal, the Python type-checker will complain + # about the constructor not being able to specify a value for it: + # You can't put `$kind="ok"` in the ctor because `$` is not a valid + # character in an identifier, and putting `**{"$kind":"ok"}` makes + # it not recognize the `"ok"` as being `Literal["ok"]`, so we're + # stuck with an impossible-to-construct object. + field_value = "..." + groups = _LITERAL_RE.match(type_name) + if groups: + field_value = groups.group(1) if name not in type.required: current_chunks.append( f" kind: Optional[{type_name}] = " - f"Field(..., alias='{name}', default=None)" + f"Field({field_value}, alias='{name}', default=None)" ) else: current_chunks.append( - f" kind: {type_name} = Field(..., alias='{name}')" + f" kind: {type_name} = " + f"Field({field_value}, alias='{name}')" ) else: if name not in type.required: