diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 5b3503bd0..389102547 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -8,11 +8,12 @@ # pyre-strict import asyncio +import concurrent import os import time import unittest from dataclasses import asdict -from typing import Dict, List, Mapping, Union +from typing import Dict, List, Mapping, Tuple, Union from unittest.mock import MagicMock import torchx.specs.named_resources_aws as named_resources_aws @@ -299,6 +300,33 @@ async def update(value: str, time_seconds: int) -> str: self.assertEqual("base", default.image) self.assertEqual("nentry", default.entrypoint) + def test_concurrent_override_role(self) -> None: + + def delay(value: Tuple[str, str], time_seconds: int) -> Tuple[str, str]: + time.sleep(time_seconds) + return value + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + launcher_fbpkg_future: concurrent.futures.Future = executor.submit( + delay, ("value1", "value2"), 2 + ) + + def get_image() -> str: + concurrent.futures.wait([launcher_fbpkg_future], 3) + return launcher_fbpkg_future.result()[0] + + def get_entrypoint() -> str: + concurrent.futures.wait([launcher_fbpkg_future], 3) + return launcher_fbpkg_future.result()[1] + + default = Role( + "foobar", + "torch", + overrides={"image": get_image, "entrypoint": get_entrypoint}, + ) + self.assertEqual("value1", default.image) + self.assertEqual("value2", default.entrypoint) + class AppHandleTest(unittest.TestCase): def test_parse_malformed_app_handles(self) -> None: