diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 18b6fd588..489e3498a 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -7,7 +7,9 @@ # pyre-strict +import asyncio import copy +import inspect import json import re import typing @@ -370,6 +372,24 @@ class Role: mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field( default_factory=list ) + overrides: Dict[str, Any] = field(default_factory=dict) + + # pyre-ignore + def __getattribute__(self, attrname: str) -> Any: + if attrname == "overrides": + return super().__getattribute__(attrname) + try: + ov = super().__getattribute__("overrides") + except AttributeError: + ov = {} + if attrname in ov: + if inspect.isawaitable(ov[attrname]): + result = asyncio.get_event_loop().run_until_complete(ov[attrname]) + else: + result = ov[attrname]() + setattr(self, attrname, result) + del ov[attrname] + return super().__getattribute__(attrname) def pre_proc( self, diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 73308de16..5b3503bd0 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -7,6 +7,7 @@ # pyre-strict +import asyncio import os import time import unittest @@ -276,6 +277,28 @@ def test_retry_policies(self) -> None: }, ) + def test_override_role(self) -> None: + default = Role( + "foobar", + "torch", + overrides={"image": lambda: "base", "entrypoint": lambda: "nentry"}, + ) + self.assertEqual("base", default.image) + self.assertEqual("nentry", default.entrypoint) + + def test_async_override_role(self) -> None: + async def update(value: str, time_seconds: int) -> str: + await asyncio.sleep(time_seconds) + return value + + default = Role( + "foobar", + "torch", + overrides={"image": update("base", 1), "entrypoint": update("nentry", 2)}, + ) + self.assertEqual("base", default.image) + self.assertEqual("nentry", default.entrypoint) + class AppHandleTest(unittest.TestCase): def test_parse_malformed_app_handles(self) -> None: