diff --git a/replit_river/codegen/client.py b/replit_river/codegen/client.py index 10ac905..f533f03 100644 --- a/replit_river/codegen/client.py +++ b/replit_river/codegen/client.py @@ -66,7 +66,7 @@ class RiverService(BaseModel): class RiverSchema(BaseModel): services: Dict[str, RiverService] - handshakeSchema: Optional[RiverConcreteType] + handshakeSchema: Optional[RiverConcreteType] = Field(default=None) RiverSchemaFile = RootModel[RiverSchema] @@ -304,11 +304,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: typeddict_encoder.append("None") return ("None", ()) if type.type == "Date": - # typeddict_encoder.append("TODO") + typeddict_encoder.append("TODO: dstewart") return ("datetime.datetime", ()) if type.type == "array" and type.items: type_name, type_chunks = encode_type(type.items, prefix, base_model) - # typeddict_encoder.append("TODO") + typeddict_encoder.append("TODO: dstewart") return (f"List[{type_name}]", type_chunks) if ( type.type == "object" @@ -355,6 +355,14 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: typeddict_encoder.append( f"if x['{safe_name}'] else None" ) + elif prop.type == "array": + assert type_name.startswith( + "List[" + ) # in case we change to list[...] + _inner_type_name = type_name[len("List[") : -len("]")] + typeddict_encoder.append( + f"[encode_{_inner_type_name}(y) for y in x['{name}']]" + ) else: typeddict_encoder.append(f"x['{safe_name}']") @@ -372,9 +380,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: if name not in type.required: value = "" if base_model != "TypedDict": - value = ( - f" = Field({field_value}, alias='{name}', default=None)" - ) + args = f"alias='{name}', default=None" + if field_value != "...": + value = f" = Field({field_value}, {args})" + else: + value = f" = Field({args})" current_chunks.append(f" kind: Optional[{type_name}]{value}") else: value = "" @@ -411,6 +421,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: def generate_river_client_module( client_name: str, schema_root: RiverSchema, + typed_dict_inputs: bool, ) -> Sequence[str]: chunks: List[str] = [ dedent( @@ -448,6 +459,7 @@ def generate_river_client_module( else: handshake_type = "Literal[None]" + input_base_class = "TypedDict" if typed_dict_inputs else "BaseModel" for schema_name, schema in schema_root.services.items(): current_chunks: List[str] = [ dedent( @@ -464,13 +476,13 @@ def __init__(self, client: river.Client[{handshake_type}]): init_type, input_chunks = encode_type( procedure.init, f"{schema_name.title()}{name.title()}Init", - base_model="TypedDict", + base_model=input_base_class, ) chunks.extend(input_chunks) input_type, input_chunks = encode_type( procedure.input, f"{schema_name.title()}{name.title()}Input", - base_model="TypedDict", + base_model=input_base_class, ) chunks.extend(input_chunks) output_type, output_chunks = encode_type( @@ -692,13 +704,18 @@ def __init__(self, client: river.Client[{handshake_type}]): def schema_to_river_client_codegen( - schema_path: str, target_path: str, client_name: str + schema_path: str, + target_path: str, + client_name: str, + typed_dict_inputs: bool, ) -> None: """Generates the lines of a River module.""" with open(schema_path) as f: schemas = RiverSchemaFile(json.load(f)) with open(target_path, "w") as f: - s = "\n".join(generate_river_client_module(client_name, schemas.root)) + s = "\n".join( + generate_river_client_module(client_name, schemas.root, typed_dict_inputs) + ) try: f.write( black.format_str(s, mode=black.FileMode(string_normalization=False)) diff --git a/replit_river/codegen/run.py b/replit_river/codegen/run.py index 2fbbc64..f7c2b8e 100644 --- a/replit_river/codegen/run.py +++ b/replit_river/codegen/run.py @@ -27,6 +27,12 @@ def main() -> None: ) client.add_argument("--output", help="output file", required=True) client.add_argument("--client-name", help="name of the class", required=True) + client.add_argument( + "--typed-dict-inputs", + help="Enable typed dicts", + action="store_true", + default=False, + ) client.add_argument("schema", help="schema file") args = parser.parse_args() @@ -41,6 +47,8 @@ def main() -> None: elif args.command == "client": schema_path = os.path.abspath(args.schema) target_path = os.path.abspath(args.output) - schema_to_river_client_codegen(schema_path, target_path, args.client_name) + schema_to_river_client_codegen( + schema_path, target_path, args.client_name, args.typed_dict_inputs + ) else: raise NotImplementedError(f"Unknown command {args.command}")