diff --git a/replit_river/codegen/client.py b/replit_river/codegen/client.py index 0447e56..18d31f5 100644 --- a/replit_river/codegen/client.py +++ b/replit_river/codegen/client.py @@ -466,6 +466,7 @@ def generate_river_client_module( chunks: List[str] = [ dedent( """\ + # ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/replit_river/codegen/run.py b/replit_river/codegen/run.py index f7c2b8e..66f04d3 100644 --- a/replit_river/codegen/run.py +++ b/replit_river/codegen/run.py @@ -13,6 +13,9 @@ def main() -> None: server = subparsers.add_parser( "server", help="Codegen a River server from gRPC protos" ) + server.add_argument( + "--module", dest="module_name", help="output module", default="." + ) server.add_argument("--output", help="output directory", required=True) server.add_argument("proto", help="proto file") @@ -39,7 +42,7 @@ def main() -> None: if args.command == "server": proto_path = os.path.abspath(args.proto) target_directory = os.path.abspath(args.output) - proto_to_river_server_codegen(proto_path, target_directory) + proto_to_river_server_codegen(args.module_name, proto_path, target_directory) elif args.command == "server-schema": proto_path = os.path.abspath(args.proto) target_directory = os.path.abspath(args.output) diff --git a/replit_river/codegen/server.py b/replit_river/codegen/server.py index c14b3f4..6976b81 100644 --- a/replit_river/codegen/server.py +++ b/replit_river/codegen/server.py @@ -1,6 +1,7 @@ import collections import os.path import tempfile +from textwrap import dedent from typing import DefaultDict, List, Sequence import black @@ -220,20 +221,27 @@ def message_encoder( def generate_river_module( module_name: str, + pb_module_name: str, fds: descriptor_pb2.FileDescriptorSet, ) -> Sequence[str]: """Generates the lines of a River module.""" chunks: List[str] = [ - "# Code generated by river.codegen. DO NOT EDIT.", - "import datetime", - "from typing import Any, Dict, Mapping, Tuple", - "", - "from google.protobuf import timestamp_pb2", - "from google.protobuf.wrappers_pb2 import BoolValue", + dedent( + f"""\ + # Code generated by river.codegen. DO NOT EDIT. + import datetime + from typing import Any, Dict, Mapping, Tuple + + from google.protobuf import timestamp_pb2 + from google.protobuf.wrappers_pb2 import BoolValue + + import replit_river as river + + from {module_name} import {pb_module_name}_pb2, {pb_module_name}_pb2_grpc + """ + ), "", - "import replit_river as river", "", - f"from . import {module_name}_pb2, {module_name}_pb2_grpc\n\n", ] for pd in fds.file: @@ -242,15 +250,15 @@ def _remove_namespace(name: str) -> str: # Generate the message encoders/decoders. for message in pd.message_type: - chunks.extend(message_encoder(module_name, message)) - chunks.extend(message_decoder(module_name, message)) + chunks.extend(message_encoder(pb_module_name, message)) + chunks.extend(message_decoder(pb_module_name, message)) # Generate the service stubs. for service in pd.service: chunks.extend( [ f"""def add_{service.name}Servicer_to_server( - servicer: {module_name}_pb2_grpc.{service.name}Servicer, + servicer: {pb_module_name}_pb2_grpc.{service.name}Servicer, server: river.Server, ) -> None:""", ( @@ -301,7 +309,11 @@ def _remove_namespace(name: str) -> str: return chunks -def proto_to_river_server_codegen(proto_path: str, target_directory: str) -> None: +def proto_to_river_server_codegen( + module_name: str, + proto_path: str, + target_directory: str, +) -> None: fds = descriptor_pb2.FileDescriptorSet() with tempfile.TemporaryDirectory() as tempdir: descriptor_path = os.path.join(tempdir, "descriptor.pb") @@ -317,12 +329,12 @@ def proto_to_river_server_codegen(proto_path: str, target_directory: str) -> Non ) with open(descriptor_path, "rb") as f: fds.ParseFromString(f.read()) - module_name = os.path.splitext(os.path.basename(proto_path))[0] + pb_module_name = os.path.splitext(os.path.basename(proto_path))[0] contents = black.format_str( - "\n".join(generate_river_module(module_name, fds)), + "\n".join(generate_river_module(module_name, pb_module_name, fds)), mode=black.FileMode(string_normalization=False), ) os.makedirs(target_directory, exist_ok=True) - output_path = f"{target_directory}/{module_name}_river.py" + output_path = f"{target_directory}/{pb_module_name}_river.py" with open(output_path, "w") as f: f.write(contents)