Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create session id and use it in table 'Pytorch Elastic Tsm Log' #953

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchx/runner/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Dict, Optional, Type

from torchx.runner.events.handlers import get_logging_handler
from torchx.util.session import get_session_id_or_create_new

from .api import SourceType, TorchxEvent # noqa F401

Expand Down Expand Up @@ -136,7 +137,7 @@ def _generate_torchx_event(
workspace: Optional[str] = None,
) -> TorchxEvent:
return TorchxEvent(
session=app_id or "",
session=get_session_id_or_create_new(),
scheduler=scheduler,
api=api,
app_id=app_id,
Expand Down
2 changes: 1 addition & 1 deletion torchx/runner/events/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TorchxEvent:
The class represents the event produced by ``torchx.runner`` api calls.
Arguments:
session: Session id that was used to execute request.
session: Session id of the current run
scheduler: Scheduler that is used to execute request
api: Api name
app_id: Unique id that is set by the underlying scheduler
Expand Down
25 changes: 18 additions & 7 deletions torchx/runner/events/test/lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
TorchxEvent,
)

SESSION_ID = "123"


class TorchxEventLibTest(unittest.TestCase):
def assert_event(
Expand All @@ -44,14 +46,14 @@ def test_get_or_create_logger(self, logging_handler_mock: MagicMock) -> None:
def test_event_created(self) -> None:
test_metadata = {"test_key": "test_value"}
event = TorchxEvent(
session="test_session",
session=SESSION_ID,
scheduler="test_scheduler",
api="test_api",
app_image="test_app_image",
app_metadata=test_metadata,
workspace="test_workspace",
)
self.assertEqual("test_session", event.session)
self.assertEqual(SESSION_ID, event.session)
self.assertEqual("test_scheduler", event.scheduler)
self.assertEqual("test_api", event.api)
self.assertEqual("test_app_image", event.app_image)
Expand All @@ -76,6 +78,7 @@ def test_event_deser(self) -> None:


@patch("torchx.runner.events.record")
@patch("torchx.runner.events.get_session_id_or_create_new")
class LogEventTest(unittest.TestCase):
def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> None:
self.assertEqual(expected.session, actual.session)
Expand All @@ -86,7 +89,10 @@ def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> Non
self.assertEqual(expected.workspace, actual.workspace)
self.assertEqual(expected.app_metadata, actual.app_metadata)

def test_create_context(self, _) -> None:
def test_create_context(
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
) -> None:
get_session_id_or_create_new_mock.return_value = SESSION_ID
test_dict = {"test_key": "test_value"}
cfg = json.dumps(test_dict)
context = log_event(
Expand All @@ -99,7 +105,7 @@ def test_create_context(self, _) -> None:
workspace="test_workspace",
)
expected_torchx_event = TorchxEvent(
"test_app_id",
SESSION_ID,
"local",
"test_call",
"test_app_id",
Expand All @@ -111,7 +117,10 @@ def test_create_context(self, _) -> None:

self.assert_torchx_event(expected_torchx_event, context._torchx_event)

def test_record_event(self, record_mock: MagicMock) -> None:
def test_record_event(
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
) -> None:
get_session_id_or_create_new_mock.return_value = SESSION_ID
test_dict = {"test_key": "test_value"}
cfg = json.dumps(test_dict)
with log_event(
Expand All @@ -126,7 +135,7 @@ def test_record_event(self, record_mock: MagicMock) -> None:
pass

expected_torchx_event = TorchxEvent(
"test_app_id",
SESSION_ID,
"local",
"test_call",
"test_app_id",
Expand All @@ -139,7 +148,9 @@ def test_record_event(self, record_mock: MagicMock) -> None:
)
self.assert_torchx_event(expected_torchx_event, ctx._torchx_event)

def test_record_event_with_exception(self, record_mock: MagicMock) -> None:
def test_record_event_with_exception(
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
) -> None:
cfg = json.dumps({"test_key": "test_value"})
with self.assertRaises(RuntimeError):
with log_event("test_call", "local", "test_app_id", cfg) as ctx:
Expand Down
48 changes: 45 additions & 3 deletions torchx/runner/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_full_path(name: str) -> str:
return os.path.join(os.path.dirname(__file__), "resource", name)


@patch("torchx.runner.api.log_event")
@patch("torchx.runner.events.record")
class RunnerTest(TestWithTmpDir):
def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -104,7 +104,35 @@ def test_validate_invalid_replicas(self, _) -> None:
with self.assertRaises(ValueError):
runner.run(app, scheduler="local_dir")

def test_run(self, _) -> None:
def test_session_id(self, record_mock: MagicMock) -> None:
test_file = self.tmpdir / "test_file"

with self.get_runner() as runner:
self.assertEqual(1, len(runner.scheduler_backends()))
role = Role(
name="touch",
image=str(self.tmpdir),
resource=resource.SMALL,
entrypoint="touch.sh",
args=[str(test_file)],
)
app = AppDef("name", roles=[role])

app_handle_1 = runner.run(app, scheduler="local_dir", cfg=self.cfg)
none_throws(runner.wait(app_handle_1, wait_interval=0.1))

app_handle_2 = runner.run(app, scheduler="local_dir", cfg=self.cfg)
none_throws(runner.wait(app_handle_2, wait_interval=0.1))

from torchx.util.session import CURRENT_SESSION_ID

self.assertIsNotNone(CURRENT_SESSION_ID)
record_mock.assert_called()
for i in range(record_mock.call_count):
event = record_mock.call_args_list[i].args[0]
self.assertEqual(event.session, CURRENT_SESSION_ID)

def test_run(self, record_mock: MagicMock) -> None:
test_file = self.tmpdir / "test_file"

with self.get_runner() as runner:
Expand All @@ -121,8 +149,15 @@ def test_run(self, _) -> None:
app_handle = runner.run(app, scheduler="local_dir", cfg=self.cfg)
app_status = none_throws(runner.wait(app_handle, wait_interval=0.1))
self.assertEqual(AppState.SUCCEEDED, app_status.state)
from torchx.util.session import CURRENT_SESSION_ID

def test_dryrun(self, _) -> None:
self.assertIsNotNone(CURRENT_SESSION_ID)
record_mock.assert_called()
for i in range(record_mock.call_count):
event = record_mock.call_args_list[i].args[0]
self.assertEqual(event.session, CURRENT_SESSION_ID)

def test_dryrun(self, record_mock: MagicMock) -> None:
scheduler_mock = MagicMock()
scheduler_mock.run_opts.return_value.resolve.return_value = {
**self.cfg,
Expand All @@ -145,6 +180,13 @@ def test_dryrun(self, _) -> None:
app, {**self.cfg, "foo": "bar"}
)
scheduler_mock._validate.assert_called_once()
from torchx.util.session import CURRENT_SESSION_ID

self.assertIsNotNone(CURRENT_SESSION_ID)
record_mock.assert_called()
for i in range(record_mock.call_count):
event = record_mock.call_args_list[i].args[0]
self.assertEqual(event.session, CURRENT_SESSION_ID)

def test_dryrun_env_variables(self, _) -> None:
scheduler_mock = MagicMock()
Expand Down
26 changes: 26 additions & 0 deletions torchx/util/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import uuid
from typing import Optional

CURRENT_SESSION_ID: Optional[str] = None


def get_session_id_or_create_new() -> str:
"""
Returns the current session ID, or creates a new one if none exists.
The session ID remains the same as long as it is in the same process.
"""
global CURRENT_SESSION_ID
if CURRENT_SESSION_ID:
return CURRENT_SESSION_ID
session_id = str(uuid.uuid4())
CURRENT_SESSION_ID = session_id
return session_id
Loading