diff --git a/torchx/schedulers/ray_scheduler.py b/torchx/schedulers/ray_scheduler.py index a425fc107..8d986bf00 100644 --- a/torchx/schedulers/ray_scheduler.py +++ b/torchx/schedulers/ray_scheduler.py @@ -8,12 +8,13 @@ import json import logging import os +import re import tempfile import time from dataclasses import dataclass, field from datetime import datetime from shutil import copy2, rmtree -from typing import Any, cast, Dict, Iterable, List, Mapping, Optional, Set, Type # noqa +from typing import Any, cast, Dict, Iterable, List, Optional, Tuple # noqa from torchx.schedulers.api import ( AppDryRunInfo, @@ -322,13 +323,25 @@ def wait_until_finish(self, app_id: str, timeout: int = 30) -> None: break time.sleep(1) - def _cancel_existing(self, app_id: str) -> None: # pragma: no cover + def _parse_app_id(self, app_id: str) -> Tuple[str, str]: + # find index of '-' in the first :\d+- + m = re.search(r":\d+-", app_id) + if m: + sep = m.span()[1] + addr = app_id[: sep - 1] + app_id = app_id[sep:] + return addr, app_id + addr, _, app_id = app_id.partition("-") + return addr, app_id + + def _cancel_existing(self, app_id: str) -> None: # pragma: no cover + addr, app_id = self._parse_app_id(app_id) client = JobSubmissionClient(f"http://{addr}") client.stop_job(app_id) def _get_job_status(self, app_id: str) -> JobStatus: - addr, _, app_id = app_id.partition("-") + addr, app_id = self._parse_app_id(app_id) client = JobSubmissionClient(f"http://{addr}") status = client.get_job_status(app_id) if isinstance(status, str): @@ -375,7 +388,7 @@ def log_iter( streams: Optional[Stream] = None, ) -> Iterable[str]: # TODO: support tailing, streams etc.. - addr, _, app_id = app_id.partition("-") + addr, app_id = self._parse_app_id(app_id) client: JobSubmissionClient = JobSubmissionClient(f"http://{addr}") logs: str = client.get_job_logs(app_id) iterator = split_lines(logs) diff --git a/torchx/schedulers/test/ray_scheduler_test.py b/torchx/schedulers/test/ray_scheduler_test.py index cd05d62da..f4f6fb04f 100644 --- a/torchx/schedulers/test/ray_scheduler_test.py +++ b/torchx/schedulers/test/ray_scheduler_test.py @@ -298,6 +298,23 @@ def test_requirements(self) -> None: job = req.request self.assertEqual(job.requirements, reqs) + def test_parse_app_id(self) -> None: + test_addr_appid = [ + ( + "0.0.0.0:1234-app_id", + "0.0.0.0:1234", + "app_id", + ), # (full address, address:port, app_id) + ("addr-of-cluster:1234-app-id", "addr-of-cluster:1234", "app-id"), + ("www.test.com:1234-app:id", "www.test.com:1234", "app:id"), + ("foo", "foo", ""), + ("foo-bar-bar", "foo", "bar-bar"), + ] + for test_example, addr, app_id in test_addr_appid: + parsed_addr, parsed_appid = self._scheduler._parse_app_id(test_example) + self.assertEqual(parsed_addr, addr) + self.assertEqual(parsed_appid, app_id) + class RayClusterSetup: _instance = None # pyre-ignore[4]