diff --git a/torchx/runner/config.py b/torchx/runner/config.py index a7c0f4e8c..ba243c7b7 100644 --- a/torchx/runner/config.py +++ b/torchx/runner/config.py @@ -197,7 +197,15 @@ def _configparser() -> configparser.ConfigParser: def _get_scheduler(name: str) -> Scheduler: - schedulers = get_scheduler_factories() + schedulers = { + **get_scheduler_factories(), + **( + get_scheduler_factories( + group="torchx.schedulers.orchestrator", skip_defaults=True + ) + or {} + ), + } if name not in schedulers: raise ValueError( f"`{name}` is not a registered scheduler. Valid scheduler names: {schedulers.keys()}" diff --git a/torchx/runner/test/config_test.py b/torchx/runner/test/config_test.py index 1cd86dd41..959c3cf73 100644 --- a/torchx/runner/test/config_test.py +++ b/torchx/runner/test/config_test.py @@ -470,10 +470,22 @@ def test_dump_and_load_all_registered_schedulers(self) -> None: sfile = StringIO() dump(sfile) - for sched_name, sched in get_scheduler_factories().items(): + scheduler_factories = { + **get_scheduler_factories(), + **( + get_scheduler_factories( + group="torchx.schedulers.orchestrator", skip_defaults=True + ) + or {} + ), + } + + for sched_name, sched in scheduler_factories.items(): sfile.seek(0) # reset the file pos cfg = {} load(scheduler=sched_name, f=sfile, cfg=cfg) - for opt_name, _ in sched("test").run_opts(): - self.assertTrue(opt_name in cfg) + self.assertTrue( + opt_name in cfg, + f"missing {opt_name} in {sched} run opts with cfg {cfg}", + ) diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index 4fb47b8e8..23af81d4e 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -42,9 +42,11 @@ def run(*args: object, **kwargs: object) -> Scheduler: return run -def get_scheduler_factories() -> Dict[str, SchedulerFactory]: +def get_scheduler_factories( + group: str = "torchx.schedulers", skip_defaults: bool = False +) -> Dict[str, SchedulerFactory]: """ - get_scheduler_factories returns all the available schedulers names and the + get_scheduler_factories returns all the available schedulers names under `group` and the method to instantiate them. The first scheduler in the dictionary is used as the default scheduler. @@ -55,8 +57,9 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]: default_schedulers[scheduler] = _defer_load_scheduler(path) return load_group( - "torchx.schedulers", + group, default=default_schedulers, + skip_defaults=skip_defaults, ) diff --git a/torchx/schedulers/test/registry_test.py b/torchx/schedulers/test/registry_test.py index 951cf8e73..e133aafcf 100644 --- a/torchx/schedulers/test/registry_test.py +++ b/torchx/schedulers/test/registry_test.py @@ -22,6 +22,7 @@ def __call__( group: str, default: Dict[str, Any], ignore_missing: Optional[bool] = False, + skip_defaults: bool = False, ) -> Dict[str, Any]: return default diff --git a/torchx/util/entrypoints.py b/torchx/util/entrypoints.py index b2f0e1fe3..9da5626c4 100644 --- a/torchx/util/entrypoints.py +++ b/torchx/util/entrypoints.py @@ -51,8 +51,7 @@ def run(*args: object, **kwargs: object) -> object: # pyre-ignore-all-errors[3, 2] def load_group( - group: str, - default: Optional[Dict[str, Any]] = None, + group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False ): """ Loads all the entry points specified by ``group`` and returns @@ -72,6 +71,7 @@ def load_group( 1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")`` 1. ``load_group("food")`` -> ``None`` 1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")`` + 1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None`` If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn`` @@ -90,6 +90,8 @@ def load_group( entrypoints = metadata.entry_points().select(group=group) if len(entrypoints) == 0: + if skip_defaults: + return None return default eps = {} diff --git a/torchx/util/test/entrypoints_test.py b/torchx/util/test/entrypoints_test.py index efa3a4893..45c456c67 100644 --- a/torchx/util/test/entrypoints_test.py +++ b/torchx/util/test/entrypoints_test.py @@ -122,6 +122,11 @@ def test_load_group_with_default(self, _: MagicMock) -> None: self.assertEqual("barbaz", eps["foo"]()) self.assertEqual("foobar", eps["bar"]()) + eps = load_group( + "ep.grp.test.missing", {"foo": barbaz, "bar": foobar}, skip_defaults=True + ) + self.assertIsNone(eps) + @patch(_METADATA_EPS, return_value=_ENTRY_POINTS) def test_load_group_missing(self, _: MagicMock) -> None: with self.assertRaises(AttributeError):