From 44f8f8ac7b14da9aef086b0ecce07b9eb544eae5 Mon Sep 17 00:00:00 2001 From: Yikai Gao Date: Sun, 8 Sep 2024 15:59:47 -0700 Subject: [PATCH] Create session id and use it in table 'Pytorch Elastic Tsm Log' (#953) Summary: Pull Request resolved: https://github.com/pytorch/torchx/pull/953 Please read the doc to understand why we create the session id: https://docs.google.com/document/d/1WJBrqSHrNIc9J1W_1PMIQPu11y2fV_hU36aeiBTgN90/edit Differential Revision: D62087199 --- torchx/runner/events/__init__.py | 3 +- torchx/runner/events/api.py | 2 +- torchx/runner/events/test/lib_test.py | 25 +++++++++++----- torchx/runner/test/api_test.py | 42 +++++++++++++++++++++++++-- torchx/util/session.py | 26 +++++++++++++++++ 5 files changed, 86 insertions(+), 12 deletions(-) create mode 100644 torchx/util/session.py diff --git a/torchx/runner/events/__init__.py b/torchx/runner/events/__init__.py index 360cb3e7c..c8eb89d96 100644 --- a/torchx/runner/events/__init__.py +++ b/torchx/runner/events/__init__.py @@ -27,6 +27,7 @@ from typing import Dict, Optional, Type from torchx.runner.events.handlers import get_logging_handler +from torchx.util.session import get_session_id_or_create_new from .api import SourceType, TorchxEvent # noqa F401 @@ -136,7 +137,7 @@ def _generate_torchx_event( workspace: Optional[str] = None, ) -> TorchxEvent: return TorchxEvent( - session=app_id or "", + session=get_session_id_or_create_new(), scheduler=scheduler, api=api, app_id=app_id, diff --git a/torchx/runner/events/api.py b/torchx/runner/events/api.py index ce5bc8998..355c03f6c 100644 --- a/torchx/runner/events/api.py +++ b/torchx/runner/events/api.py @@ -25,7 +25,7 @@ class TorchxEvent: The class represents the event produced by ``torchx.runner`` api calls. Arguments: - session: Session id that was used to execute request. + session: Session id of the current run scheduler: Scheduler that is used to execute request api: Api name app_id: Unique id that is set by the underlying scheduler diff --git a/torchx/runner/events/test/lib_test.py b/torchx/runner/events/test/lib_test.py index 92bb3c828..bbeed590e 100644 --- a/torchx/runner/events/test/lib_test.py +++ b/torchx/runner/events/test/lib_test.py @@ -19,6 +19,8 @@ TorchxEvent, ) +SESSION_ID = "123" + class TorchxEventLibTest(unittest.TestCase): def assert_event( @@ -44,14 +46,14 @@ def test_get_or_create_logger(self, logging_handler_mock: MagicMock) -> None: def test_event_created(self) -> None: test_metadata = {"test_key": "test_value"} event = TorchxEvent( - session="test_session", + session=SESSION_ID, scheduler="test_scheduler", api="test_api", app_image="test_app_image", app_metadata=test_metadata, workspace="test_workspace", ) - self.assertEqual("test_session", event.session) + self.assertEqual(SESSION_ID, event.session) self.assertEqual("test_scheduler", event.scheduler) self.assertEqual("test_api", event.api) self.assertEqual("test_app_image", event.app_image) @@ -76,6 +78,7 @@ def test_event_deser(self) -> None: @patch("torchx.runner.events.record") +@patch("torchx.runner.events.get_session_id_or_create_new") class LogEventTest(unittest.TestCase): def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> None: self.assertEqual(expected.session, actual.session) @@ -86,7 +89,10 @@ def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> Non self.assertEqual(expected.workspace, actual.workspace) self.assertEqual(expected.app_metadata, actual.app_metadata) - def test_create_context(self, _) -> None: + def test_create_context( + self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock + ) -> None: + get_session_id_or_create_new_mock.return_value = SESSION_ID test_dict = {"test_key": "test_value"} cfg = json.dumps(test_dict) context = log_event( @@ -99,7 +105,7 @@ def test_create_context(self, _) -> None: workspace="test_workspace", ) expected_torchx_event = TorchxEvent( - "test_app_id", + SESSION_ID, "local", "test_call", "test_app_id", @@ -111,7 +117,10 @@ def test_create_context(self, _) -> None: self.assert_torchx_event(expected_torchx_event, context._torchx_event) - def test_record_event(self, record_mock: MagicMock) -> None: + def test_record_event( + self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock + ) -> None: + get_session_id_or_create_new_mock.return_value = SESSION_ID test_dict = {"test_key": "test_value"} cfg = json.dumps(test_dict) with log_event( @@ -126,7 +135,7 @@ def test_record_event(self, record_mock: MagicMock) -> None: pass expected_torchx_event = TorchxEvent( - "test_app_id", + SESSION_ID, "local", "test_call", "test_app_id", @@ -139,7 +148,9 @@ def test_record_event(self, record_mock: MagicMock) -> None: ) self.assert_torchx_event(expected_torchx_event, ctx._torchx_event) - def test_record_event_with_exception(self, record_mock: MagicMock) -> None: + def test_record_event_with_exception( + self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock + ) -> None: cfg = json.dumps({"test_key": "test_value"}) with self.assertRaises(RuntimeError): with log_event("test_call", "local", "test_app_id", cfg) as ctx: diff --git a/torchx/runner/test/api_test.py b/torchx/runner/test/api_test.py index 155555afa..a53bd314b 100644 --- a/torchx/runner/test/api_test.py +++ b/torchx/runner/test/api_test.py @@ -32,6 +32,7 @@ from torchx.specs.finder import ComponentNotFoundException from torchx.test.fixtures import TestWithTmpDir from torchx.tracker.api import ENV_TORCHX_JOB_ID, ENV_TORCHX_PARENT_RUN_ID +from torchx.util.session import get_session_id from torchx.util.types import none_throws from torchx.workspace import WorkspaceMixin @@ -51,7 +52,7 @@ def get_full_path(name: str) -> str: return os.path.join(os.path.dirname(__file__), "resource", name) -@patch("torchx.runner.api.log_event") +@patch("torchx.runner.events.record") class RunnerTest(TestWithTmpDir): def setUp(self) -> None: super().setUp() @@ -104,7 +105,38 @@ def test_validate_invalid_replicas(self, _) -> None: with self.assertRaises(ValueError): runner.run(app, scheduler="local_dir") - def test_run(self, _) -> None: + @patch("torchx.util.session.uuid") + def test_session_id(self, uuid_mock: MagicMock, record_mock: MagicMock) -> None: + uuid_mock.uuid4.return_value = "test_session_id" + test_file = self.tmpdir / "test_file" + + with self.get_runner() as runner: + self.assertEqual(1, len(runner.scheduler_backends())) + role = Role( + name="touch", + image=str(self.tmpdir), + resource=resource.SMALL, + entrypoint="touch.sh", + args=[str(test_file)], + ) + app = AppDef("name", roles=[role]) + + app_handle_1 = runner.run(app, scheduler="local_dir", cfg=self.cfg) + none_throws(runner.wait(app_handle_1, wait_interval=0.1)) + + app_handle_2 = runner.run(app, scheduler="local_dir", cfg=self.cfg) + none_throws(runner.wait(app_handle_2, wait_interval=0.1)) + + self.assertEqual(get_session_id(), "test_session_id") + uuid_mock.uuid4.assert_called_once() + record_mock.assert_called() + for i in range(record_mock.call_count): + event = record_mock.call_args_list[i].args[0] + self.assertEqual(event.session, "test_session_id") + + @patch("torchx.util.session.uuid") + def test_run(self, uuid_mock: MagicMock, _) -> None: + uuid_mock.uuid4.return_value = "test_session_id" test_file = self.tmpdir / "test_file" with self.get_runner() as runner: @@ -121,8 +153,11 @@ def test_run(self, _) -> None: app_handle = runner.run(app, scheduler="local_dir", cfg=self.cfg) app_status = none_throws(runner.wait(app_handle, wait_interval=0.1)) self.assertEqual(AppState.SUCCEEDED, app_status.state) + self.assertEqual(get_session_id(), "test_session_id") - def test_dryrun(self, _) -> None: + @patch("torchx.util.session.uuid") + def test_dryrun(self, uuid_mock: MagicMock, _) -> None: + uuid_mock.uuid4.return_value = "test_session_id" scheduler_mock = MagicMock() scheduler_mock.run_opts.return_value.resolve.return_value = { **self.cfg, @@ -145,6 +180,7 @@ def test_dryrun(self, _) -> None: app, {**self.cfg, "foo": "bar"} ) scheduler_mock._validate.assert_called_once() + self.assertEqual(get_session_id(), "test_session_id") def test_dryrun_env_variables(self, _) -> None: scheduler_mock = MagicMock() diff --git a/torchx/util/session.py b/torchx/util/session.py new file mode 100644 index 000000000..94f8101d0 --- /dev/null +++ b/torchx/util/session.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import uuid +from typing import Optional + +_CURRENT_SESSION_ID: Optional[str] = None + + +def get_session_id_or_create_new() -> str: + global _CURRENT_SESSION_ID + if _CURRENT_SESSION_ID: + return _CURRENT_SESSION_ID + session_id = str(uuid.uuid4()) + _CURRENT_SESSION_ID = session_id + return session_id + + +def get_session_id() -> Optional[str]: + return _CURRENT_SESSION_ID