Skip to content

Commit

Permalink
Make torchx scheduler opts support enums (#870)
Browse files Browse the repository at this point in the history
Summary:

Added Support for Enumerations in scheduler options.

This is a bit more generic as it takes in a creator function which converts a CfgVal to another type. There are some limitations on how general it can be based on how the typing is setup. This also makes it tricky to make a convenience function to handle only Enums.

Differential Revision: D55551233
  • Loading branch information
andywag authored and facebook-github-bot committed Apr 2, 2024
1 parent 7880bd7 commit 4227419
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
38 changes: 34 additions & 4 deletions torchx/schedulers/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import unittest
from datetime import datetime
from enum import Enum
from typing import Iterable, List, Mapping, Optional, TypeVar, Union
from unittest.mock import MagicMock, patch

Expand All @@ -36,6 +37,16 @@
T = TypeVar("T")


class EnumConfig(str, Enum):
option1 = "option1"
option2 = "option2"


class IntEnumConfig(int, Enum):
option1 = 1
option2 = 2


class SchedulerTest(unittest.TestCase):
class MockScheduler(Scheduler[T], WorkspaceMixin[None]):
def __init__(self, session_name: str) -> None:
Expand Down Expand Up @@ -78,6 +89,21 @@ def list(self) -> List[ListAppResponse]:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add("foo", type_=str, required=True, help="required option")
opts.add(
"bar",
type_=EnumConfig,
required=True,
help=f"Test Enum Config {[m.name for m in EnumConfig]}",
creator=lambda x: EnumConfig(x),
),
opts.add(
"ienum",
type_=IntEnumConfig,
required=False,
help=f"Test Enum Config {[m.name for m in IntEnumConfig]}",
creator=lambda x: IntEnumConfig(x),
),

return opts

def resolve_resource(self, resource: Union[str, Resource]) -> Resource:
Expand All @@ -92,12 +118,16 @@ def test_invalid_run_cfg(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
app_mock = MagicMock()

empty_cfg = {}
with self.assertRaises(InvalidRunConfigException):
empty_cfg = {}
scheduler_mock.submit(app_mock, empty_cfg)

bad_type_cfg = {"foo": 100}
with self.assertRaises(InvalidRunConfigException):
scheduler_mock.submit(app_mock, bad_type_cfg)

bad_type_cfg = {"foo": "here", "bar": "temp"}
with self.assertRaises(InvalidRunConfigException):
bad_type_cfg = {"foo": 100}
scheduler_mock.submit(app_mock, bad_type_cfg)

def test_submit_workspace(self) -> None:
Expand All @@ -110,7 +140,7 @@ def test_submit_workspace(self) -> None:

scheduler_mock = SchedulerTest.MockScheduler("test_session")

cfg = {"foo": "asdf"}
cfg = {"foo": "asdf", "bar": EnumConfig["option1"], "ienum": 1}
scheduler_mock.submit(app, cfg, workspace="some_workspace")
self.assertEqual(app.roles[0].image, "some_workspace")

Expand All @@ -131,7 +161,7 @@ def test_role_preproc_called(self) -> None:
app_mock = MagicMock()
app_mock.roles = [MagicMock()]

cfg = {"foo": "bar"}
cfg = {"foo": "bar", "bar": "option2"}
scheduler_mock.submit_dryrun(app_mock, cfg)
role_mock = app_mock.roles[0]
role_mock.pre_proc.assert_called_once()
Expand Down
31 changes: 23 additions & 8 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ class runopt:
opt_type: Type[CfgVal]
is_required: bool
help: str
creator: Optional[Callable[[CfgVal], CfgVal]] = None


class runopts:
Expand Down Expand Up @@ -793,13 +794,26 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
)

# check type (None matches all types)
if val is not None and not runopts.is_type(val, runopt.opt_type):
raise InvalidRunConfigException(
f"Run option: {cfg_key}, must be of type: {get_type_name(runopt.opt_type)},"
f" but was: {val} ({type(val).__name__})",
cfg_key,
cfg,
)
if val is not None:
if (
not runopts.is_type(val, runopt.opt_type)
and runopt.creator is not None
):
try:
val = runopt.creator(val)
except Exception as e:
raise InvalidRunConfigException(
f"Run option failed with error: {e}",
cfg_key,
cfg,
)
if not runopts.is_type(val, runopt.opt_type):
raise InvalidRunConfigException(
f"Run option: {cfg_key}, must be of type: {get_type_name(runopt.opt_type)},"
f" but was: {val} ({type(val).__name__})",
cfg_key,
cfg,
)

# not required and not set, set to default
if val is None:
Expand Down Expand Up @@ -892,6 +906,7 @@ def add(
help: str,
default: CfgVal = None,
required: bool = False,
creator: Optional[Callable[[CfgVal], CfgVal]] = None,
) -> None:
"""
Adds the ``config`` option with the given help string and ``default``
Expand All @@ -909,7 +924,7 @@ def add(
f" Given: {default} ({type(default).__name__})"
)

self._opts[cfg_key] = runopt(default, type_, required, help)
self._opts[cfg_key] = runopt(default, type_, required, help, creator)

def update(self, other: "runopts") -> None:
self._opts.update(other._opts)
Expand Down

0 comments on commit 4227419

Please sign in to comment.