Skip to content

Commit

Permalink
Create session id and use it in table 'Pytorch Elastic Tsm Log' (pyto…
Browse files Browse the repository at this point in the history
…rch#953)

Summary:
Pull Request resolved: pytorch#953

Please read the doc to understand why we create the session id: https://docs.google.com/document/d/1WJBrqSHrNIc9J1W_1PMIQPu11y2fV_hU36aeiBTgN90/edit

Differential Revision: D62087199
  • Loading branch information
yikaiMeta authored and facebook-github-bot committed Sep 9, 2024
1 parent 66733b7 commit 3ec2ef9
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 12 deletions.
3 changes: 2 additions & 1 deletion torchx/runner/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Dict, Optional, Type

from torchx.runner.events.handlers import get_logging_handler
from torchx.util.session import get_session_id_or_create_new

from .api import SourceType, TorchxEvent # noqa F401

Expand Down Expand Up @@ -136,7 +137,7 @@ def _generate_torchx_event(
workspace: Optional[str] = None,
) -> TorchxEvent:
return TorchxEvent(
session=app_id or "",
session=get_session_id_or_create_new(),
scheduler=scheduler,
api=api,
app_id=app_id,
Expand Down
2 changes: 1 addition & 1 deletion torchx/runner/events/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TorchxEvent:
The class represents the event produced by ``torchx.runner`` api calls.
Arguments:
session: Session id that was used to execute request.
session: Session id of the current run
scheduler: Scheduler that is used to execute request
api: Api name
app_id: Unique id that is set by the underlying scheduler
Expand Down
25 changes: 18 additions & 7 deletions torchx/runner/events/test/lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
TorchxEvent,
)

SESSION_ID = "123"


class TorchxEventLibTest(unittest.TestCase):
def assert_event(
Expand All @@ -44,14 +46,14 @@ def test_get_or_create_logger(self, logging_handler_mock: MagicMock) -> None:
def test_event_created(self) -> None:
test_metadata = {"test_key": "test_value"}
event = TorchxEvent(
session="test_session",
session=SESSION_ID,
scheduler="test_scheduler",
api="test_api",
app_image="test_app_image",
app_metadata=test_metadata,
workspace="test_workspace",
)
self.assertEqual("test_session", event.session)
self.assertEqual(SESSION_ID, event.session)
self.assertEqual("test_scheduler", event.scheduler)
self.assertEqual("test_api", event.api)
self.assertEqual("test_app_image", event.app_image)
Expand All @@ -76,6 +78,7 @@ def test_event_deser(self) -> None:


@patch("torchx.runner.events.record")
@patch("torchx.runner.events.get_session_id_or_create_new")
class LogEventTest(unittest.TestCase):
def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> None:
self.assertEqual(expected.session, actual.session)
Expand All @@ -86,7 +89,10 @@ def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> Non
self.assertEqual(expected.workspace, actual.workspace)
self.assertEqual(expected.app_metadata, actual.app_metadata)

def test_create_context(self, _) -> None:
def test_create_context(
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
) -> None:
get_session_id_or_create_new_mock.return_value = SESSION_ID
test_dict = {"test_key": "test_value"}
cfg = json.dumps(test_dict)
context = log_event(
Expand All @@ -99,7 +105,7 @@ def test_create_context(self, _) -> None:
workspace="test_workspace",
)
expected_torchx_event = TorchxEvent(
"test_app_id",
SESSION_ID,
"local",
"test_call",
"test_app_id",
Expand All @@ -111,7 +117,10 @@ def test_create_context(self, _) -> None:

self.assert_torchx_event(expected_torchx_event, context._torchx_event)

def test_record_event(self, record_mock: MagicMock) -> None:
def test_record_event(
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
) -> None:
get_session_id_or_create_new_mock.return_value = SESSION_ID
test_dict = {"test_key": "test_value"}
cfg = json.dumps(test_dict)
with log_event(
Expand All @@ -126,7 +135,7 @@ def test_record_event(self, record_mock: MagicMock) -> None:
pass

expected_torchx_event = TorchxEvent(
"test_app_id",
SESSION_ID,
"local",
"test_call",
"test_app_id",
Expand All @@ -139,7 +148,9 @@ def test_record_event(self, record_mock: MagicMock) -> None:
)
self.assert_torchx_event(expected_torchx_event, ctx._torchx_event)

def test_record_event_with_exception(self, record_mock: MagicMock) -> None:
def test_record_event_with_exception(
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
) -> None:
cfg = json.dumps({"test_key": "test_value"})
with self.assertRaises(RuntimeError):
with log_event("test_call", "local", "test_app_id", cfg) as ctx:
Expand Down
42 changes: 39 additions & 3 deletions torchx/runner/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torchx.specs.finder import ComponentNotFoundException
from torchx.test.fixtures import TestWithTmpDir
from torchx.tracker.api import ENV_TORCHX_JOB_ID, ENV_TORCHX_PARENT_RUN_ID
from torchx.util.session import get_session_id
from torchx.util.types import none_throws
from torchx.workspace import WorkspaceMixin

Expand All @@ -51,7 +52,7 @@ def get_full_path(name: str) -> str:
return os.path.join(os.path.dirname(__file__), "resource", name)


@patch("torchx.runner.api.log_event")
@patch("torchx.runner.events.record")
class RunnerTest(TestWithTmpDir):
def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -104,7 +105,38 @@ def test_validate_invalid_replicas(self, _) -> None:
with self.assertRaises(ValueError):
runner.run(app, scheduler="local_dir")

def test_run(self, _) -> None:
@patch("torchx.util.session.uuid")
def test_session_id(self, uuid_mock: MagicMock, record_mock: MagicMock) -> None:
uuid_mock.uuid4.return_value = "test_session_id"
test_file = self.tmpdir / "test_file"

with self.get_runner() as runner:
self.assertEqual(1, len(runner.scheduler_backends()))
role = Role(
name="touch",
image=str(self.tmpdir),
resource=resource.SMALL,
entrypoint="touch.sh",
args=[str(test_file)],
)
app = AppDef("name", roles=[role])

app_handle_1 = runner.run(app, scheduler="local_dir", cfg=self.cfg)
none_throws(runner.wait(app_handle_1, wait_interval=0.1))

app_handle_2 = runner.run(app, scheduler="local_dir", cfg=self.cfg)
none_throws(runner.wait(app_handle_2, wait_interval=0.1))

self.assertEqual(get_session_id(), "test_session_id")
uuid_mock.uuid4.assert_called_once()
record_mock.assert_called()
for i in range(record_mock.call_count):
event = record_mock.call_args_list[i].args[0]
self.assertEqual(event.session, "test_session_id")

@patch("torchx.util.session.uuid")
def test_run(self, uuid_mock: MagicMock, _) -> None:
uuid_mock.uuid4.return_value = "test_session_id"
test_file = self.tmpdir / "test_file"

with self.get_runner() as runner:
Expand All @@ -121,8 +153,11 @@ def test_run(self, _) -> None:
app_handle = runner.run(app, scheduler="local_dir", cfg=self.cfg)
app_status = none_throws(runner.wait(app_handle, wait_interval=0.1))
self.assertEqual(AppState.SUCCEEDED, app_status.state)
self.assertEqual(get_session_id(), "test_session_id")

def test_dryrun(self, _) -> None:
@patch("torchx.util.session.uuid")
def test_dryrun(self, uuid_mock: MagicMock, _) -> None:
uuid_mock.uuid4.return_value = "test_session_id"
scheduler_mock = MagicMock()
scheduler_mock.run_opts.return_value.resolve.return_value = {
**self.cfg,
Expand All @@ -145,6 +180,7 @@ def test_dryrun(self, _) -> None:
app, {**self.cfg, "foo": "bar"}
)
scheduler_mock._validate.assert_called_once()
self.assertEqual(get_session_id(), "test_session_id")

def test_dryrun_env_variables(self, _) -> None:
scheduler_mock = MagicMock()
Expand Down
26 changes: 26 additions & 0 deletions torchx/util/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import uuid
from typing import Optional

_CURRENT_SESSION_ID: Optional[str] = None


def get_session_id_or_create_new() -> str:
global _CURRENT_SESSION_ID
if _CURRENT_SESSION_ID:
return _CURRENT_SESSION_ID
session_id = str(uuid.uuid4())
_CURRENT_SESSION_ID = session_id
return session_id


def get_session_id() -> Optional[str]:
return _CURRENT_SESSION_ID

0 comments on commit 3ec2ef9

Please sign in to comment.