From 0c8d57739c131e269663f201de927e4106a2f04b Mon Sep 17 00:00:00 2001 From: Andy Wagner Date: Tue, 17 Sep 2024 15:46:24 -0700 Subject: [PATCH] Adding Override Option for TorchX Role (#956) Summary: Pull Request resolved: https://github.com/pytorch/torchx/pull/956 Adds a generic way to override the internal values of the Role. Allows async overriding of role values and enable Async Packaging Reviewed By: Sanjay-Ganeshan Differential Revision: D62591176 --- torchx/specs/api.py | 20 ++++++++++++++++++++ torchx/specs/test/api_test.py | 23 +++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 18b6fd588..489e3498a 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -7,7 +7,9 @@ # pyre-strict +import asyncio import copy +import inspect import json import re import typing @@ -370,6 +372,24 @@ class Role: mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field( default_factory=list ) + overrides: Dict[str, Any] = field(default_factory=dict) + + # pyre-ignore + def __getattribute__(self, attrname: str) -> Any: + if attrname == "overrides": + return super().__getattribute__(attrname) + try: + ov = super().__getattribute__("overrides") + except AttributeError: + ov = {} + if attrname in ov: + if inspect.isawaitable(ov[attrname]): + result = asyncio.get_event_loop().run_until_complete(ov[attrname]) + else: + result = ov[attrname]() + setattr(self, attrname, result) + del ov[attrname] + return super().__getattribute__(attrname) def pre_proc( self, diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 73308de16..5b3503bd0 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -7,6 +7,7 @@ # pyre-strict +import asyncio import os import time import unittest @@ -276,6 +277,28 @@ def test_retry_policies(self) -> None: }, ) + def test_override_role(self) -> None: + default = Role( + "foobar", + "torch", + overrides={"image": lambda: "base", "entrypoint": lambda: "nentry"}, + ) + self.assertEqual("base", default.image) + self.assertEqual("nentry", default.entrypoint) + + def test_async_override_role(self) -> None: + async def update(value: str, time_seconds: int) -> str: + await asyncio.sleep(time_seconds) + return value + + default = Role( + "foobar", + "torch", + overrides={"image": update("base", 1), "entrypoint": update("nentry", 2)}, + ) + self.assertEqual("base", default.image) + self.assertEqual("nentry", default.entrypoint) + class AppHandleTest(unittest.TestCase): def test_parse_malformed_app_handles(self) -> None: