Skip to content

Commit

Permalink
fix jobs never reloading (#934)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Oct 16, 2024
1 parent e6b2470 commit 879c3e2
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 19 deletions.
5 changes: 5 additions & 0 deletions .changeset/tidy-eggs-smell.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

fix jobs never reloading
4 changes: 1 addition & 3 deletions livekit-agents/livekit/agents/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,7 @@ async def _worker_run(worker: Worker) -> None:
if args.watch:
from .watcher import WatchClient

assert args.mp_cch is not None

watch_client = WatchClient(worker, args.mp_cch, loop=loop)
watch_client = WatchClient(worker, args, loop=loop)
watch_client.start()

try:
Expand Down
5 changes: 5 additions & 0 deletions livekit-agents/livekit/agents/cli/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ActiveJobsRequest:
class ActiveJobsResponse:
MSG_ID: ClassVar[int] = 2
jobs: list[RunningJobInfo] = field(default_factory=list)
reload_count: int = 0

def write(self, b: io.BytesIO) -> None:
channel.write_int(b, len(self.jobs))
Expand All @@ -52,6 +53,8 @@ def write(self, b: io.BytesIO) -> None:
channel.write_string(b, running_job.url)
channel.write_string(b, running_job.token)

channel.write_int(b, self.reload_count)

def read(self, b: io.BytesIO) -> None:
for _ in range(channel.read_int(b)):
job = agent.Job()
Expand All @@ -69,6 +72,8 @@ def read(self, b: io.BytesIO) -> None:
)
)

self.reload_count = channel.read_int(b)


@dataclass
class ReloadJobsRequest:
Expand Down
43 changes: 27 additions & 16 deletions livekit-agents/livekit/agents/cli/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self._loop = loop

self._recv_jobs_fut = asyncio.Future[None]()
self._reloading_jobs = False
self._worker_reloading = False

async def run(self) -> None:
watch_paths = _find_watchable_paths(self._main_file)
Expand All @@ -99,25 +99,29 @@ async def run(self) -> None:
await self._pch.aclose()

async def _on_reload(self, _: Set[watchfiles.main.FileChange]) -> None:
self._cli_args.reload_count += 1

if self._reloading_jobs:
if self._worker_reloading:
return

await channel.asend_message(self._pch, proto.ActiveJobsRequest())
self._working_reloading = True
self._worker_reloading = True

self._recv_jobs_fut = asyncio.Future()
with contextlib.suppress(asyncio.TimeoutError):
# wait max 1.5s to get the active jobs
await asyncio.wait_for(self._recv_jobs_fut, timeout=1.5)
try:
await channel.asend_message(self._pch, proto.ActiveJobsRequest())
self._recv_jobs_fut = asyncio.Future()
with contextlib.suppress(asyncio.TimeoutError):
# wait max 1.5s to get the active jobs
await asyncio.wait_for(self._recv_jobs_fut, timeout=1.5)
finally:
self._cli_args.reload_count += 1

@utils.log_exceptions(logger=logger)
async def _read_ipc_task(self) -> None:
active_jobs = []
while True:
msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES)
if isinstance(msg, proto.ActiveJobsResponse) and self._working_reloading:
if isinstance(msg, proto.ActiveJobsResponse):
if msg.reload_count != self._cli_args.reload_count:
continue

active_jobs = msg.jobs
with contextlib.suppress(asyncio.InvalidStateError):
self._recv_jobs_fut.set_result(None)
Expand All @@ -126,29 +130,33 @@ async def _read_ipc_task(self) -> None:
self._pch, proto.ReloadJobsResponse(jobs=active_jobs)
)
if isinstance(msg, proto.Reloaded):
self._working_reloading = False
self._worker_reloading = False


class WatchClient:
def __init__(
self,
worker: Worker,
mp_cch: socket.socket,
cli_args: proto.CliArgs,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
self._loop = loop or asyncio.get_event_loop()
self._worker = worker
self._mp_cch = mp_cch
self._cli_args = cli_args

def start(self) -> None:
self._main_task = self._loop.create_task(self._run())

@utils.log_exceptions(logger=logger)
async def _run(self) -> None:
assert self._cli_args.mp_cch
try:
self._cch = await utils.aio.duplex_unix._AsyncDuplex.open(self._mp_cch)
self._cch = await utils.aio.duplex_unix._AsyncDuplex.open(
self._cli_args.mp_cch
)

await channel.asend_message(self._cch, proto.ReloadJobsRequest())

while True:
try:
msg = await channel.arecv_message(self._cch, proto.IPC_MESSAGES)
Expand All @@ -158,7 +166,10 @@ async def _run(self) -> None:
if isinstance(msg, proto.ActiveJobsRequest):
jobs = self._worker.active_jobs
await channel.asend_message(
self._cch, proto.ActiveJobsResponse(jobs=jobs)
self._cch,
proto.ActiveJobsResponse(
jobs=jobs, reload_count=self._cli_args.reload_count
),
)
elif isinstance(msg, proto.ReloadJobsResponse):
# TODO(theomonnom): wait for the worker to be fully initialized/connected
Expand Down

0 comments on commit 879c3e2

Please sign in to comment.