diff --git a/plantseg/core/zoo.py b/plantseg/core/zoo.py index 9e1699fc..2e0bffb6 100644 --- a/plantseg/core/zoo.py +++ b/plantseg/core/zoo.py @@ -342,7 +342,7 @@ def get_model_by_id(self, model_id: str): https://bioimage-io.github.io/collection-bioimage-io/rdfs/10.5281/zenodo.8401064/8429203/rdf.yaml """ - if not self.models_bioimageio: + if not hasattr(self, 'models_bioimageio'): self.refresh_bioimageio_zoo_urls() if model_id not in self.models_bioimageio.index: @@ -458,13 +458,13 @@ def _is_plantseg_model(self, collection_entry: dict) -> bool: def get_bioimageio_zoo_plantseg_model_names(self) -> list[str]: """Return a list of model names in the BioImage.IO Model Zoo tagged with 'plantseg'.""" - if not self.models_bioimageio: + if not hasattr(self, 'models_bioimageio'): self.refresh_bioimageio_zoo_urls() return sorted(model_zoo.models_bioimageio[model_zoo.models_bioimageio["supported"]].index.to_list()) def get_bioimageio_zoo_all_model_names(self) -> list[str]: """Return a list of all model names in the BioImage.IO Model Zoo.""" - if not self.models_bioimageio: + if not hasattr(self, 'models_bioimageio'): self.refresh_bioimageio_zoo_urls() return sorted(model_zoo.models_bioimageio.index.to_list()) diff --git a/tests/core/test_zoo.py b/tests/core/test_zoo.py index 2e0788a2..8a1b92f5 100644 --- a/tests/core/test_zoo.py +++ b/tests/core/test_zoo.py @@ -47,6 +47,8 @@ def test_model_output_normalisation(self, model_name): class TestBioImageIOModelZoo: """Test the BioImage.IO model zoo""" + model_zoo.refresh_bioimageio_zoo_urls() + @pytest.mark.parametrize("model_id", MODEL_IDS) def test_get_model_by_id(self, model_id): """Try to load a model from the BioImage.IO model zoo by ID.""" @@ -62,9 +64,3 @@ def test_halo_computation_for_bioimageio_model(self, model_id): model, _, _ = model_zoo.get_model_by_id(model_id) halo = model_zoo.compute_halo(model) assert halo == 44 - - def test_models_bioimageio(self): - """`model_zoo` has no `models_bioimageio` attribute until `.refresh_bioimageio_zoo_urls()` is called.""" - assert not hasattr(model_zoo, 'models_bioimageio') - model_zoo.refresh_bioimageio_zoo_urls() - assert hasattr(model_zoo, 'models_bioimageio')