diff --git a/torchx/schedulers/test/api_test.py b/torchx/schedulers/test/api_test.py index 7f8ecf8ac..b6a2206ee 100644 --- a/torchx/schedulers/test/api_test.py +++ b/torchx/schedulers/test/api_test.py @@ -10,6 +10,7 @@ import unittest from datetime import datetime +from enum import Enum from typing import Iterable, List, Mapping, Optional, TypeVar, Union from unittest.mock import MagicMock, patch @@ -36,6 +37,16 @@ T = TypeVar("T") +class EnumConfig(str, Enum): + option1 = "option1" + option2 = "option2" + + +class IntEnumConfig(int, Enum): + option1 = 1 + option2 = 2 + + class SchedulerTest(unittest.TestCase): class MockScheduler(Scheduler[T], WorkspaceMixin[None]): def __init__(self, session_name: str) -> None: @@ -78,6 +89,21 @@ def list(self) -> List[ListAppResponse]: def _run_opts(self) -> runopts: opts = runopts() opts.add("foo", type_=str, required=True, help="required option") + opts.add( + "bar", + type_=EnumConfig, + required=True, + help=f"Test Enum Config {[m.name for m in EnumConfig]}", + creator=lambda x: EnumConfig(x), + ), + opts.add( + "ienum", + type_=IntEnumConfig, + required=False, + help=f"Test Enum Config {[m.name for m in IntEnumConfig]}", + creator=lambda x: IntEnumConfig(x), + ), + return opts def resolve_resource(self, resource: Union[str, Resource]) -> Resource: @@ -92,12 +118,16 @@ def test_invalid_run_cfg(self) -> None: scheduler_mock = SchedulerTest.MockScheduler("test_session") app_mock = MagicMock() + empty_cfg = {} with self.assertRaises(InvalidRunConfigException): - empty_cfg = {} scheduler_mock.submit(app_mock, empty_cfg) + bad_type_cfg = {"foo": 100} + with self.assertRaises(InvalidRunConfigException): + scheduler_mock.submit(app_mock, bad_type_cfg) + + bad_type_cfg = {"foo": "here", "bar": "temp"} with self.assertRaises(InvalidRunConfigException): - bad_type_cfg = {"foo": 100} scheduler_mock.submit(app_mock, bad_type_cfg) def test_submit_workspace(self) -> None: @@ -110,7 +140,7 @@ def test_submit_workspace(self) -> None: scheduler_mock = SchedulerTest.MockScheduler("test_session") - cfg = {"foo": "asdf"} + cfg = {"foo": "asdf", "bar": EnumConfig["option1"], "ienum": 1} scheduler_mock.submit(app, cfg, workspace="some_workspace") self.assertEqual(app.roles[0].image, "some_workspace") @@ -131,7 +161,7 @@ def test_role_preproc_called(self) -> None: app_mock = MagicMock() app_mock.roles = [MagicMock()] - cfg = {"foo": "bar"} + cfg = {"foo": "bar", "bar": "option2"} scheduler_mock.submit_dryrun(app_mock, cfg) role_mock = app_mock.roles[0] role_mock.pre_proc.assert_called_once() diff --git a/torchx/specs/api.py b/torchx/specs/api.py index cb3a174e3..28c43f9b3 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -702,6 +702,7 @@ class runopt: opt_type: Type[CfgVal] is_required: bool help: str + creator: Optional[Callable[[CfgVal], CfgVal]] = None class runopts: @@ -793,13 +794,26 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]: ) # check type (None matches all types) - if val is not None and not runopts.is_type(val, runopt.opt_type): - raise InvalidRunConfigException( - f"Run option: {cfg_key}, must be of type: {get_type_name(runopt.opt_type)}," - f" but was: {val} ({type(val).__name__})", - cfg_key, - cfg, - ) + if val is not None: + if ( + not runopts.is_type(val, runopt.opt_type) + and runopt.creator is not None + ): + try: + val = runopt.creator(val) + except Exception as e: + raise InvalidRunConfigException( + f"Run option failed with error: {e}", + cfg_key, + cfg, + ) + if not runopts.is_type(val, runopt.opt_type): + raise InvalidRunConfigException( + f"Run option: {cfg_key}, must be of type: {get_type_name(runopt.opt_type)}," + f" but was: {val} ({type(val).__name__})", + cfg_key, + cfg, + ) # not required and not set, set to default if val is None: @@ -892,6 +906,7 @@ def add( help: str, default: CfgVal = None, required: bool = False, + creator: Optional[Callable[[CfgVal], CfgVal]] = None, ) -> None: """ Adds the ``config`` option with the given help string and ``default`` @@ -909,7 +924,7 @@ def add( f" Given: {default} ({type(default).__name__})" ) - self._opts[cfg_key] = runopt(default, type_, required, help) + self._opts[cfg_key] = runopt(default, type_, required, help, creator) def update(self, other: "runopts") -> None: self._opts.update(other._opts)