Skip to content

Commit

Permalink
[feat] Expose containing module for servergen (#105)
Browse files Browse the repository at this point in the history
Why
===

We have some `sed` commands that get run against generated code.
Instead, surface config options so we don't need to do that.

We also don't need to lint generated sources.

What changed
============

Added `replit_river.codegen server --module ...`
`# ruff: noqa` at the top of generated files

Test plan
=========

Manual testing resulted in the same file after replacing `sed` with
`--module ...`
  • Loading branch information
blast-hardcheese authored Nov 6, 2024
1 parent 6405ae8 commit e00d75f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
1 change: 1 addition & 0 deletions replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def generate_river_client_module(
chunks: List[str] = [
dedent(
"""\
# ruff: noqa
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
import datetime
Expand Down
5 changes: 4 additions & 1 deletion replit_river/codegen/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def main() -> None:
server = subparsers.add_parser(
"server", help="Codegen a River server from gRPC protos"
)
server.add_argument(
"--module", dest="module_name", help="output module", default="."
)
server.add_argument("--output", help="output directory", required=True)
server.add_argument("proto", help="proto file")

Expand All @@ -39,7 +42,7 @@ def main() -> None:
if args.command == "server":
proto_path = os.path.abspath(args.proto)
target_directory = os.path.abspath(args.output)
proto_to_river_server_codegen(proto_path, target_directory)
proto_to_river_server_codegen(args.module_name, proto_path, target_directory)
elif args.command == "server-schema":
proto_path = os.path.abspath(args.proto)
target_directory = os.path.abspath(args.output)
Expand Down
42 changes: 27 additions & 15 deletions replit_river/codegen/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import os.path
import tempfile
from textwrap import dedent
from typing import DefaultDict, List, Sequence

import black
Expand Down Expand Up @@ -220,20 +221,27 @@ def message_encoder(

def generate_river_module(
module_name: str,
pb_module_name: str,
fds: descriptor_pb2.FileDescriptorSet,
) -> Sequence[str]:
"""Generates the lines of a River module."""
chunks: List[str] = [
"# Code generated by river.codegen. DO NOT EDIT.",
"import datetime",
"from typing import Any, Dict, Mapping, Tuple",
"",
"from google.protobuf import timestamp_pb2",
"from google.protobuf.wrappers_pb2 import BoolValue",
dedent(
f"""\
# Code generated by river.codegen. DO NOT EDIT.
import datetime
from typing import Any, Dict, Mapping, Tuple
from google.protobuf import timestamp_pb2
from google.protobuf.wrappers_pb2 import BoolValue
import replit_river as river
from {module_name} import {pb_module_name}_pb2, {pb_module_name}_pb2_grpc
"""
),
"",
"import replit_river as river",
"",
f"from . import {module_name}_pb2, {module_name}_pb2_grpc\n\n",
]
for pd in fds.file:

Expand All @@ -242,15 +250,15 @@ def _remove_namespace(name: str) -> str:

# Generate the message encoders/decoders.
for message in pd.message_type:
chunks.extend(message_encoder(module_name, message))
chunks.extend(message_decoder(module_name, message))
chunks.extend(message_encoder(pb_module_name, message))
chunks.extend(message_decoder(pb_module_name, message))

# Generate the service stubs.
for service in pd.service:
chunks.extend(
[
f"""def add_{service.name}Servicer_to_server(
servicer: {module_name}_pb2_grpc.{service.name}Servicer,
servicer: {pb_module_name}_pb2_grpc.{service.name}Servicer,
server: river.Server,
) -> None:""",
(
Expand Down Expand Up @@ -301,7 +309,11 @@ def _remove_namespace(name: str) -> str:
return chunks


def proto_to_river_server_codegen(proto_path: str, target_directory: str) -> None:
def proto_to_river_server_codegen(
module_name: str,
proto_path: str,
target_directory: str,
) -> None:
fds = descriptor_pb2.FileDescriptorSet()
with tempfile.TemporaryDirectory() as tempdir:
descriptor_path = os.path.join(tempdir, "descriptor.pb")
Expand All @@ -317,12 +329,12 @@ def proto_to_river_server_codegen(proto_path: str, target_directory: str) -> Non
)
with open(descriptor_path, "rb") as f:
fds.ParseFromString(f.read())
module_name = os.path.splitext(os.path.basename(proto_path))[0]
pb_module_name = os.path.splitext(os.path.basename(proto_path))[0]
contents = black.format_str(
"\n".join(generate_river_module(module_name, fds)),
"\n".join(generate_river_module(module_name, pb_module_name, fds)),
mode=black.FileMode(string_normalization=False),
)
os.makedirs(target_directory, exist_ok=True)
output_path = f"{target_directory}/{module_name}_river.py"
output_path = f"{target_directory}/{pb_module_name}_river.py"
with open(output_path, "w") as f:
f.write(contents)

0 comments on commit e00d75f

Please sign in to comment.