diff --git a/.changeset/tidy-eggs-smell.md b/.changeset/tidy-eggs-smell.md new file mode 100644 index 000000000..1f5ce6839 --- /dev/null +++ b/.changeset/tidy-eggs-smell.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": patch +--- + +fix jobs never reloading diff --git a/livekit-agents/livekit/agents/cli/cli.py b/livekit-agents/livekit/agents/cli/cli.py index 9e2475a07..578ce5ced 100644 --- a/livekit-agents/livekit/agents/cli/cli.py +++ b/livekit-agents/livekit/agents/cli/cli.py @@ -271,9 +271,7 @@ async def _worker_run(worker: Worker) -> None: if args.watch: from .watcher import WatchClient - assert args.mp_cch is not None - - watch_client = WatchClient(worker, args.mp_cch, loop=loop) + watch_client = WatchClient(worker, args, loop=loop) watch_client.start() try: diff --git a/livekit-agents/livekit/agents/cli/proto.py b/livekit-agents/livekit/agents/cli/proto.py index 0f33fe519..f7753c579 100644 --- a/livekit-agents/livekit/agents/cli/proto.py +++ b/livekit-agents/livekit/agents/cli/proto.py @@ -40,6 +40,7 @@ class ActiveJobsRequest: class ActiveJobsResponse: MSG_ID: ClassVar[int] = 2 jobs: list[RunningJobInfo] = field(default_factory=list) + reload_count: int = 0 def write(self, b: io.BytesIO) -> None: channel.write_int(b, len(self.jobs)) @@ -52,6 +53,8 @@ def write(self, b: io.BytesIO) -> None: channel.write_string(b, running_job.url) channel.write_string(b, running_job.token) + channel.write_int(b, self.reload_count) + def read(self, b: io.BytesIO) -> None: for _ in range(channel.read_int(b)): job = agent.Job() @@ -69,6 +72,8 @@ def read(self, b: io.BytesIO) -> None: ) ) + self.reload_count = channel.read_int(b) + @dataclass class ReloadJobsRequest: diff --git a/livekit-agents/livekit/agents/cli/watcher.py b/livekit-agents/livekit/agents/cli/watcher.py index 803c9557f..5f4a60751 100644 --- a/livekit-agents/livekit/agents/cli/watcher.py +++ b/livekit-agents/livekit/agents/cli/watcher.py @@ -76,7 +76,7 @@ def __init__( self._loop = loop self._recv_jobs_fut = asyncio.Future[None]() - self._reloading_jobs = False + self._worker_reloading = False async def run(self) -> None: watch_paths = _find_watchable_paths(self._main_file) @@ -99,25 +99,29 @@ async def run(self) -> None: await self._pch.aclose() async def _on_reload(self, _: Set[watchfiles.main.FileChange]) -> None: - self._cli_args.reload_count += 1 - - if self._reloading_jobs: + if self._worker_reloading: return - await channel.asend_message(self._pch, proto.ActiveJobsRequest()) - self._working_reloading = True + self._worker_reloading = True - self._recv_jobs_fut = asyncio.Future() - with contextlib.suppress(asyncio.TimeoutError): - # wait max 1.5s to get the active jobs - await asyncio.wait_for(self._recv_jobs_fut, timeout=1.5) + try: + await channel.asend_message(self._pch, proto.ActiveJobsRequest()) + self._recv_jobs_fut = asyncio.Future() + with contextlib.suppress(asyncio.TimeoutError): + # wait max 1.5s to get the active jobs + await asyncio.wait_for(self._recv_jobs_fut, timeout=1.5) + finally: + self._cli_args.reload_count += 1 @utils.log_exceptions(logger=logger) async def _read_ipc_task(self) -> None: active_jobs = [] while True: msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES) - if isinstance(msg, proto.ActiveJobsResponse) and self._working_reloading: + if isinstance(msg, proto.ActiveJobsResponse): + if msg.reload_count != self._cli_args.reload_count: + continue + active_jobs = msg.jobs with contextlib.suppress(asyncio.InvalidStateError): self._recv_jobs_fut.set_result(None) @@ -126,29 +130,33 @@ async def _read_ipc_task(self) -> None: self._pch, proto.ReloadJobsResponse(jobs=active_jobs) ) if isinstance(msg, proto.Reloaded): - self._working_reloading = False + self._worker_reloading = False class WatchClient: def __init__( self, worker: Worker, - mp_cch: socket.socket, + cli_args: proto.CliArgs, loop: asyncio.AbstractEventLoop | None = None, ) -> None: self._loop = loop or asyncio.get_event_loop() self._worker = worker - self._mp_cch = mp_cch + self._cli_args = cli_args def start(self) -> None: self._main_task = self._loop.create_task(self._run()) @utils.log_exceptions(logger=logger) async def _run(self) -> None: + assert self._cli_args.mp_cch try: - self._cch = await utils.aio.duplex_unix._AsyncDuplex.open(self._mp_cch) + self._cch = await utils.aio.duplex_unix._AsyncDuplex.open( + self._cli_args.mp_cch + ) await channel.asend_message(self._cch, proto.ReloadJobsRequest()) + while True: try: msg = await channel.arecv_message(self._cch, proto.IPC_MESSAGES) @@ -158,7 +166,10 @@ async def _run(self) -> None: if isinstance(msg, proto.ActiveJobsRequest): jobs = self._worker.active_jobs await channel.asend_message( - self._cch, proto.ActiveJobsResponse(jobs=jobs) + self._cch, + proto.ActiveJobsResponse( + jobs=jobs, reload_count=self._cli_args.reload_count + ), ) elif isinstance(msg, proto.ReloadJobsResponse): # TODO(theomonnom): wait for the worker to be fully initialized/connected