Skip to content

Commit

Permalink
Add new test for concurrent futures for TorchX Role
Browse files Browse the repository at this point in the history
Differential Revision: D63046717

Pull Request resolved: #957
  • Loading branch information
andywag authored Sep 23, 2024
1 parent 710f654 commit 94ac896
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
# pyre-strict

import asyncio
import concurrent
import os
import time
import unittest
from dataclasses import asdict
from typing import Dict, List, Mapping, Union
from typing import Dict, List, Mapping, Tuple, Union
from unittest.mock import MagicMock

import torchx.specs.named_resources_aws as named_resources_aws
Expand Down Expand Up @@ -299,6 +300,33 @@ async def update(value: str, time_seconds: int) -> str:
self.assertEqual("base", default.image)
self.assertEqual("nentry", default.entrypoint)

def test_concurrent_override_role(self) -> None:

def delay(value: Tuple[str, str], time_seconds: int) -> Tuple[str, str]:
time.sleep(time_seconds)
return value

with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
launcher_fbpkg_future: concurrent.futures.Future = executor.submit(
delay, ("value1", "value2"), 2
)

def get_image() -> str:
concurrent.futures.wait([launcher_fbpkg_future], 3)
return launcher_fbpkg_future.result()[0]

def get_entrypoint() -> str:
concurrent.futures.wait([launcher_fbpkg_future], 3)
return launcher_fbpkg_future.result()[1]

default = Role(
"foobar",
"torch",
overrides={"image": get_image, "entrypoint": get_entrypoint},
)
self.assertEqual("value1", default.image)
self.assertEqual("value2", default.entrypoint)


class AppHandleTest(unittest.TestCase):
def test_parse_malformed_app_handles(self) -> None:
Expand Down

0 comments on commit 94ac896

Please sign in to comment.