Skip to content

Commit

Permalink
bug/conditionally enable typeddicts (#78)
Browse files Browse the repository at this point in the history
Why
===

Unfortunately switching over to `TypedDicts` turned out to be overly
optimistic without any sort of test suite to keep things in parity. For
now, I'll make the codegen require explicit opt-in to not break clients.

What changed
============

- Conditionally opt-in to TypedDict inputs
- Bugfix: Encode arrays using the inner encoder
- Optional `Field(...`'s are invalid in pydantic, special case that.

Test plan
=========

Manual testing for now
  • Loading branch information
blast-hardcheese authored Sep 2, 2024
1 parent 4d26277 commit c5299df
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
37 changes: 27 additions & 10 deletions replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class RiverService(BaseModel):

class RiverSchema(BaseModel):
services: Dict[str, RiverService]
handshakeSchema: Optional[RiverConcreteType]
handshakeSchema: Optional[RiverConcreteType] = Field(default=None)


RiverSchemaFile = RootModel[RiverSchema]
Expand Down Expand Up @@ -304,11 +304,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
typeddict_encoder.append("None")
return ("None", ())
if type.type == "Date":
# typeddict_encoder.append("TODO")
typeddict_encoder.append("TODO: dstewart")
return ("datetime.datetime", ())
if type.type == "array" and type.items:
type_name, type_chunks = encode_type(type.items, prefix, base_model)
# typeddict_encoder.append("TODO")
typeddict_encoder.append("TODO: dstewart")
return (f"List[{type_name}]", type_chunks)
if (
type.type == "object"
Expand Down Expand Up @@ -355,6 +355,14 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
typeddict_encoder.append(
f"if x['{safe_name}'] else None"
)
elif prop.type == "array":
assert type_name.startswith(
"List["
) # in case we change to list[...]
_inner_type_name = type_name[len("List[") : -len("]")]
typeddict_encoder.append(
f"[encode_{_inner_type_name}(y) for y in x['{name}']]"
)
else:
typeddict_encoder.append(f"x['{safe_name}']")

Expand All @@ -372,9 +380,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
if name not in type.required:
value = ""
if base_model != "TypedDict":
value = (
f" = Field({field_value}, alias='{name}', default=None)"
)
args = f"alias='{name}', default=None"
if field_value != "...":
value = f" = Field({field_value}, {args})"
else:
value = f" = Field({args})"
current_chunks.append(f" kind: Optional[{type_name}]{value}")
else:
value = ""
Expand Down Expand Up @@ -411,6 +421,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
def generate_river_client_module(
client_name: str,
schema_root: RiverSchema,
typed_dict_inputs: bool,
) -> Sequence[str]:
chunks: List[str] = [
dedent(
Expand Down Expand Up @@ -448,6 +459,7 @@ def generate_river_client_module(
else:
handshake_type = "Literal[None]"

input_base_class = "TypedDict" if typed_dict_inputs else "BaseModel"
for schema_name, schema in schema_root.services.items():
current_chunks: List[str] = [
dedent(
Expand All @@ -464,13 +476,13 @@ def __init__(self, client: river.Client[{handshake_type}]):
init_type, input_chunks = encode_type(
procedure.init,
f"{schema_name.title()}{name.title()}Init",
base_model="TypedDict",
base_model=input_base_class,
)
chunks.extend(input_chunks)
input_type, input_chunks = encode_type(
procedure.input,
f"{schema_name.title()}{name.title()}Input",
base_model="TypedDict",
base_model=input_base_class,
)
chunks.extend(input_chunks)
output_type, output_chunks = encode_type(
Expand Down Expand Up @@ -692,13 +704,18 @@ def __init__(self, client: river.Client[{handshake_type}]):


def schema_to_river_client_codegen(
schema_path: str, target_path: str, client_name: str
schema_path: str,
target_path: str,
client_name: str,
typed_dict_inputs: bool,
) -> None:
"""Generates the lines of a River module."""
with open(schema_path) as f:
schemas = RiverSchemaFile(json.load(f))
with open(target_path, "w") as f:
s = "\n".join(generate_river_client_module(client_name, schemas.root))
s = "\n".join(
generate_river_client_module(client_name, schemas.root, typed_dict_inputs)
)
try:
f.write(
black.format_str(s, mode=black.FileMode(string_normalization=False))
Expand Down
10 changes: 9 additions & 1 deletion replit_river/codegen/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def main() -> None:
)
client.add_argument("--output", help="output file", required=True)
client.add_argument("--client-name", help="name of the class", required=True)
client.add_argument(
"--typed-dict-inputs",
help="Enable typed dicts",
action="store_true",
default=False,
)
client.add_argument("schema", help="schema file")
args = parser.parse_args()

Expand All @@ -41,6 +47,8 @@ def main() -> None:
elif args.command == "client":
schema_path = os.path.abspath(args.schema)
target_path = os.path.abspath(args.output)
schema_to_river_client_codegen(schema_path, target_path, args.client_name)
schema_to_river_client_codegen(
schema_path, target_path, args.client_name, args.typed_dict_inputs
)
else:
raise NotImplementedError(f"Unknown command {args.command}")

0 comments on commit c5299df

Please sign in to comment.