From fd68d872627eb69f979787c757a67697313b2a2a Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 24 Oct 2024 16:41:07 -0400 Subject: [PATCH] raise ValueError if class_alias lookup returns more than 1 step --- src/stpipe/utilities.py | 20 ++++++++- tests/test_utilities.py | 98 +++++++++++++++++++++++++++++++++-------- 2 files changed, 97 insertions(+), 21 deletions(-) diff --git a/src/stpipe/utilities.py b/src/stpipe/utilities.py index 6a6b9551..5ba436c0 100644 --- a/src/stpipe/utilities.py +++ b/src/stpipe/utilities.py @@ -33,13 +33,29 @@ def resolve_step_class_alias(name): else: scope, class_name = None, name + # track all found steps keyed by package name + found_class_names = {} for info in entry_points.get_steps(): if scope and info.package_name != scope: continue if info.class_alias is not None and class_name == info.class_alias: - return info.class_name + found_class_names[info.package_name] = info - return name + if not found_class_names: + return name + + if len(found_class_names) == 1: + return found_class_names.popitem()[1].class_name + + # class alias resolved to several possible steps + scopes = list(found_class_names.keys()) + msg = ( + f"class alias {name} matched more than 1 step. Please provide " + "the package name along with the step name. One of:\n" + ) + for scope in scopes: + msg += f" {scope}::{name}\n" + raise ValueError(msg) def import_class(full_name, subclassof=object, config_file=None): diff --git a/tests/test_utilities.py b/tests/test_utilities.py index cc409b65..9ca34e64 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -56,39 +56,99 @@ def test_import_func_no_module(): import_func("foo") -@pytest.mark.parametrize( - "name, resolve", - ( - ("foo_step", True), - ("stpipe::foo_step", True), - ("some_other_package::foo_step", False), - ), -) -def test_class_alias_lookup(name, resolve, monkeypatch): +@pytest.fixture() +def mock_entry_points(monkeypatch, request): # as the test class above isn't registered via an entry point # we mock the entry points here class FakeDist: - name = "stpipe" - version = "dev" + def __init__(self, name): + self.name = name + self.version = "dev" class FakeEntryPoint: - dist = FakeDist() + def __init__(self, dist_name, steps): + self.dist = FakeDist(dist_name) + self.steps = steps def load(self): def loader(): - return [("Foo", "foo_step", False)] + return self.steps return loader def fake_entrypoints(group=None): - return [FakeEntryPoint()] + return [FakeEntryPoint(k, v) for k, v in request.param.items()] import importlib_metadata monkeypatch.setattr(importlib_metadata, "entry_points", fake_entrypoints) + yield + - resolved_name = resolve_step_class_alias(name) - if resolve: - assert resolved_name == Foo.__name__ - else: - assert resolved_name == name +@pytest.mark.parametrize("name", ("foo_step", "stpipe::foo_step")) +@pytest.mark.parametrize( + "mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True +) +def test_class_alias_lookup(name, mock_entry_points): + """ + Test that a step name can be resolved if either: + - only a single step is found that matches + - a step is found and a valid package name was provided + """ + assert resolve_step_class_alias(name) == "Foo" + + +@pytest.mark.parametrize("name", ("bar_step", "other_package::foo_step")) +@pytest.mark.parametrize( + "mock_entry_points", [{"stpipe": [("Foo", "foo_step", False)]}], indirect=True +) +def test_class_alias_lookup_fallthrough(name, mock_entry_points): + """ + Test that passing in an unknown class alias or an alias scoped + to a different package falls through to returning the unresolved + class_alias (to match previous behavior). + """ + assert resolve_step_class_alias(name) == name + + +@pytest.mark.parametrize("name", ("aaa::foo_step", "zzz::foo_step")) +@pytest.mark.parametrize( + "mock_entry_points", + [ + { + "aaa": [("Foo", "foo_step", False)], + "zzz": [("Foo", "foo_step", False)], + } + ], + indirect=True, +) +def test_class_alias_lookup_scoped(name, mock_entry_points): + """ + Test the lookup succeeds if more than 1 package + provides a matching step name but the "scope" (package name) + is provided on lookup. + """ + assert resolve_step_class_alias(name) == "Foo" + + +@pytest.mark.parametrize( + "mock_entry_points", + [ + { + "aaa": [("Foo", "foo_step", False)], + "zzz": [("Foo", "foo_step", False)], + } + ], + indirect=True, +) +def test_class_alias_lookup_conflict(mock_entry_points): + """ + Test that an ambiguous lookup (a class alias that resolves + to more than 1 step from different packages) results in + an error. + When the package name is provided, tes + """ + with pytest.raises(ValueError) as err: + resolve_step_class_alias("foo_step") + assert err.match("aaa::foo_step") + assert err.match("zzz::foo_step")