Skip to content

Commit

Permalink
allow configurable scheduler load group
Browse files Browse the repository at this point in the history
Differential Revision: D67290464

Pull Request resolved: #992
  • Loading branch information
lgarg26 authored Dec 17, 2024
1 parent c1a195a commit 90884d7
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 9 deletions.
10 changes: 9 additions & 1 deletion torchx/runner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,15 @@ def _configparser() -> configparser.ConfigParser:


def _get_scheduler(name: str) -> Scheduler:
schedulers = get_scheduler_factories()
schedulers = {
**get_scheduler_factories(),
**(
get_scheduler_factories(
group="torchx.schedulers.orchestrator", skip_defaults=True
)
or {}
),
}
if name not in schedulers:
raise ValueError(
f"`{name}` is not a registered scheduler. Valid scheduler names: {schedulers.keys()}"
Expand Down
18 changes: 15 additions & 3 deletions torchx/runner/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,22 @@ def test_dump_and_load_all_registered_schedulers(self) -> None:
sfile = StringIO()
dump(sfile)

for sched_name, sched in get_scheduler_factories().items():
scheduler_factories = {
**get_scheduler_factories(),
**(
get_scheduler_factories(
group="torchx.schedulers.orchestrator", skip_defaults=True
)
or {}
),
}

for sched_name, sched in scheduler_factories.items():
sfile.seek(0) # reset the file pos
cfg = {}
load(scheduler=sched_name, f=sfile, cfg=cfg)

for opt_name, _ in sched("test").run_opts():
self.assertTrue(opt_name in cfg)
self.assertTrue(
opt_name in cfg,
f"missing {opt_name} in {sched} run opts with cfg {cfg}",
)
9 changes: 6 additions & 3 deletions torchx/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def run(*args: object, **kwargs: object) -> Scheduler:
return run


def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
def get_scheduler_factories(
group: str = "torchx.schedulers", skip_defaults: bool = False
) -> Dict[str, SchedulerFactory]:
"""
get_scheduler_factories returns all the available schedulers names and the
get_scheduler_factories returns all the available schedulers names under `group` and the
method to instantiate them.
The first scheduler in the dictionary is used as the default scheduler.
Expand All @@ -55,8 +57,9 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
default_schedulers[scheduler] = _defer_load_scheduler(path)

return load_group(
"torchx.schedulers",
group,
default=default_schedulers,
skip_defaults=skip_defaults,
)


Expand Down
1 change: 1 addition & 0 deletions torchx/schedulers/test/registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __call__(
group: str,
default: Dict[str, Any],
ignore_missing: Optional[bool] = False,
skip_defaults: bool = False,
) -> Dict[str, Any]:
return default

Expand Down
6 changes: 4 additions & 2 deletions torchx/util/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def run(*args: object, **kwargs: object) -> object:

# pyre-ignore-all-errors[3, 2]
def load_group(
group: str,
default: Optional[Dict[str, Any]] = None,
group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
):
"""
Loads all the entry points specified by ``group`` and returns
Expand All @@ -72,6 +71,7 @@ def load_group(
1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")``
1. ``load_group("food")`` -> ``None``
1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")``
1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None``
If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn``
Expand All @@ -90,6 +90,8 @@ def load_group(
entrypoints = metadata.entry_points().select(group=group)

if len(entrypoints) == 0:
if skip_defaults:
return None
return default

eps = {}
Expand Down
5 changes: 5 additions & 0 deletions torchx/util/test/entrypoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def test_load_group_with_default(self, _: MagicMock) -> None:
self.assertEqual("barbaz", eps["foo"]())
self.assertEqual("foobar", eps["bar"]())

eps = load_group(
"ep.grp.test.missing", {"foo": barbaz, "bar": foobar}, skip_defaults=True
)
self.assertIsNone(eps)

@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
def test_load_group_missing(self, _: MagicMock) -> None:
with self.assertRaises(AttributeError):
Expand Down

0 comments on commit 90884d7

Please sign in to comment.