diff --git a/scopesim/effects/surface_list.py b/scopesim/effects/surface_list.py index d17322a2..252c5203 100644 --- a/scopesim/effects/surface_list.py +++ b/scopesim/effects/surface_list.py @@ -156,7 +156,7 @@ def plot(self, which="x", wavelength=None, *, axes=None, **kwargs): """ if axes is None: - fig, axes = figure_factory(len(which), 1) + fig, axes = figure_factory(len(which), 1, iterable_axes=True) else: fig = axes.figure self._axes_guard(which, axes) diff --git a/scopesim/effects/ter_curves.py b/scopesim/effects/ter_curves.py index 51058c7e..cfd9a548 100644 --- a/scopesim/effects/ter_curves.py +++ b/scopesim/effects/ter_curves.py @@ -186,8 +186,7 @@ def plot(self, which="x", wavelength=None, *, axes=None, **kwargs): """ if axes is None: - fig, axes = figure_factory(len(which), 1) - # figsize=(10, 5) + fig, axes = figure_factory(len(which), 1, iterable_axes=True) else: fig = axes.figure _guard_plot_axes(which, axes) @@ -627,8 +626,7 @@ def plot(self, which="x", wavelength=None, *, axes=None, **kwargs): """ if axes is None: - fig, axes = figure_factory(len(which), 1) - # figsize=(10, 5) + fig, axes = figure_factory(len(which), 1, iterable_axes=True) else: fig = axes.figure _guard_plot_axes(which, axes) diff --git a/scopesim/utils.py b/scopesim/utils.py index 19f0d6db..d653b40c 100644 --- a/scopesim/utils.py +++ b/scopesim/utils.py @@ -1032,7 +1032,10 @@ def close_loop(iterable: Iterable) -> Generator: def figure_factory(nrows=1, ncols=1, **kwargs): """Default way to init fig and ax, to easily modify later.""" + iterable_axes = kwargs.pop("iterable_axes", False) fig, ax = plt.subplots(nrows, ncols, **kwargs) + if iterable_axes and not isinstance(ax, Iterable): + ax = (ax,) return fig, ax