From 3faa8d3ac50616ee1af70ec1026558f871eef35f Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Tue, 26 Mar 2024 11:24:12 +0000 Subject: [PATCH] Ruff (no formatter) (#325) * Add basic ruff config and use ruff for import sortting * More ruff * ruff pyupgrade * no pre-commit flake8 * clean up template * More ruff fixes * ruff C4 * ruff PT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * No X | Y in isinstance and no ruff isort * Ruff ICN * G and INP * Ruff Q: Convert all ' to " * Ruff RSE * Ruff TID * NPY and RUF rules * Enable PTH * Ruff RET rule * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/changelog_filter.py | 22 ---- .ruff.toml | 36 +++++- conftest.py | 7 +- dkist/__init__.py | 6 +- dkist/_dev/scm_version.py | 4 +- dkist/config/__init__.py | 6 +- dkist/conftest.py | 54 ++++----- dkist/dataset/dataset.py | 11 +- dkist/dataset/loader.py | 12 +- dkist/dataset/tests/test_dataset.py | 53 ++++----- dkist/dataset/tests/test_load_dataset.py | 14 +-- dkist/dataset/tests/test_plotting.py | 4 +- dkist/dataset/tests/test_tiled_dataset.py | 30 ++--- dkist/dataset/tiled_dataset.py | 4 +- dkist/dataset/utils.py | 78 ++++++------- dkist/io/__init__.py | 4 +- dkist/io/asdf/converters/file_manager.py | 3 +- dkist/io/asdf/converters/models.py | 15 +-- dkist/io/asdf/tests/test_dataset.py | 10 +- dkist/io/asdf/tests/test_tiled_dataset.py | 2 +- dkist/io/dask_utils.py | 2 +- dkist/io/file_manager.py | 21 ++-- dkist/io/loaders.py | 12 +- dkist/io/tests/test_file_manager.py | 36 +++--- dkist/io/tests/test_fits.py | 2 +- dkist/logger.py | 2 +- dkist/net/__init__.py | 4 +- dkist/net/attr_walker.py | 84 +++++++------- dkist/net/attrs.py | 16 +-- dkist/net/attrs_values.py | 6 +- dkist/net/client.py | 23 ++-- dkist/net/globus/auth.py | 40 +++---- dkist/net/globus/endpoints.py | 30 ++--- dkist/net/globus/tests/__init__.py | 0 dkist/net/globus/tests/conftest.py | 4 +- dkist/net/globus/tests/test_auth.py | 8 +- dkist/net/globus/tests/test_endpoints.py | 16 ++- dkist/net/globus/tests/test_transfer.py | 81 +++++++------- dkist/net/globus/transfer.py | 35 +++--- dkist/net/helpers.py | 13 ++- dkist/net/tests/__init__.py | 0 dkist/net/tests/conftest.py | 58 +++++----- dkist/net/tests/strategies.py | 4 +- dkist/net/tests/test_attr_walker.py | 43 ++++---- dkist/net/tests/test_attrs.py | 4 +- dkist/net/tests/test_attrs_values.py | 6 +- dkist/net/tests/test_client.py | 27 ++--- dkist/net/tests/test_helpers.py | 44 ++++---- dkist/tests/generate_aia_dataset.py | 30 +++-- dkist/tests/generate_eit_test_dataset.py | 14 +-- dkist/utils/_model_to_graphviz.py | 8 +- dkist/utils/inventory.py | 21 ++-- dkist/utils/sysinfo.py | 18 +-- dkist/utils/tests/test_inventory.py | 2 +- dkist/wcs/models.py | 67 ++++++------ dkist/wcs/tests/__init__.py | 0 .../wcs/tests/test_coupled_compound_model.py | 20 ++-- dkist/wcs/tests/test_models.py | 59 +++++----- docs/conf.py | 103 +++++++++--------- 59 files changed, 665 insertions(+), 673 deletions(-) delete mode 100644 .github/changelog_filter.py create mode 100644 dkist/net/globus/tests/__init__.py create mode 100644 dkist/net/tests/__init__.py create mode 100644 dkist/wcs/tests/__init__.py diff --git a/.github/changelog_filter.py b/.github/changelog_filter.py deleted file mode 100644 index beb07d63..00000000 --- a/.github/changelog_filter.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -A pandoc filter which returns only the content between the start of the file -and the second top level heading. -""" -import io -import sys -import json - -if __name__ == '__main__': - input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') - source = input_stream.read() - doc = json.loads(source) - - output_blocks = [] - for block in doc["blocks"]: - if output_blocks and block["t"] == "Header" and block["c"][0] == 1: - break - print(block["c"], file=sys.stderr) - output_blocks.append(block) - - doc["blocks"] = output_blocks - sys.stdout.write(json.dumps(doc)) diff --git a/.ruff.toml b/.ruff.toml index 54746ed9..4b57429d 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -5,10 +5,28 @@ exclude = [ "__pycache__", "build", "dkist/version.py", + "dkist/_dev/", ] [lint] -select = ["E", "F"] +select = [ + "F", + "E", + "W", + "UP", + "C4", + "ICN", + "G", + "INP", + "PT", + "Q", + "RSE", + "RET", + "TID", + "PTH", + "NPY", + "RUF", +] extend-ignore = [ # pycodestyle (E, W) "E501", # LineTooLong # TODO! fix @@ -25,6 +43,17 @@ extend-ignore = [ "PT007", # Parametrize should be lists of tuples # TODO! fix "PT011", # Too broad exception assert # TODO! fix "PT023", # Always use () on pytest decorators + # pyupgrade + "UP038", # Use | in isinstance - not compatible with models and is slower + # Returns (RET) + "RET502", # Do not implicitly return None in function able to return non-None value + "RET503", # Missing explicit return at the end of function able to return non-None value + # Pathlib (PTH) + "PTH123", # open() should be replaced by Path.open() + # Ruff + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` + "RUF013", # PEP 484 prohibits implicit `Optional` + "RUF015", # Prefer `next(iter(...))` over single element slice ] [lint.per-file-ignores] @@ -39,6 +68,11 @@ extend-ignore = [ ] "__init__.py" = ["E402", "F401", "F403"] "test_*.py" = ["B011", "D", "E402", "PGH001", "S101"] +"dkist/logger.py" = ["PTH"] + +[lint.flake8-import-conventions.extend-aliases] +"astropy.units" = "u" +"sunpy.net.attrs" = "a" [lint.pydocstyle] convention = "numpy" diff --git a/conftest.py b/conftest.py index 6bbc74e6..a17bd8ed 100644 --- a/conftest.py +++ b/conftest.py @@ -1,9 +1,10 @@ -import matplotlib +import matplotlib as mpl -matplotlib.use("Agg") +mpl.use("Agg") def pytest_configure(config): # pre-cache the IERS file for astropy to prevent downloads # which will cause errors with remote_data off - from astropy.utils.iers import IERS_Auto; IERS_Auto.open() + from astropy.utils.iers import IERS_Auto + IERS_Auto.open() diff --git a/dkist/__init__.py b/dkist/__init__.py index 108cbbe3..afc2c99f 100644 --- a/dkist/__init__.py +++ b/dkist/__init__.py @@ -14,7 +14,7 @@ __version__ = "unknown" -__all__ = ['TiledDataset', 'Dataset', 'load_dataset', 'system_info'] +__all__ = ["TiledDataset", "Dataset", "load_dataset", "system_info"] def write_default_config(overwrite=False): @@ -30,5 +30,5 @@ def write_default_config(overwrite=False): # Do internal imports last (so logger etc is initialised) -from dkist.dataset import Dataset, TiledDataset, load_dataset # noqa -from dkist.utils.sysinfo import system_info # noqa +from dkist.dataset import Dataset, TiledDataset, load_dataset +from dkist.utils.sysinfo import system_info diff --git a/dkist/_dev/scm_version.py b/dkist/_dev/scm_version.py index 1bcf0dd9..b9afb1d9 100644 --- a/dkist/_dev/scm_version.py +++ b/dkist/_dev/scm_version.py @@ -5,8 +5,8 @@ try: from setuptools_scm import get_version - version = get_version(root=os.path.join('..', '..'), relative_to=__file__) + version = get_version(root=os.path.join("..", ".."), relative_to=__file__) except ImportError: raise except Exception as e: - raise ValueError('setuptools_scm can not determine version.') from e + raise ValueError("setuptools_scm can not determine version.") from e diff --git a/dkist/config/__init__.py b/dkist/config/__init__.py index 9439b923..1f9b6ed1 100644 --- a/dkist/config/__init__.py +++ b/dkist/config/__init__.py @@ -1,12 +1,12 @@ from astropy.config import ConfigItem as _AstropyConfigItem from astropy.config import ConfigNamespace as _AstropyConfigNamespace -__all__ = ['ConfigItem', 'ConfigNamespace'] +__all__ = ["ConfigItem", "ConfigNamespace"] class ConfigNamespace(_AstropyConfigNamespace): - rootname = 'dkist' + rootname = "dkist" class ConfigItem(_AstropyConfigItem): - rootname = 'dkist' + rootname = "dkist" diff --git a/dkist/conftest.py b/dkist/conftest.py index 48940e17..21a21639 100644 --- a/dkist/conftest.py +++ b/dkist/conftest.py @@ -36,7 +36,7 @@ def caplog_dkist(caplog): @pytest.fixture def array(): - shape = 2**np.random.randint(2, 7, size=2) + shape = 2**np.random.randint(2, 7, size=2) # noqa: NPY002 x = np.ones(np.prod(shape)) + 10 x = x.reshape(shape) return da.from_array(x, tuple(shape)) @@ -65,7 +65,7 @@ def identity_gwcs(): """ identity = m.Multiply(1*u.arcsec/u.pixel) & m.Multiply(1*u.arcsec/u.pixel) sky_frame = cf.CelestialFrame(axes_order=(0, 1), - name='helioprojective', + name="helioprojective", reference_frame=Helioprojective(obstime="2018-01-01"), unit=(u.arcsec, u.arcsec), axis_physical_types=("custom:pos.helioprojective.lat", @@ -89,7 +89,7 @@ def identity_gwcs_3d(): identity = (TwoDScale(1 * u.arcsec / u.pixel) & m.Multiply(1 * u.nm / u.pixel)) - sky_frame = cf.CelestialFrame(axes_order=(0, 1), name='helioprojective', + sky_frame = cf.CelestialFrame(axes_order=(0, 1), name="helioprojective", reference_frame=Helioprojective(obstime="2018-01-01"), axes_names=("longitude", "latitude"), unit=(u.arcsec, u.arcsec), @@ -118,7 +118,7 @@ def identity_gwcs_3d_temporal(): identity = (TwoDScale(1 * u.arcsec / u.pixel) & m.Multiply(1 * u.s / u.pixel)) - sky_frame = cf.CelestialFrame(axes_order=(0, 1), name='helioprojective', + sky_frame = cf.CelestialFrame(axes_order=(0, 1), name="helioprojective", reference_frame=Helioprojective(obstime="2018-01-01"), axes_names=("longitude", "latitude"), unit=(u.arcsec, u.arcsec), @@ -145,7 +145,7 @@ def identity_gwcs_4d(): """ identity = (TwoDScale(1 * u.arcsec / u.pixel) & m.Multiply(1 * u.nm/u.pixel) & m.Multiply(1 * u.s/u.pixel)) - sky_frame = cf.CelestialFrame(axes_order=(0, 1), name='helioprojective', + sky_frame = cf.CelestialFrame(axes_order=(0, 1), name="helioprojective", reference_frame=Helioprojective(obstime="2018-01-01"), unit=(u.arcsec, u.arcsec), axis_physical_types=("custom:pos.helioprojective.lon", "custom:pos.helioprojective.lat")) @@ -168,7 +168,7 @@ def identity_gwcs_4d(): # This function lives in dkist_inventory, but is copied here to avoid a test dep -def generate_lookup_table(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs): +def generate_lookup_table(lookup_table, interpolation="linear", points_unit=u.pix, **kwargs): if not isinstance(lookup_table, u.Quantity): raise TypeError("lookup_table must be a Quantity.") @@ -176,9 +176,9 @@ def generate_lookup_table(lookup_table, interpolation='linear', points_unit=u.pi points = (np.arange(lookup_table.size) - 0) * points_unit kwargs = { - 'bounds_error': False, - 'fill_value': np.nan, - 'method': interpolation, + "bounds_error": False, + "fill_value": np.nan, + "method": interpolation, **kwargs } @@ -188,9 +188,9 @@ def generate_lookup_table(lookup_table, interpolation='linear', points_unit=u.pi @pytest.fixture def identity_gwcs_5d_stokes(identity_gwcs_4d): stokes_frame = cf.StokesFrame(axes_order=(4,)) - stokes_model = generate_lookup_table([1, 2, 3, 4] * u.one, interpolation='nearest') + stokes_model = generate_lookup_table([1, 2, 3, 4] * u.one, interpolation="nearest") transform = identity_gwcs_4d.forward_transform - frame = cf.CompositeFrame(identity_gwcs_4d.output_frame.frames + [stokes_frame]) + frame = cf.CompositeFrame([*identity_gwcs_4d.output_frame.frames, stokes_frame]) detector_frame = cf.CoordinateFrame(name="detector", naxes=5, axes_order=(0, 1, 2, 3, 4), @@ -209,17 +209,17 @@ def identity_gwcs_5d_stokes(identity_gwcs_4d): @pytest.fixture def dataset(array, identity_gwcs): meta = { - 'inventory': { - 'bucket': 'data', - 'datasetId': 'test_dataset', - 'primaryProposalId': 'test_proposal', - 'asdfObjectKey': 'test_proposal/test_dataset/test_dataset.asdf', - 'browseMovieObjectKey': 'test_proposal/test_dataset/test_dataset.mp4', - 'qualityReportObjectKey': 'test_proposal/test_dataset/test_dataset.pdf', - 'wavelengthMin': 0, - 'wavelengthMax': 0, + "inventory": { + "bucket": "data", + "datasetId": "test_dataset", + "primaryProposalId": "test_proposal", + "asdfObjectKey": "test_proposal/test_dataset/test_dataset.asdf", + "browseMovieObjectKey": "test_proposal/test_dataset/test_dataset.mp4", + "qualityReportObjectKey": "test_proposal/test_dataset/test_dataset.pdf", + "wavelengthMin": 0, + "wavelengthMax": 0, }, - 'headers': Table() + "headers": Table() } identity_gwcs.array_shape = array.shape @@ -231,7 +231,7 @@ def dataset(array, identity_gwcs): # Construct the filename here as a scalar array to make sure that works as # it's what dkist-inventory does - ds._file_manager = FileManager.from_parts(np.array('test1.fits'), 0, 'float', array.shape, + ds._file_manager = FileManager.from_parts(np.array("test1.fits"), 0, "float", array.shape, loader=AstropyFITSLoader) return ds @@ -239,7 +239,7 @@ def dataset(array, identity_gwcs): @pytest.fixture def empty_meta(): - return {'inventory': {}, 'headers': {}} + return {"inventory": {}, "headers": {}} @pytest.fixture @@ -270,16 +270,16 @@ def dataset_4d(identity_gwcs_4d, empty_meta): def eit_dataset(): eitdir = Path(rootdir) / "EIT" with asdf.open(eitdir / "eit_test_dataset.asdf") as f: - return f.tree['dataset'] + return f.tree["dataset"] @pytest.fixture def simple_tiled_dataset(dataset): datasets = [copy.deepcopy(dataset) for i in range(4)] for ds in datasets: - ds.meta['inventory'] = dataset.meta['inventory'] + ds.meta["inventory"] = dataset.meta["inventory"] dataset_array = np.array(datasets).reshape((2,2)) - return TiledDataset(dataset_array, dataset.meta['inventory']) + return TiledDataset(dataset_array, dataset.meta["inventory"]) @pytest.fixture @@ -303,7 +303,7 @@ def small_visp_dataset(): vispdir = Path(rootdir) / "small_visp" with asdf.open(vispdir / "test_visp.asdf") as f: - return f.tree['dataset'] + return f.tree["dataset"] @pytest.fixture(scope="session") diff --git a/dkist/dataset/dataset.py b/dkist/dataset/dataset.py index fee5cf17..a972686d 100644 --- a/dkist/dataset/dataset.py +++ b/dkist/dataset/dataset.py @@ -12,7 +12,7 @@ from .utils import dataset_info_str -__all__ = ['Dataset'] +__all__ = ["Dataset"] class FileManagerDescriptor(NDCubeLinkedDescriptor): @@ -57,7 +57,7 @@ class Dataset(NDCube): Uncertainty in the dataset. Should have an attribute uncertainty_type that defines what kind of uncertainty is stored, for example "std" for standard deviation or "var" for variance. A metaclass defining such - an interface is `~astropy.nddata.NDUncertainty` - but isn’t mandatory. + an interface is `~astropy.nddata.NDUncertainty` - but isn't mandatory. If the uncertainty has no such attribute the uncertainty is stored as `~astropy.nddata.UnknownUncertainty`. Defaults to None. @@ -152,7 +152,7 @@ def _slice_headers(self, slice_): file_idx.append(slc) grid = np.mgrid[tuple(file_idx)] file_idx = tuple(grid[i].ravel() for i in range(grid.shape[0])) - flat_idx = np.ravel_multi_index(file_idx[::-1], files_shape[::-1], order='F') + flat_idx = np.ravel_multi_index(file_idx[::-1], files_shape[::-1], order="F") # Explicitly create new header table to ensure consistency # Otherwise would return a reference sometimes and a new table others @@ -193,7 +193,7 @@ def inventory(self): """ Convenience attribute to access the inventory metadata. """ - return self.meta['inventory'] + return self.meta["inventory"] """ Dataset loading and saving routines. @@ -227,8 +227,7 @@ def __repr__(self): Overload the NDData repr because it does not play nice with the dask delayed io. """ prefix = object.__repr__(self) - output = dedent(f"{prefix}\n{self.__str__()}") - return output + return dedent(f"{prefix}\n{self.__str__()}") def __str__(self): return dataset_info_str(self) diff --git a/dkist/dataset/loader.py b/dkist/dataset/loader.py index c28d5d71..ef724726 100644 --- a/dkist/dataset/loader.py +++ b/dkist/dataset/loader.py @@ -123,8 +123,8 @@ def _load_from_path(path: Path): if not path.exists(): raise ValueError(f"{path} does not exist.") return _load_from_asdf(path) - else: - return _load_from_directory(path) + + return _load_from_directory(path) def _load_from_directory(directory): @@ -137,7 +137,8 @@ def _load_from_directory(directory): if not asdf_files: raise ValueError(f"No asdf file found in directory {base_path}.") - elif len(asdf_files) > 1: + + if len(asdf_files) > 1: return _load_from_iterable(asdf_files) asdf_file = asdf_files[0] @@ -156,7 +157,7 @@ def _load_from_asdf(filepath): with importlib_resources.as_file(importlib_resources.files("dkist.io") / "level_1_dataset_schema.yaml") as schema_path: with asdf.open(filepath, custom_schema=schema_path.as_posix(), lazy_load=False, copy_arrays=True) as ff: - ds = ff.tree['dataset'] + ds = ff.tree["dataset"] if isinstance(ds, TiledDataset): for sub in ds.flat: sub.files.basepath = base_path @@ -183,8 +184,7 @@ def _known_types_docs(): def _formatted_types_docstring(known_types): lines = [f"| `{fqn}` - {doc}" for fqn, doc in known_types.items()] - docstring = '\n '.join(lines) - return docstring + return "\n ".join(lines) load_dataset.__doc__ = load_dataset.__doc__.format(types_list=_formatted_types_docstring(_known_types_docs()), diff --git a/dkist/dataset/tests/test_dataset.py b/dkist/dataset/tests/test_dataset.py index 2778abdf..e1b64da7 100644 --- a/dkist/dataset/tests/test_dataset.py +++ b/dkist/dataset/tests/test_dataset.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import dask.array as da @@ -20,7 +19,7 @@ @pytest.fixture def invalid_asdf(tmp_path): filename = Path(tmp_path / "test.asdf") - tree = {'spam': 'eggs'} + tree = {"spam": "eggs"} with asdf.AsdfFile(tree=tree) as af: af.write_to(filename) return filename @@ -38,10 +37,10 @@ def test_missing_quality(dataset): def test_init_missing_meta_keys(identity_gwcs): data = np.zeros(identity_gwcs.array_shape) with pytest.raises(ValueError, match=".*must contain the headers table."): - Dataset(data, wcs=identity_gwcs, meta={'inventory': {}}) + Dataset(data, wcs=identity_gwcs, meta={"inventory": {}}) with pytest.raises(ValueError, match=".*must contain the inventory record."): - Dataset(data, wcs=identity_gwcs, meta={'headers': {}}) + Dataset(data, wcs=identity_gwcs, meta={"headers": {}}) def test_repr(dataset, dataset_3d): @@ -72,57 +71,53 @@ def test_dimensions(dataset, dataset_3d): def test_load_from_directory(): - ds = load_dataset(os.path.join(rootdir, 'EIT')) + ds = load_dataset(rootdir / "EIT") assert isinstance(ds.data, da.Array) assert isinstance(ds.wcs, gwcs.WCS) assert_quantity_allclose(ds.dimensions, (11, 128, 128)*u.pix) - assert ds.files.basepath == Path(os.path.join(rootdir, 'EIT')) + assert ds.files.basepath == Path(rootdir / "EIT") def test_from_directory_no_asdf(tmp_path): - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="No asdf file found"): load_dataset(tmp_path) - assert "No asdf file found" in str(e) def test_from_not_directory(): - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="does not exist"): load_dataset(rootdir / "notadirectory") - assert "directory argument" in str(e) def test_load_tiled_dataset(): - ds = load_dataset(os.path.join(rootdir, 'test_tiled_dataset-1.0.0_dataset-1.1.0.asdf')) + ds = load_dataset(rootdir / "test_tiled_dataset-1.0.0_dataset-1.1.0.asdf") assert isinstance(ds, TiledDataset) assert ds.shape == (3, 3) def test_load_with_old_methods(): with pytest.warns(DKISTDeprecationWarning): - ds = Dataset.from_directory(os.path.join(rootdir, 'EIT')) + ds = Dataset.from_directory(rootdir / "EIT") assert isinstance(ds.data, da.Array) assert isinstance(ds.wcs, gwcs.WCS) assert_quantity_allclose(ds.dimensions, (11, 128, 128)*u.pix) - assert ds.files.basepath == Path(os.path.join(rootdir, 'EIT')) + assert ds.files.basepath == Path(rootdir / "EIT") - with pytest.warns(DKISTDeprecationWarning) as e: - ds = Dataset.from_asdf(os.path.join(rootdir, 'EIT', "eit_test_dataset.asdf")) + with pytest.warns(DKISTDeprecationWarning): + ds = Dataset.from_asdf(rootdir / "EIT" / "eit_test_dataset.asdf") assert isinstance(ds.data, da.Array) assert isinstance(ds.wcs, gwcs.WCS) assert_quantity_allclose(ds.dimensions, (11, 128, 128)*u.pix) - assert ds.files.basepath == Path(os.path.join(rootdir, 'EIT')) + assert ds.files.basepath == Path(rootdir / "EIT") def test_from_directory_not_dir(): - with pytest.raises(ValueError) as e: - load_dataset(rootdir / 'EIT' / 'eit_2004-03-01T00_00_10.515000.asdf') - assert "must be a directory" in str(e) + with pytest.raises(ValueError, match="asdf does not exist"): + load_dataset(rootdir / "EIT" / "eit_2004-03-01T00_00_10.515000.asdf") def test_load_with_invalid_input(): - with pytest.raises(TypeError) as e: + with pytest.raises(TypeError, match="Input type .* not recognised."): load_dataset(42) - assert "Input type not recognised." in str(e) def test_crop_few_slices(dataset_4d): @@ -131,7 +126,7 @@ def test_crop_few_slices(dataset_4d): def test_file_manager(): - dataset = load_dataset(os.path.join(rootdir, 'EIT')) + dataset = load_dataset(rootdir / "EIT") assert dataset.files is dataset._file_manager with pytest.raises(AttributeError): dataset.files = 10 @@ -149,23 +144,23 @@ def test_no_file_manager(dataset_3d): def test_inventory_propery(): - dataset = load_dataset(os.path.join(rootdir, 'EIT')) - assert dataset.inventory == dataset.meta['inventory'] + dataset = load_dataset(rootdir / "EIT") + assert dataset.inventory == dataset.meta["inventory"] def test_header_slicing_single_index(): - dataset = load_dataset(os.path.join(rootdir, 'EIT')) + dataset = load_dataset(rootdir / "EIT") idx = 5 sliced = dataset[idx] sliced_headers = dataset.headers[idx] # Filenames in the header don't match the names of the files because why would you expect those things to be the same - sliced_header_files = sliced_headers['FILENAME'] + '_s.fits' + sliced_header_files = sliced_headers["FILENAME"] + "_s.fits" assert len(sliced.files.filenames) == 1 assert isinstance(sliced_headers, Row) assert sliced.files.filenames[0] == sliced_header_files - assert (sliced.headers['DINDEX3'] == sliced_headers['DINDEX3']).all() + assert (sliced.headers["DINDEX3"] == sliced_headers["DINDEX3"]).all() def test_header_slicing_3D_slice(large_visp_dataset): @@ -181,5 +176,5 @@ def test_header_slicing_3D_slice(large_visp_dataset): sliced_headers = dataset.headers[flat_idx] - assert len(sliced.files.filenames) == len(sliced_headers['FILENAME']) == len(sliced.headers) - assert (sliced.headers['DINDEX3', 'DINDEX4'] == sliced_headers['DINDEX3', 'DINDEX4']).all() + assert len(sliced.files.filenames) == len(sliced_headers["FILENAME"]) == len(sliced.headers) + assert (sliced.headers["DINDEX3", "DINDEX4"] == sliced_headers["DINDEX3", "DINDEX4"]).all() diff --git a/dkist/dataset/tests/test_load_dataset.py b/dkist/dataset/tests/test_load_dataset.py index 3b8eb871..bc24a147 100644 --- a/dkist/dataset/tests/test_load_dataset.py +++ b/dkist/dataset/tests/test_load_dataset.py @@ -49,12 +49,12 @@ def fixture_finder(request): return request.getfixturevalue(request.param) -@pytest.mark.parametrize("fixture_finder", ( +@pytest.mark.parametrize("fixture_finder", [ "asdf_path", "asdf_str", "single_asdf_in_folder", "single_asdf_in_folder_str", - ), + ], indirect=True ) def test_load_single_dataset(fixture_finder): @@ -62,16 +62,16 @@ def test_load_single_dataset(fixture_finder): assert isinstance(ds, Dataset) -@pytest.mark.parametrize("fixture_finder", ( +@pytest.mark.parametrize("fixture_finder", [ ["asdf_path", "asdf_str", "single_asdf_in_folder", "single_asdf_in_folder_str"], ("asdf_path", "asdf_str", "single_asdf_in_folder", "single_asdf_in_folder_str"), - ), + ], indirect=True ) def test_load_multiple(fixture_finder): datasets = load_dataset(fixture_finder) assert isinstance(datasets, list) - assert all([isinstance(ds, Dataset) for ds in datasets]) + assert all(isinstance(ds, Dataset) for ds in datasets) def test_load_from_results(asdf_path, asdf_str): @@ -82,14 +82,14 @@ def test_load_from_results(asdf_path, asdf_str): res = Results([asdf_str, asdf_str]) ds = load_dataset(res) assert isinstance(ds, list) - assert all([isinstance(ds, Dataset) for ds in ds]) + assert all(isinstance(ds, Dataset) for ds in ds) def test_multiple_from_dir(multiple_asdf_in_folder): ds = load_dataset(multiple_asdf_in_folder) assert isinstance(ds, list) assert len(ds) == 2 - assert all([isinstance(d, Dataset) for d in ds]) + assert all(isinstance(d, Dataset) for d in ds) def test_tiled_dataset(asdf_tileddataset_path): diff --git a/dkist/dataset/tests/test_plotting.py b/dkist/dataset/tests/test_plotting.py index bc6e65b5..2965db14 100644 --- a/dkist/dataset/tests/test_plotting.py +++ b/dkist/dataset/tests/test_plotting.py @@ -6,7 +6,7 @@ @pytest.mark.mpl_image_compare -@pytest.mark.parametrize("aslice", (np.s_[0, :, :], np.s_[:, 0, :], np.s_[:, :, 0])) +@pytest.mark.parametrize("aslice", [np.s_[0, :, :], np.s_[:, 0, :], np.s_[:, :, 0]]) def test_dataset_projection(dataset_3d, aslice): pytest.importorskip("ndcube", "2.0.2") # https://github.com/sunpy/ndcube/pull/509 ds = dataset_3d[aslice] @@ -17,7 +17,7 @@ def test_dataset_projection(dataset_3d, aslice): @pytest.mark.mpl_image_compare -@pytest.mark.parametrize("aslice", (np.s_[0, :, :], np.s_[:, 0, :], np.s_[:, :, 0])) +@pytest.mark.parametrize("aslice", [np.s_[0, :, :], np.s_[:, 0, :], np.s_[:, :, 0]]) def test_2d_plot(dataset_3d, aslice): fig = plt.figure() dataset_3d[aslice].plot() diff --git a/dkist/dataset/tests/test_tiled_dataset.py b/dkist/dataset/tests/test_tiled_dataset.py index 9da6be7e..4b6b4364 100644 --- a/dkist/dataset/tests/test_tiled_dataset.py +++ b/dkist/dataset/tests/test_tiled_dataset.py @@ -10,25 +10,25 @@ def test_tiled_dataset(simple_tiled_dataset, dataset): assert isinstance(simple_tiled_dataset, TiledDataset) assert simple_tiled_dataset._data[0, 0] in simple_tiled_dataset assert 5 not in simple_tiled_dataset - assert all([isinstance(t, Dataset) for t in simple_tiled_dataset.flat]) - assert all([t.shape == (2,) for t in simple_tiled_dataset]) - assert simple_tiled_dataset.inventory is dataset.meta['inventory'] + assert all(isinstance(t, Dataset) for t in simple_tiled_dataset.flat) + assert all(t.shape == (2,) for t in simple_tiled_dataset) + assert simple_tiled_dataset.inventory is dataset.meta["inventory"] assert simple_tiled_dataset.shape == (2, 2) -@pytest.mark.parametrize("aslice", (np.s_[0,0], +@pytest.mark.parametrize("aslice", [np.s_[0,0], np.s_[0], np.s_[...,0], np.s_[:,1], np.s_[1,1], - np.s_[0:2, :])) + np.s_[0:2, :]]) def test_tiled_dataset_slice(simple_tiled_dataset, aslice): assert np.all(simple_tiled_dataset[aslice] == simple_tiled_dataset._data[aslice]) def test_tiled_dataset_headers(simple_tiled_dataset, dataset): - assert len(simple_tiled_dataset.combined_headers) == len(dataset.meta['headers']) * 4 - assert simple_tiled_dataset.combined_headers.colnames == dataset.meta['headers'].colnames + assert len(simple_tiled_dataset.combined_headers) == len(dataset.meta["headers"]) * 4 + assert simple_tiled_dataset.combined_headers.colnames == dataset.meta["headers"].colnames def test_tiled_dataset_invalid_construction(dataset, dataset_4d): @@ -36,26 +36,26 @@ def test_tiled_dataset_invalid_construction(dataset, dataset_4d): TiledDataset(np.array((dataset, dataset_4d))) with pytest.raises(ValueError, match="physical types do not match"): - TiledDataset(np.array((dataset, dataset_4d)), inventory=dataset.meta['inventory']) + TiledDataset(np.array((dataset, dataset_4d)), inventory=dataset.meta["inventory"]) ds2 = copy.deepcopy(dataset) - ds2.meta['inventory'] = {'hello': 'world'} + ds2.meta["inventory"] = {"hello": "world"} with pytest.raises(ValueError, match="inventory records of all the datasets"): - TiledDataset(np.array((dataset, ds2)), dataset.meta['inventory']) + TiledDataset(np.array((dataset, ds2)), dataset.meta["inventory"]) def test_tiled_dataset_from_components(dataset): shape = (2, 2) file_managers = [dataset._file_manager] * 4 wcses = [dataset.wcs] * 4 - header_tables = [dataset.meta['headers']] * 4 - inventory = dataset.meta['inventory'] + header_tables = [dataset.meta["headers"]] * 4 + inventory = dataset.meta["inventory"] tiled_ds = TiledDataset._from_components(shape, file_managers, wcses, header_tables, inventory) assert isinstance(tiled_ds, TiledDataset) assert tiled_ds.shape == shape - assert all([isinstance(t, Dataset) for t in tiled_ds.flat]) + assert all(isinstance(t, Dataset) for t in tiled_ds.flat) for ds, fm, headers in zip(tiled_ds.flat, file_managers, header_tables): assert ds.files == fm - assert ds.meta['inventory'] is inventory - assert ds.meta['headers'] is headers + assert ds.meta["inventory"] is inventory + assert ds.meta["headers"] is headers diff --git a/dkist/dataset/tiled_dataset.py b/dkist/dataset/tiled_dataset.py index a55ed88c..eec8b9de 100644 --- a/dkist/dataset/tiled_dataset.py +++ b/dkist/dataset/tiled_dataset.py @@ -13,7 +13,7 @@ from .dataset import Dataset -__all__ = ['TiledDataset'] +__all__ = ["TiledDataset"] class TiledDataset(Collection): @@ -66,7 +66,7 @@ def __init__(self, dataset_array, inventory=None): self._validate_component_datasets(self._data, inventory) def __contains__(self, x): - return any([ele is x for ele in self._data.flat]) + return any(ele is x for ele in self._data.flat) def __len__(self): return self._data.__len__() diff --git a/dkist/dataset/utils.py b/dkist/dataset/utils.py index 8619370c..bd29f20c 100644 --- a/dkist/dataset/utils.py +++ b/dkist/dataset/utils.py @@ -6,7 +6,7 @@ import gwcs -__all__ = ['dataset_info_str'] +__all__ = ["dataset_info_str"] def dataset_info_str(ds): @@ -30,24 +30,24 @@ def dataset_info_str(ds): elif isinstance(wcs, gwcs.WCS): pixel_axis_names = wcs.input_frame.axes_names else: - pixel_axis_names = [''] * wcs.pixel_n_dim + pixel_axis_names = [""] * wcs.pixel_n_dim pixel_dim_width = max(9, len(str(wcs.pixel_n_dim))) pixel_nam_width = max(9, max(len(x) for x in pixel_axis_names)) pixel_siz_width = max(9, len(str(max(array_shape)))) - s += (('{0:' + str(pixel_dim_width) + 's}').format('Pixel Dim') + ' ' + - ('{0:' + str(pixel_nam_width) + 's}').format('Axis Name') + ' ' + - ('{0:' + str(pixel_siz_width) + 's}').format('Data size') + ' ' + - 'Bounds\n') + s += (("{0:" + str(pixel_dim_width) + "s}").format("Pixel Dim") + " " + + ("{0:" + str(pixel_nam_width) + "s}").format("Axis Name") + " " + + ("{0:" + str(pixel_siz_width) + "s}").format("Data size") + " " + + "Bounds\n") for ipix in range(ds.wcs.pixel_n_dim): - s += (('{0:' + str(pixel_dim_width) + 'd}').format(ipix) + ' ' + - ('{0:' + str(pixel_nam_width) + 's}').format(pixel_axis_names[::-1][ipix] or 'None') + ' ' + + s += (("{0:" + str(pixel_dim_width) + "d}").format(ipix) + " " + + ("{0:" + str(pixel_nam_width) + "s}").format(pixel_axis_names[::-1][ipix] or "None") + " " + (" " * 5 + str(None) if pixel_shape[::-1][ipix] is None else - ('{0:' + str(pixel_siz_width) + 'd}').format(pixel_shape[::-1][ipix])) + ' ' + - '{:s}'.format(str(None if wcs.pixel_bounds is None else wcs.pixel_bounds[::-1][ipix]) + '\n')) - s += '\n' + ("{0:" + str(pixel_siz_width) + "d}").format(pixel_shape[::-1][ipix])) + " " + + "{:s}".format(str(None if wcs.pixel_bounds is None else wcs.pixel_bounds[::-1][ipix]) + "\n")) + s += "\n" # World dimensions table @@ -56,52 +56,52 @@ def dataset_info_str(ds): world_nam_width = max(9, max(len(x) if x is not None else 0 for x in wcs.world_axis_names)) world_typ_width = max(13, max(len(x) if x is not None else 0 for x in wcs.world_axis_physical_types)) - s += (('{0:' + str(world_dim_width) + 's}').format('World Dim') + ' ' + - ('{0:' + str(world_nam_width) + 's}').format('Axis Name') + ' ' + - ('{0:' + str(world_typ_width) + 's}').format('Physical Type') + ' ' + - 'Units\n') + s += (("{0:" + str(world_dim_width) + "s}").format("World Dim") + " " + + ("{0:" + str(world_nam_width) + "s}").format("Axis Name") + " " + + ("{0:" + str(world_typ_width) + "s}").format("Physical Type") + " " + + "Units\n") for iwrl in range(wcs.world_n_dim): - name = wcs.world_axis_names[::-1][iwrl] or 'None' - typ = wcs.world_axis_physical_types[::-1][iwrl] or 'None' - unit = wcs.world_axis_units[::-1][iwrl] or 'unknown' + name = wcs.world_axis_names[::-1][iwrl] or "None" + typ = wcs.world_axis_physical_types[::-1][iwrl] or "None" + unit = wcs.world_axis_units[::-1][iwrl] or "unknown" - s += (('{0:' + str(world_dim_width) + 'd}').format(iwrl) + ' ' + - ('{0:' + str(world_nam_width) + 's}').format(name) + ' ' + - ('{0:' + str(world_typ_width) + 's}').format(typ) + ' ' + - '{:s}'.format(unit + '\n')) + s += (("{0:" + str(world_dim_width) + "d}").format(iwrl) + " " + + ("{0:" + str(world_nam_width) + "s}").format(name) + " " + + ("{0:" + str(world_typ_width) + "s}").format(typ) + " " + + "{:s}".format(unit + "\n")) - s += '\n' + s += "\n" # Axis correlation matrix pixel_dim_width = max(3, len(str(wcs.world_n_dim))) - s += 'Correlation between pixel and world axes:\n\n' + s += "Correlation between pixel and world axes:\n\n" - s += (' ' * world_dim_width + ' ' + - ('{0:^' + str(wcs.pixel_n_dim * 5 - 2) + 's}').format('Pixel Dim') + - '\n') + s += (" " * world_dim_width + " " + + ("{0:^" + str(wcs.pixel_n_dim * 5 - 2) + "s}").format("Pixel Dim") + + "\n") - s += (('{0:' + str(world_dim_width) + 's}').format('World Dim') + - ''.join([' ' + ('{0:' + str(pixel_dim_width) + 'd}').format(ipix) + s += (("{0:" + str(world_dim_width) + "s}").format("World Dim") + + "".join([" " + ("{0:" + str(pixel_dim_width) + "d}").format(ipix) for ipix in range(wcs.pixel_n_dim)]) + - '\n') + "\n") matrix = wcs.axis_correlation_matrix[::-1, ::-1] - matrix_str = np.empty(matrix.shape, dtype='U3') - matrix_str[matrix] = 'yes' - matrix_str[~matrix] = 'no' + matrix_str = np.empty(matrix.shape, dtype="U3") + matrix_str[matrix] = "yes" + matrix_str[~matrix] = "no" for iwrl in range(wcs.world_n_dim): - s += (('{0:' + str(world_dim_width) + 'd}').format(iwrl) + - ''.join([' ' + ('{0:>' + str(pixel_dim_width) + 's}').format(matrix_str[iwrl, ipix]) + s += (("{0:" + str(world_dim_width) + "d}").format(iwrl) + + "".join([" " + ("{0:>" + str(pixel_dim_width) + "s}").format(matrix_str[iwrl, ipix]) for ipix in range(wcs.pixel_n_dim)]) + - '\n') + "\n") # Make sure we get rid of the extra whitespace at the end of some lines - return '\n'.join([l.rstrip() for l in s.splitlines()]) + return "\n".join([line.rstrip() for line in s.splitlines()]) def pp_matrix(wcs): @@ -112,10 +112,10 @@ def pp_matrix(wcs): ---------- wcs : `BaseHighLevelWCS` or `BaseLowLevelWCS` """ - slen = np.max([len(l) for l in list(wcs.world_axis_names) + list(wcs.pixel_axis_names)]) + slen = np.max([len(line) for line in list(wcs.world_axis_names) + list(wcs.pixel_axis_names)]) mstr = wcs.axis_correlation_matrix.astype(f" bool: return all((uri, target, dtype, shape)) @staticmethod - def _output_shape_from_ref_array(shape, loader_array) -> Tuple[int]: + def _output_shape_from_ref_array(shape, loader_array) -> tuple[int]: # If the first dimension is one we are going to squash it. if shape[0] == 1: shape = shape[1:] if loader_array.size == 1: return shape - else: - return tuple(list(loader_array.shape) + list(shape)) + + return tuple(list(loader_array.shape) + list(shape)) @property - def output_shape(self) -> Tuple[int, ...]: + def output_shape(self) -> tuple[int, ...]: """ The final shape of the reconstructed data array. """ @@ -125,7 +126,7 @@ def basepath(self) -> os.PathLike: return self._basepath @basepath.setter - def basepath(self, value: Optional[Union[os.PathLike, str]]): + def basepath(self, value: os.PathLike | str | None): self._basepath = Path(value).expanduser() if value is not None else None @property @@ -153,7 +154,7 @@ class StripedExternalArrayView(BaseStripedExternalArray): # class. __slots__ = ["parent", "parent_slice"] - def __init__(self, parent: StripedExternalArray, aslice: Union[tuple, slice, int]): + def __init__(self, parent: StripedExternalArray, aslice: tuple | slice | int): self.parent = parent self.parent_slice = tuple(aslice) @@ -340,7 +341,7 @@ def quality_report(self, path=None, overwrite=None): downloaded file if the download was successful, and any errors if it was not. """ - dataset_id = self._ndcube.meta['inventory']['datasetId'] + dataset_id = self._ndcube.meta["inventory"]["datasetId"] url = f"{self._metadata_streamer_url}/quality?datasetId={dataset_id}" if path is None and self.basepath: path = self.basepath @@ -368,7 +369,7 @@ def preview_movie(self, path=None, overwrite=None): downloaded file if the download was successful, and any errors if it was not. """ - dataset_id = self._ndcube.meta['inventory']['datasetId'] + dataset_id = self._ndcube.meta["inventory"]["datasetId"] url = f"{self._metadata_streamer_url}/movie?datasetId={dataset_id}" if path is None and self.basepath: path = self.basepath diff --git a/dkist/io/loaders.py b/dkist/io/loaders.py index d1a3b816..3964b733 100644 --- a/dkist/io/loaders.py +++ b/dkist/io/loaders.py @@ -14,7 +14,7 @@ from dkist import log -__all__ = ['BaseFITSLoader', 'AstropyFITSLoader'] +__all__ = ["BaseFITSLoader", "AstropyFITSLoader"] common_parameters = """ @@ -58,7 +58,7 @@ def __repr__(self): return self.__str__() def __str__(self): - return "".format(self) + return f"" @property def data(self): @@ -79,8 +79,8 @@ def absolute_uri(self): """ if self.basepath: return self.basepath / self.fileuri - else: - return Path(self.fileuri) + + return Path(self.fileuri) @add_common_docstring(append=common_parameters) @@ -105,5 +105,5 @@ def __getitem__(self, slc): hdu = hdul[self.target] if hasattr(hdu, "section"): return hdu.section[slc] - else: - return hdu.data[slc] + + return hdu.data[slc] diff --git a/dkist/io/tests/test_file_manager.py b/dkist/io/tests/test_file_manager.py index 22b8ab0f..c8a067ee 100644 --- a/dkist/io/tests/test_file_manager.py +++ b/dkist/io/tests/test_file_manager.py @@ -9,7 +9,7 @@ from dkist.data.test import rootdir from dkist.io.file_manager import FileManager, StripedExternalArray, StripedExternalArrayView -eitdir = Path(rootdir) / 'EIT' +eitdir = Path(rootdir) / "EIT" @pytest.fixture @@ -31,7 +31,7 @@ def loader_array(file_manager): def test_load_and_slicing(file_manager, loader_array): ext_shape = np.array(loader_array, dtype=object).shape assert file_manager._striped_external_array.loader_array.shape == ext_shape - assert file_manager.output_shape == tuple(list(ext_shape) + [128, 128]) + assert file_manager.output_shape == (*list(ext_shape), 128, 128) array = file_manager._generate_array().compute() assert isinstance(array, np.ndarray) @@ -41,7 +41,7 @@ def test_load_and_slicing(file_manager, loader_array): sliced_manager = file_manager[5:8] ext_shape = np.array(loader_array[5:8], dtype=object).shape assert sliced_manager._striped_external_array.loader_array.shape == ext_shape - assert sliced_manager.output_shape == tuple(list(ext_shape) + [128, 128]) + assert sliced_manager.output_shape == (*list(ext_shape), 128, 128) def test_filenames(file_manager, loader_array): @@ -52,7 +52,7 @@ def test_filenames(file_manager, loader_array): def test_dask(file_manager, loader_array): ext_shape = np.array(loader_array, dtype=object).shape assert file_manager._striped_external_array.loader_array.shape == ext_shape - assert file_manager.output_shape == tuple(list(ext_shape) + [128, 128]) + assert file_manager.output_shape == (*list(ext_shape), 128, 128) assert isinstance(file_manager._generate_array(), da.Array) assert_allclose(file_manager._generate_array(), np.array(file_manager._generate_array())) @@ -162,16 +162,14 @@ def test_reprs(file_manager): @pytest.fixture def orchestrate_transfer_mock(mocker): - yield mocker.patch("dkist.net.helpers._orchestrate_transfer_task", + return mocker.patch("dkist.net.helpers._orchestrate_transfer_task", autospec=True) def test_download_default_keywords(dataset, orchestrate_transfer_mock): base_path = Path(net.conf.dataset_path.format(**dataset.meta["inventory"])) folder = Path("/{bucket}/{primaryProposalId}/{datasetId}/".format(**dataset.meta["inventory"])) - file_list = dataset.files.filenames + [folder / "test_dataset.asdf", - folder / "test_dataset.mp4", - folder / "test_dataset.pdf"] + file_list = [*dataset.files.filenames, folder / "test_dataset.asdf", folder / "test_dataset.mp4", folder / "test_dataset.pdf"] file_list = [base_path / fn for fn in file_list] dataset.files.download() @@ -179,7 +177,7 @@ def test_download_default_keywords(dataset, orchestrate_transfer_mock): orchestrate_transfer_mock.assert_called_once_with( file_list, recursive=False, - destination_path=Path('/~'), + destination_path=Path("/~"), destination_endpoint=None, progress=True, wait=True, @@ -201,9 +199,7 @@ def test_download_keywords(dataset, orchestrate_transfer_mock, keywords): base_path = Path(net.conf.dataset_path.format(**dataset.meta["inventory"])) folder = Path("/{bucket}/{primaryProposalId}/{datasetId}/".format(**dataset.meta["inventory"])) - file_list = dataset.files.filenames + [folder / "test_dataset.asdf", - folder / "test_dataset.mp4", - folder / "test_dataset.pdf"] + file_list = [*dataset.files.filenames, folder / "test_dataset.asdf", folder / "test_dataset.mp4", folder / "test_dataset.pdf"] file_list = [base_path / fn for fn in file_list] dataset.files.download(path="/test/", **keywords) @@ -211,7 +207,7 @@ def test_download_keywords(dataset, orchestrate_transfer_mock, keywords): orchestrate_transfer_mock.assert_called_once_with( file_list, recursive=False, - destination_path=Path('/test'), + destination_path=Path("/test"), **keywords ) @@ -222,9 +218,7 @@ def test_download_keywords(dataset, orchestrate_transfer_mock, keywords): def test_download_path_interpolation(dataset, orchestrate_transfer_mock): base_path = Path(net.conf.dataset_path.format(**dataset.meta["inventory"])) folder = Path("/{bucket}/{primaryProposalId}/{datasetId}/".format(**dataset.meta["inventory"])) - file_list = dataset.files.filenames + [folder / "test_dataset.asdf", - folder / "test_dataset.mp4", - folder / "test_dataset.pdf"] + file_list = [*dataset.files.filenames, folder / "test_dataset.asdf", folder / "test_dataset.mp4", folder / "test_dataset.pdf"] file_list = [base_path / fn for fn in file_list] dataset.files.download(path="~/{dataset_id}") @@ -232,7 +226,7 @@ def test_download_path_interpolation(dataset, orchestrate_transfer_mock): orchestrate_transfer_mock.assert_called_once_with( file_list, recursive=False, - destination_path=Path('~/test_dataset/'), + destination_path=Path("~/test_dataset/"), destination_endpoint=None, progress=True, wait=True, @@ -254,10 +248,10 @@ def test_length_one_first_array_axis(small_visp_dataset): assert len(small_visp_dataset[:, 5, 5].files.filenames) == 3 -@pytest.mark.parametrize("kwargs", ( +@pytest.mark.parametrize("kwargs", [ {}, {"path": "~/", "overwrite": True} -)) +]) def test_download_quality(mocker, small_visp_dataset, kwargs): simple_download = mocker.patch("dkist.io.file_manager.Downloader.simple_download") from dkist.net import conf @@ -276,10 +270,10 @@ def test_download_quality(mocker, small_visp_dataset, kwargs): ) -@pytest.mark.parametrize("kwargs", ( +@pytest.mark.parametrize("kwargs", [ {}, {"path": "~/", "overwrite": True} -)) +]) def test_download_quality_movie(mocker, small_visp_dataset, kwargs): simple_download = mocker.patch("dkist.io.file_manager.Downloader.simple_download") from dkist.net import conf diff --git a/dkist/io/tests/test_fits.py b/dkist/io/tests/test_fits.py index a2540a5e..484d526b 100644 --- a/dkist/io/tests/test_fits.py +++ b/dkist/io/tests/test_fits.py @@ -10,7 +10,7 @@ from dkist.io.file_manager import FileManager from dkist.io.loaders import AstropyFITSLoader -eitdir = Path(rootdir) / 'EIT' +eitdir = Path(rootdir) / "EIT" @pytest.fixture diff --git a/dkist/logger.py b/dkist/logger.py index c8f526fb..c0162046 100644 --- a/dkist/logger.py +++ b/dkist/logger.py @@ -30,7 +30,7 @@ class DKISTLogger(logging.Logger): def __init__(self, name, level=logging.NOTSET, *, capture_warning_classes=None): super().__init__(name, level=level) - self.capture_warning_classes = tuple(capture_warning_classes) if capture_warning_classes is not None else tuple() + self.capture_warning_classes = tuple(capture_warning_classes) if capture_warning_classes is not None else () self.enable_warnings_capture() diff --git a/dkist/net/__init__.py b/dkist/net/__init__.py index 444118f1..7c4586a5 100644 --- a/dkist/net/__init__.py +++ b/dkist/net/__init__.py @@ -31,5 +31,5 @@ class Conf(_config.ConfigNamespace): conf = Conf() # Put imports after conf so that conf is initialized before import -from .client import DKISTClient # noqa -from .helpers import transfer_complete_datasets # noqa +from .client import DKISTClient +from .helpers import transfer_complete_datasets diff --git a/dkist/net/attr_walker.py b/dkist/net/attr_walker.py index 2ace7f69..b31561b3 100644 --- a/dkist/net/attr_walker.py +++ b/dkist/net/attr_walker.py @@ -36,7 +36,7 @@ def create_from_or(wlk, tree): @walker.add_creator(AttrAnd, DataAttr) def create_new_param(wlk, tree): - params = dict() + params = {} # Use the apply dispatcher to convert the attrs to their query parameters wlk.apply(tree, params) @@ -54,30 +54,30 @@ def iterate_over_and(wlk, tree, params): # SunPy Attrs @walker.add_applier(Time) def _(wlk, attr, params): - return params.update({'endTimeMin': attr.start.isot, - 'startTimeMax': attr.end.isot}) + return params.update({"endTimeMin": attr.start.isot, + "startTimeMax": attr.end.isot}) @walker.add_applier(Instrument) def _(wlk, attr, params): - return params.update({'instrumentNames': attr.value}) + return params.update({"instrumentNames": attr.value}) @walker.add_applier(Wavelength) def _(wlk, attr, params): - return params.update({'wavelengthRanges': [attr.min.to_value(u.nm), attr.max.to_value(u.nm)]}) + return params.update({"wavelengthRanges": [attr.min.to_value(u.nm), attr.max.to_value(u.nm)]}) @walker.add_applier(Physobs) def _(wlk, attr, params): if attr.value.lower() == "stokes_parameters": - return params.update({'hasAllStokes': True}) + return params.update({"hasAllStokes": True}) if attr.value.lower() == "intensity": - return params.update({'hasAllStokes': False}) + return params.update({"hasAllStokes": False}) if attr.value.lower() == "spectral_axis": - return params.update({'hasSpectralAxis': True}) + return params.update({"hasSpectralAxis": True}) if attr.value.lower() == "temporal_axis": - return params.update({'hasTemporalAxis': True}) + return params.update({"hasTemporalAxis": True}) # The client should not have accepted the query if we make it this far. raise ValueError(f"Physobs({attr.value}) is not supported by the DKIST client.") # pragma: no cover @@ -86,99 +86,99 @@ def _(wlk, attr, params): # DKIST Attrs @walker.add_applier(PageSize) def _(wlk, attr, params): - return params.update({'pageSize': attr.value}) + return params.update({"pageSize": attr.value}) @walker.add_applier(Page) def _(wlk, attr, params): - return params.update({'pageNumber': attr.value}) + return params.update({"pageNumber": attr.value}) @walker.add_applier(Dataset) def _(wlk, attr, params): - return params.update({'datasetIds': attr.value}) + return params.update({"datasetIds": attr.value}) @walker.add_applier(WavelengthBand) def _(wlk, attr, params): - return params.update({'filterWavelengths': attr.value}) + return params.update({"filterWavelengths": attr.value}) @walker.add_applier(Observable) def _(wlk, attr, params): - return params.update({'observables': attr.value}) + return params.update({"observables": attr.value}) @walker.add_applier(Experiment) def _(wlk, attr, params): - return params.update({'primaryExperimentIds': attr.value}) + return params.update({"primaryExperimentIds": attr.value}) @walker.add_applier(Proposal) def _(wlk, attr, params): - return params.update({'primaryProposalIds': attr.value}) + return params.update({"primaryProposalIds": attr.value}) @walker.add_applier(TargetType) def _(wlk, attr, params): - return params.update({'targetTypes': attr.value}) + return params.update({"targetTypes": attr.value}) @walker.add_applier(Recipe) def _(wlk, attr, params): - return params.update({'recipeId': attr.value}) + return params.update({"recipeId": attr.value}) @walker.add_applier(Embargoed) def _(wlk, attr, params): - return params.update({'isEmbargoed': bool(attr.value)}) + return params.update({"isEmbargoed": bool(attr.value)}) @walker.add_applier(FriedParameter) def _(wlk, attr, params): - return params.update({'qualityAverageFriedParameterMin': attr.min.to_value(u.cm), - 'qualityAverageFriedParameterMax': attr.max.to_value(u.cm)}) + return params.update({"qualityAverageFriedParameterMin": attr.min.to_value(u.cm), + "qualityAverageFriedParameterMax": attr.max.to_value(u.cm)}) @walker.add_applier(PolarimetricAccuracy) def _(wlk, attr, params): - return params.update({'qualityAveragePolarimetricAccuracyMin': attr.min, - 'qualityAveragePolarimetricAccuracyMax': attr.max}) + return params.update({"qualityAveragePolarimetricAccuracyMin": attr.min, + "qualityAveragePolarimetricAccuracyMax": attr.max}) @walker.add_applier(ExposureTime) def _(wlk, attr, params): - return params.update({'exposureTimeMin': attr.min.to_value(u.s), - 'exposureTimeMax': attr.max.to_value(u.s)}) + return params.update({"exposureTimeMin": attr.min.to_value(u.s), + "exposureTimeMax": attr.max.to_value(u.s)}) @walker.add_applier(EmbargoEndTime) def _(wlk, attr, params): - return params.update({'embargoEndDateMin': attr.start.isot, - 'embargoEndDateMax': attr.end.isot}) + return params.update({"embargoEndDateMin": attr.start.isot, + "embargoEndDateMax": attr.end.isot}) @walker.add_applier(SpectralSampling) def _(wlk, attr, params): - return params.update({'averageDatasetSpectralSamplingMin': attr.min.to_value(equivalencies=float), - 'averageDatasetSpectralSamplingMax': attr.max.to_value(equivalencies=float)}) + return params.update({"averageDatasetSpectralSamplingMin": attr.min.to_value(equivalencies=float), + "averageDatasetSpectralSamplingMax": attr.max.to_value(equivalencies=float)}) @walker.add_applier(SpatialSampling) def _(wlk, attr, params): - return params.update({'averageDatasetSpatialSamplingMin': attr.min.to_value(equivalencies=float), - 'averageDatasetSpatialSamplingMax': attr.max.to_value(equivalencies=float)}) + return params.update({"averageDatasetSpatialSamplingMin": attr.min.to_value(equivalencies=float), + "averageDatasetSpatialSamplingMax": attr.max.to_value(equivalencies=float)}) @walker.add_applier(TemporalSampling) def _(wlk, attr, params): - return params.update({'averageDatasetTemporalSamplingMin': attr.min.to_value(u.s), - 'averageDatasetTemporalSamplingMax': attr.max.to_value(u.s)}) + return params.update({"averageDatasetTemporalSamplingMin": attr.min.to_value(u.s), + "averageDatasetTemporalSamplingMax": attr.max.to_value(u.s)}) @walker.add_applier(BrowseMovie) def _(wlk, attr, params): values = {} if attr.movieurl: - values['browseMovieUrl'] = attr.movieurl + values["browseMovieUrl"] = attr.movieurl if attr.movieobjectkey: - values['browseMovieObjectKey'] = attr.movieobjectkey + values["browseMovieObjectKey"] = attr.movieobjectkey return params.update(values) @@ -190,7 +190,7 @@ def _(wlk, attr, params): "intersecting": "rectangleIntersectingBoundingBox"} # strip all spaces and the outer most () - return params.update({search_types[attr.search_type]: str(attr.hpc_bounding_box_arcsec).replace(' ', '')[1:-1]}) + return params.update({search_types[attr.search_type]: str(attr.hpc_bounding_box_arcsec).replace(" ", "")[1:-1]}) @walker.add_applier(Provider) @@ -209,29 +209,29 @@ def _(wlk, attr, params): @walker.add_applier(SummitSoftwareVersion) def _(wlk, attr, params): - return params.update({'highLevelSoftwareVersion': attr.value}) + return params.update({"highLevelSoftwareVersion": attr.value}) @walker.add_applier(WorkflowName) def _(wlk, attr, params): - return params.update({'workflowName': attr.value}) + return params.update({"workflowName": attr.value}) @walker.add_applier(WorkflowVersion) def _(wlk, attr, params): - return params.update({'workflowVersion': attr.value}) + return params.update({"workflowVersion": attr.value}) @walker.add_applier(ObservingProgramExecutionID) def _(wlk, attr, params): - return params.update({'observingProgramExecutionId': attr.value}) + return params.update({"observingProgramExecutionId": attr.value}) @walker.add_applier(InstrumentProgramExecutionID) def _(wlk, attr, params): - return params.update({'instrumentProgramExecutionId': attr.value}) + return params.update({"instrumentProgramExecutionId": attr.value}) @walker.add_applier(HeaderVersion) def _(wlk, attr, params): - return params.update({'headerVersion': attr.value}) + return params.update({"headerVersion": attr.value}) diff --git a/dkist/net/attrs.py b/dkist/net/attrs.py index 8aa9a4fe..5297e35c 100644 --- a/dkist/net/attrs.py +++ b/dkist/net/attrs.py @@ -3,7 +3,7 @@ Other attributes provided by `sunpy.net.attrs` are supported by the client. """ -import astropy.units as _u +import astropy.units as _u # noqa: ICN001 import sunpy.net._attrs as _sunpy_attrs from sunpy.coordinates.frames import Helioprojective as _Helioprojective @@ -12,13 +12,13 @@ from sunpy.net.attr import Range as _Range from sunpy.net.attr import SimpleAttr as _SimpleAttr -__all__ = ['PageSize', 'Page', 'Dataset', 'WavelengthBand', 'Embargoed', 'Observable', - 'Experiment', 'Proposal', 'TargetType', 'Recipe', - 'FriedParameter', 'PolarimetricAccuracy', 'ExposureTime', - 'EmbargoEndTime', 'BrowseMovie', 'BoundingBox', - 'SpectralSampling', 'SpatialSampling', 'TemporalSampling', 'SummitSoftwareVersion', - 'WorkflowName', 'WorkflowVersion', 'ObservingProgramExecutionID', - 'InstrumentProgramExecutionID', 'HeaderVersion'] +__all__ = ["PageSize", "Page", "Dataset", "WavelengthBand", "Embargoed", "Observable", + "Experiment", "Proposal", "TargetType", "Recipe", + "FriedParameter", "PolarimetricAccuracy", "ExposureTime", + "EmbargoEndTime", "BrowseMovie", "BoundingBox", + "SpectralSampling", "SpatialSampling", "TemporalSampling", "SummitSoftwareVersion", + "WorkflowName", "WorkflowVersion", "ObservingProgramExecutionID", + "InstrumentProgramExecutionID", "HeaderVersion"] # SimpleAttrs diff --git a/dkist/net/attrs_values.py b/dkist/net/attrs_values.py index 7f0461f5..ccee14c8 100644 --- a/dkist/net/attrs_values.py +++ b/dkist/net/attrs_values.py @@ -7,7 +7,7 @@ import platformdirs -from sunpy.net import attrs as sattrs +from sunpy.net import attrs as sattrs # noqa: ICN001 import dkist.data from dkist import log @@ -118,7 +118,7 @@ def attempt_local_update(*, timeout: int = 1, user_file: Path = None, silence_er # Test that the file we just saved can be parsed as json try: - with open(user_file, "r") as f: + with open(user_file) as f: json.load(f) except Exception: log.error("Downloaded file is not valid JSON.") @@ -158,7 +158,7 @@ def get_search_attrs_values(*, allow_update: bool = True, timeout: int = 1) -> d log.debug("No update to attr values needed.") log.debug("Using attr values from %s", local_path) - with open(local_path, "r") as f: + with open(local_path) as f: search_values = json.load(f) search_values = {param["parameterName"]: param["values"] for param in search_values["parameterValues"]} diff --git a/dkist/net/client.py b/dkist/net/client.py index fdb31950..b49bc411 100644 --- a/dkist/net/client.py +++ b/dkist/net/client.py @@ -2,10 +2,11 @@ import json import urllib.parse import urllib.request -from typing import Any, List, Mapping, Iterable +from typing import Any from textwrap import dedent from functools import partial from collections import defaultdict +from collections.abc import Mapping, Iterable import aiohttp import numpy as np @@ -16,7 +17,7 @@ from astropy.time import Time from sunpy.net import attr -from sunpy.net import attrs as sattrs +from sunpy.net import attrs as sattrs # noqa: ICN001 from sunpy.net.base_client import (BaseClient, QueryResponseRow, QueryResponseTable, convert_row_to_table) from sunpy.util.net import parse_header @@ -38,7 +39,7 @@ class DKISTQueryResponseTable(QueryResponseTable): # Define some class properties to better format the results table. # TODO: remove experimentDescription from this list, when we can limit the # length of the field to something nicer - hide_keys: List[str] = [ + hide_keys: list[str] = [ "Storage Bucket", "Full Stokes", "asdf Filename", @@ -85,13 +86,13 @@ def _process_table(results: "DKISTQueryResponseTable") -> "DKISTQueryResponseTab for colname in times: if colname not in results.colnames: continue # pragma: no cover - if not any([v is None for v in results[colname]]): + if not any(v is None for v in results[colname]): results[colname] = Time(results[colname]) for colname, unit in units.items(): if colname not in results.colnames: continue # pragma: no cover - none_values = np.array(results[colname] == None) + none_values = np.array(results[colname] == None) # E711 if none_values.any(): results[colname][none_values] = np.nan results[colname] = u.Quantity(results[colname], unit=unit) @@ -110,7 +111,7 @@ def from_results(cls, responses: Iterable[Mapping[str, Any]], *, client: "DKISTC total_available_results = 0 new_results = defaultdict(list) for response in responses: - total_available_results += response.get('recordCount', 0) + total_available_results += response.get("recordCount", 0) for result in response["searchResults"]: for key, value in result.items(): new_results[INVENTORY_KEY_MAP[key]].append(value) @@ -176,8 +177,8 @@ def search(self, *args) -> DKISTQueryResponseTable: results = [] for url_parameters in queries: - if 'pageSize' not in url_parameters: - url_parameters.update({'pageSize': conf.default_page_size}) + if "pageSize" not in url_parameters: + url_parameters.update({"pageSize": conf.default_page_size}) # TODO make this accept and concatenate multiple wavebands in a search query_string = urllib.parse.urlencode(url_parameters, doseq=True) full_url = f"{self._dataset_search_url}?{query_string}" @@ -201,7 +202,7 @@ def _make_filename(path: os.PathLike, row: QueryResponseRow, cdheader = resp.headers.get("Content-Disposition", None) if cdheader: _, params = parse_header(cdheader) - name = params.get('filename', "") + name = params.get("filename", "") return str(path).format(file=name, **row.response_block_map) @@ -237,7 +238,7 @@ def _can_handle_query(cls, *query) -> bool: supported = set(walker.applymm.registry) # This function is only called with arguments of the query where they are assumed to be ANDed. supported.remove(attr.AttrAnd) - query_attrs = set(type(x) for x in query) + query_attrs = {type(x) for x in query} # The DKIST client only requires that one or more of the support attrs be present. if not query_attrs.issubset(supported) or len(query_attrs.intersection(supported)) < 1: @@ -261,7 +262,7 @@ def _can_handle_query(cls, *query) -> bool: @classmethod def _attrs_module(cls): - return 'dkist', 'dkist.net.attrs' + return "dkist", "dkist.net.attrs" @classmethod def register_values(cls): diff --git a/dkist/net/globus/auth.py b/dkist/net/globus/auth.py index 36114bd7..52343621 100644 --- a/dkist/net/globus/auth.py +++ b/dkist/net/globus/auth.py @@ -17,12 +17,12 @@ import globus_sdk import platformdirs -CLIENT_ID = 'dd2d62af-0b44-4e2e-9454-1092c94b46b3' -SCOPES = ('urn:globus:auth:scope:transfer.api.globus.org:all', - 'openid') +CLIENT_ID = "dd2d62af-0b44-4e2e-9454-1092c94b46b3" +SCOPES = ("urn:globus:auth:scope:transfer.api.globus.org:all", + "openid") -__all__ = ['ensure_globus_authorized', 'get_refresh_token_authorizer'] +__all__ = ["ensure_globus_authorized", "get_refresh_token_authorizer"] class AuthenticationError(Exception): @@ -47,18 +47,18 @@ def wait_for_code(self): class RedirectHandler(BaseHTTPRequestHandler): def do_GET(self): self.send_response(200) - self.send_header('Content-type', 'text/html') + self.send_header("Content-type", "text/html") self.end_headers() - self.wfile.write(b'You\'re all set, you can close this window!') + self.wfile.write(b"You're all set, you can close this window!") - code = parse_qs(urlparse(self.path).query).get('code', [''])[0] + code = parse_qs(urlparse(self.path).query).get("code", [""])[0] self.server.return_code(code) def log_message(self, format, *args): return -def start_local_server(listen=('localhost', 0)): +def start_local_server(listen=("localhost", 0)): """ Start a server which will listen for the OAuth2 callback. @@ -91,12 +91,12 @@ def get_cache_contents(): cache_file = get_cache_file_path() if not cache_file.exists(): return {} - else: - try: - with open(cache_file) as fd: - return json.load(fd) - except (IOError, json.JSONDecodeError): - return {} + + try: + with open(cache_file) as fd: + return json.load(fd) + except (OSError, json.JSONDecodeError): + return {} def save_auth_cache(auth_cache): @@ -133,7 +133,7 @@ def do_native_app_authentication(client_id, requested_scopes=None): # pragma: n dict of tokens keyed by service name. """ server = start_local_server() - redirect_uri = "http://{a[0]}:{a[1]}".format(a=server.server_address) + redirect_uri = f"http://{server.server_address[0]}:{server.server_address[1]}" client = globus_sdk.NativeAppAuthClient(client_id=client_id) client.oauth2_start_flow(requested_scopes=SCOPES, @@ -141,7 +141,7 @@ def do_native_app_authentication(client_id, requested_scopes=None): # pragma: n refresh_tokens=True) url = client.oauth2_get_authorize_url() - result = webbrowser.open(url, new=1) + webbrowser.open(url, new=1) print("Waiting for completion of Globus Authentication in your webbrowser...") print(f"If your webbrowser has not opened, please go to {url} to authenticate with globus.") @@ -182,15 +182,15 @@ def get_refresh_token_authorizer(force_reauth=False): auth_client = globus_sdk.NativeAppAuthClient(client_id=CLIENT_ID) - transfer_tokens = tokens['transfer.api.globus.org'] + transfer_tokens = tokens["transfer.api.globus.org"] authorizers = {} for scope, transfer_tokens in tokens.items(): authorizers[scope] = globus_sdk.RefreshTokenAuthorizer( - transfer_tokens['refresh_token'], + transfer_tokens["refresh_token"], auth_client, - access_token=transfer_tokens['access_token'], - expires_at=transfer_tokens['expires_at_seconds'], + access_token=transfer_tokens["access_token"], + expires_at=transfer_tokens["expires_at_seconds"], on_refresh=save_auth_cache) return authorizers diff --git a/dkist/net/globus/endpoints.py b/dkist/net/globus/endpoints.py index 064ba92a..17357936 100644 --- a/dkist/net/globus/endpoints.py +++ b/dkist/net/globus/endpoints.py @@ -5,13 +5,13 @@ import urllib import pathlib import webbrowser -from functools import lru_cache +from functools import cache import globus_sdk from .auth import ensure_globus_authorized, get_refresh_token_authorizer -__all__ = ['get_data_center_endpoint_id', 'get_endpoint_id', 'get_directory_listing'] +__all__ = ["get_data_center_endpoint_id", "get_endpoint_id", "get_directory_listing"] def get_transfer_client(force_reauth=False): @@ -27,7 +27,7 @@ def get_transfer_client(force_reauth=False): ------- `globus_sdk.TransferClient` """ - auth = get_refresh_token_authorizer(force_reauth)['transfer.api.globus.org'] + auth = get_refresh_token_authorizer(force_reauth)["transfer.api.globus.org"] return globus_sdk.TransferClient(authorizer=auth) @@ -55,7 +55,7 @@ def get_local_endpoint_id(): return endpoint_id -@lru_cache(maxsize=None) +@cache def get_data_center_endpoint_id(): """ Query the data center for the current globus endpoint ID. @@ -93,7 +93,7 @@ def get_endpoint_id(endpoint, tfr_client): tr = None # If there is a space in the endpoint it's not an id - if ' ' not in endpoint: + if " " not in endpoint: try: tr = tfr_client.get_endpoint(endpoint) return endpoint @@ -106,17 +106,17 @@ def get_endpoint_id(endpoint, tfr_client): responses = tr.data["DATA"] + if len(responses) == 0: + raise ValueError(f"No matches found for endpoint '{endpoint}'") + if len(responses) > 1: - display_names = [a['display_name'] for a in responses] + display_names = [a["display_name"] for a in responses] # If we have one and only one exact display name match use that if display_names.count(endpoint) == 1: - return responses[display_names.index(endpoint)]['id'] + return responses[display_names.index(endpoint)]["id"] raise ValueError(f"Multiple matches for endpoint '{endpoint}': {display_names}") - elif len(responses) == 0: - raise ValueError(f"No matches found for endpoint '{endpoint}'") - - return responses[0]['id'] + return responses[0]["id"] @ensure_globus_authorized @@ -134,11 +134,11 @@ def auto_activate_endpoint(endpoint_id, tfr_client): # pragma: no cover """ activation = tfr_client.endpoint_get_activation_requirements(endpoint_id) - needs_activation = bool(activation['DATA']) - activated = activation['activated'] + needs_activation = bool(activation["DATA"]) + activated = activation["activated"] if needs_activation and not activated: r = tfr_client.endpoint_autoactivate(endpoint_id) - if r['code'] == "AutoActivationFailed": + if r["code"] == "AutoActivationFailed": webbrowser.open(f"https://app.globus.org/file-manager?origin_id={endpoint_id}", new=1) input("Press Return after completing activation in your webbrowser...") @@ -180,6 +180,6 @@ def get_directory_listing(path, endpoint=None): auto_activate_endpoint(endpoint_id, tc) response = tc.operation_ls(endpoint_id, path=path.as_posix()) - names = [r['name'] for r in response] + names = [r["name"] for r in response] return [path / n for n in names] diff --git a/dkist/net/globus/tests/__init__.py b/dkist/net/globus/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dkist/net/globus/tests/conftest.py b/dkist/net/globus/tests/conftest.py index 65ed7712..7dd38984 100644 --- a/dkist/net/globus/tests/conftest.py +++ b/dkist/net/globus/tests/conftest.py @@ -7,10 +7,10 @@ @pytest.fixture() def transfer_client(mocker): mocker.patch("globus_sdk.TransferClient.get_submission_id", - return_value={'value': "1234"}) + return_value={"value": "1234"}) mocker.patch("dkist.net.globus.endpoints.get_refresh_token_authorizer", - return_value={'transfer.api.globus.org': None}) + return_value={"transfer.api.globus.org": None}) tc = get_transfer_client() diff --git a/dkist/net/globus/tests/test_auth.py b/dkist/net/globus/tests/test_auth.py index 032f8ea9..3ad5a7fc 100644 --- a/dkist/net/globus/tests/test_auth.py +++ b/dkist/net/globus/tests/test_auth.py @@ -13,7 +13,7 @@ def test_http_server(): server = start_local_server() - redirect_uri = "http://{a[0]}:{a[1]}".format(a=server.server_address) + redirect_uri = f"http://{server.server_address[0]}:{server.server_address[1]}" inp_code = "wibble" requests.get(redirect_uri + f"?code={inp_code}") @@ -74,7 +74,7 @@ def test_save_auth_cache(mocker, tmpdir): assert bool(statinfo.mode & stat.S_IRUSR) assert bool(statinfo.mode & stat.S_IWUSR) - if platform.system() != 'Windows': + if platform.system() != "Windows": # Test that neither "Group" or "Other" have read permissions assert not bool(statinfo.mode & stat.S_IRGRP) assert not bool(statinfo.mode & stat.S_IROTH) @@ -94,12 +94,12 @@ def test_get_refresh_token_authorizer(mocker): } mocker.patch("dkist.net.globus.auth.get_cache_contents", return_value=cache) - auth = get_refresh_token_authorizer()['transfer.api.globus.org'] + auth = get_refresh_token_authorizer()["transfer.api.globus.org"] assert isinstance(auth, globus_sdk.RefreshTokenAuthorizer) assert auth.access_token == cache["transfer.api.globus.org"]["access_token"] mocker.patch("dkist.net.globus.auth.do_native_app_authentication", return_value=cache) - auth = get_refresh_token_authorizer(force_reauth=True)['transfer.api.globus.org'] + auth = get_refresh_token_authorizer(force_reauth=True)["transfer.api.globus.org"] assert isinstance(auth, globus_sdk.RefreshTokenAuthorizer) assert auth.access_token == cache["transfer.api.globus.org"]["access_token"] diff --git a/dkist/net/globus/tests/test_endpoints.py b/dkist/net/globus/tests/test_endpoints.py index 453376e9..cf5bc95b 100644 --- a/dkist/net/globus/tests/test_endpoints.py +++ b/dkist/net/globus/tests/test_endpoints.py @@ -48,7 +48,7 @@ def test_get_transfer_client(mocker, transfer_client): assert isinstance(transfer_client, globus_sdk.TransferClient) -@pytest.mark.parametrize("endpoint_id", ("12345", None)) +@pytest.mark.parametrize("endpoint_id", ["12345", None]) def test_get_local_endpoint_id(mocker, endpoint_id): lgcp_mock = mocker.patch("globus_sdk.LocalGlobusConnectPersonal.endpoint_id", new_callable=mocker.PropertyMock) @@ -68,13 +68,12 @@ def test_get_endpoint_id_search(mocker, mock_search, endpoint_search, transfer_c transfer_client = get_transfer_client() # Test exact display name match - endpoint_id = get_endpoint_id('NCAR Data Sharing Service', transfer_client) + endpoint_id = get_endpoint_id("NCAR Data Sharing Service", transfer_client) assert endpoint_id == "dd1ee92a-6d04-11e5-ba46-22000b92c6ec" # Test multiple match fail - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match="Multiple"): get_endpoint_id(" ", transfer_client) - assert "Multiple" in str(exc.value) # Test just one result mock_search.return_value = {"DATA": endpoint_search["DATA"][1:2]} @@ -83,9 +82,8 @@ def test_get_endpoint_id_search(mocker, mock_search, endpoint_search, transfer_c # Test no results mock_search.return_value = {"DATA": []} - with pytest.raises(ValueError) as e_info: + with pytest.raises(ValueError, match="No matches"): get_endpoint_id(" ", transfer_client) - assert "No matches" in str(e_info.value) def test_get_endpoint_id_uuid(mocker, transfer_client, endpoint_search): @@ -95,7 +93,7 @@ def test_get_endpoint_id_uuid(mocker, transfer_client, endpoint_search): new_callable=mocker.PropertyMock) get_ep_mock.return_value = {"DATA": endpoint_search["DATA"][1:2]} - endpoint_id = get_endpoint_id('dd1ee92a-6d04-11e5-ba46-22000b92c6ec', transfer_client) + endpoint_id = get_endpoint_id("dd1ee92a-6d04-11e5-ba46-22000b92c6ec", transfer_client) assert endpoint_id == "dd1ee92a-6d04-11e5-ba46-22000b92c6ec" @@ -124,11 +122,11 @@ def test_directory_listing(mocker, transfer_client, ls_response): return_value=ls_response) ls = get_directory_listing("/") - assert all([isinstance(a, pathlib.Path) for a in ls]) + assert all(isinstance(a, pathlib.Path) for a in ls) assert len(ls) == 13 ls = get_directory_listing("/", "1234") - assert all([isinstance(a, pathlib.Path) for a in ls]) + assert all(isinstance(a, pathlib.Path) for a in ls) assert len(ls) == 13 diff --git a/dkist/net/globus/tests/test_transfer.py b/dkist/net/globus/tests/test_transfer.py index 93b30c6e..35f186c8 100644 --- a/dkist/net/globus/tests/test_transfer.py +++ b/dkist/net/globus/tests/test_transfer.py @@ -33,30 +33,30 @@ def mock_task_event_list(mocker, transfer_client): { "DATA": [ { - 'DATA_TYPE': 'event', - 'code': 'STARTED', - 'description': 'started', - 'details': + "DATA_TYPE": "event", + "code": "STARTED", + "description": "started", + "details": '{\n "type": "GridFTP Transfer", \n "concurrency": 2, \n "protocol": "Mode S"\n}', - 'is_error': False, - 'parent_task_id': None, - 'time': '2019-05-16 10:13:26+00:00'}, + "is_error": False, + "parent_task_id": None, + "time": "2019-05-16 10:13:26+00:00"}, { - 'DATA_TYPE': 'event', - 'code': 'SUCCEEDED', - 'description': 'succeeded', - 'details': 'Scanned 100 file(s)', - 'is_error': False, - 'parent_task_id': None, - 'time': '2019-05-16 10:13:24+00:00'}, + "DATA_TYPE": "event", + "code": "SUCCEEDED", + "description": "succeeded", + "details": "Scanned 100 file(s)", + "is_error": False, + "parent_task_id": None, + "time": "2019-05-16 10:13:24+00:00"}, { - 'DATA_TYPE': 'event', - 'code': 'STARTED', - 'description': 'started', - 'details': 'Starting sync scan', - 'is_error': False, - 'parent_task_id': None, - 'time': '2019-05-16 10:13:20+00:00'}, + "DATA_TYPE": "event", + "code": "STARTED", + "description": "started", + "details": "Starting sync scan", + "is_error": False, + "parent_task_id": None, + "time": "2019-05-16 10:13:20+00:00"}, ], "DATA_TYPE": "event_list", "limit": 10, @@ -75,7 +75,7 @@ def test_start_transfer(mocker, transfer_client, mock_endpoints): submit_mock = mocker.patch("globus_sdk.TransferClient.submit_transfer", return_value={"task_id": "task_id"}) mocker.patch("globus_sdk.TransferClient.get_submission_id", - return_value={'value': "wibble"}) + return_value={"value": "wibble"}) file_list = list(map(Path, ["/a/name.fits", "/a/name2.fits"])) start_transfer_from_file_list("a", "b", "/", file_list) calls = mock_endpoints.call_args_list @@ -83,11 +83,11 @@ def test_start_transfer(mocker, transfer_client, mock_endpoints): assert calls[1][0][0] == "b" submit_mock.assert_called_once() - transfer_manifest = submit_mock.call_args_list[0][0][0]['DATA'] + transfer_manifest = submit_mock.call_args_list[0][0][0]["DATA"] for filepath, tfr in zip(file_list, transfer_manifest): - assert str(filepath) == tfr['source_path'] - assert os.path.sep + filepath.name == tfr['destination_path'] + assert str(filepath) == tfr["source_path"] + assert os.path.sep + filepath.name == tfr["destination_path"] def test_start_transfer_src_base(mocker, transfer_client, mock_endpoints): @@ -100,11 +100,11 @@ def test_start_transfer_src_base(mocker, transfer_client, mock_endpoints): assert calls[1][0][0] == "b" submit_mock.assert_called_once() - transfer_manifest = submit_mock.call_args_list[0][0][0]['DATA'] + transfer_manifest = submit_mock.call_args_list[0][0][0]["DATA"] for filepath, tfr in zip(file_list, transfer_manifest): - assert str(filepath) == tfr['source_path'] - assert "{0}b{0}".format(os.path.sep) + filepath.name == tfr['destination_path'] + assert str(filepath) == tfr["source_path"] + assert f"{os.path.sep}b{os.path.sep}" + filepath.name == tfr["destination_path"] def test_process_event_list(transfer_client, mock_task_event_list): @@ -113,26 +113,25 @@ def test_process_event_list(transfer_client, mock_task_event_list): message_events) = _process_task_events("1234", set(), transfer_client) assert isinstance(events, set) - assert all([isinstance(e, tuple) for e in events]) - assert all([all([isinstance(item, tuple) for item in e]) for e in events]) + assert all(isinstance(e, tuple) for e in events) + assert all(all(isinstance(item, tuple) for item in e) for e in events) print(events) assert len(json_events) == 1 assert isinstance(json_events, tuple) assert isinstance(json_events[0], dict) - assert isinstance(json_events[0]['details'], dict) - assert json_events[0]['code'] == 'STARTED' + assert isinstance(json_events[0]["details"], dict) + assert json_events[0]["code"] == "STARTED" assert len(message_events) == 2 assert isinstance(message_events, tuple) assert isinstance(message_events[0], dict) - assert isinstance(message_events[0]['details'], str) + assert isinstance(message_events[0]["details"], str) def test_process_event_list_message_only(transfer_client, mock_task_event_list): # Filter out the json event - prev_events = tuple(map(lambda x: tuple(x.items()), - mock_task_event_list.return_value)) + prev_events = tuple(tuple(x.items()) for x in mock_task_event_list.return_value) prev_events = set(prev_events[0:1]) (events, @@ -140,8 +139,8 @@ def test_process_event_list_message_only(transfer_client, mock_task_event_list): message_events) = _process_task_events("1234", prev_events, transfer_client) assert isinstance(events, set) - assert all([isinstance(e, tuple) for e in events]) - assert all([all([isinstance(item, tuple) for item in e]) for e in events]) + assert all(isinstance(e, tuple) for e in events) + assert all(all(isinstance(item, tuple) for item in e) for e in events) assert len(json_events) == 0 assert isinstance(json_events, tuple) @@ -149,17 +148,17 @@ def test_process_event_list_message_only(transfer_client, mock_task_event_list): assert len(message_events) == 2 assert isinstance(message_events, tuple) assert isinstance(message_events[0], dict) - assert isinstance(message_events[0]['details'], str) + assert isinstance(message_events[0]["details"], str) def test_get_speed(): - speed = _get_speed({'code': "PROGRESS", 'details': {'mbps': 10}}) + speed = _get_speed({"code": "PROGRESS", "details": {"mbps": 10}}) assert speed == 10 speed = _get_speed({}) assert speed is None - speed = _get_speed({'code': "progress", "details": "hello"}) + speed = _get_speed({"code": "progress", "details": "hello"}) assert speed is None - speed = _get_speed({'code': "progress", "details": {"hello": "world"}}) + speed = _get_speed({"code": "progress", "details": {"hello": "world"}}) assert speed is None diff --git a/dkist/net/globus/transfer.py b/dkist/net/globus/transfer.py index 9f92224a..90e711bb 100644 --- a/dkist/net/globus/transfer.py +++ b/dkist/net/globus/transfer.py @@ -7,7 +7,7 @@ import pathlib import datetime from os import PathLike -from typing import List, Union, Literal +from typing import Literal import globus_sdk from tqdm.auto import tqdm @@ -16,7 +16,7 @@ from .endpoints import (auto_activate_endpoint, get_data_center_endpoint_id, get_endpoint_id, get_local_endpoint_id, get_transfer_client) -__all__ = ['watch_transfer_progress', 'start_transfer_from_file_list'] +__all__ = ["watch_transfer_progress", "start_transfer_from_file_list"] def start_transfer_from_file_list(src_endpoint, dst_endpoint, dst_base_path, file_list, @@ -134,8 +134,7 @@ def _process_task_events(task_id, prev_events, tfr_client): """ # Convert all the events into a (key, value) tuple pair - events = set(map(lambda x: tuple(x.items()), - tfr_client.task_event_list(task_id))) + events = {tuple(x.items()) for x in tfr_client.task_event_list(task_id)} # Drop all events we have seen before new_events = events.difference(prev_events) @@ -146,14 +145,14 @@ def _process_task_events(task_id, prev_events, tfr_client): def json_loader(x): """Modify the event so the json is a dict.""" - x['details'] = json.loads(x['details']) + x["details"] = json.loads(x["details"]) return x # If some of the events are json events, load the json. if json_events: json_events = tuple(map(dict, map(json_loader, map(dict, json_events)))) else: - json_events = tuple() + json_events = () return events, json_events, message_events @@ -162,8 +161,8 @@ def _get_speed(event): """ A helper function to extract the speed from an event. """ - if event.get('code', "").lower() == "progress" and isinstance(event['details'], dict): - return event['details'].get("mbps") + if event.get("code", "").lower() == "progress" and isinstance(event["details"], dict): + return event["details"].get("mbps") def get_progress_bar(*args, **kwargs): # pragma: no cover @@ -172,10 +171,10 @@ def get_progress_bar(*args, **kwargs): # pragma: no cover """ notebook = tqdm is tqdm_notebook if not notebook: - kwargs['bar_format'] = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]' + kwargs["bar_format"] = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]" else: # TODO: Both having this and not having it breaks things. - kwargs['total'] = kwargs.get("total", 1e9) or 1e9 + kwargs["total"] = kwargs.get("total", 1e9) or 1e9 return tqdm(*args, **kwargs) @@ -212,13 +211,13 @@ def watch_transfer_progress(task_id, tfr_client, poll_interval=5, json_events, message_events) = _process_task_events(task_id, prev_events, tfr_client) - if ('code', 'STARTED') not in prev_events and not started: + if ("code", "STARTED") not in prev_events and not started: started = True progress.write("PENDING: Starting Transfer") # Print status messages if verbose or if they are errors for event in message_events: - if event['is_error'] or verbose: + if event["is_error"] or verbose: progress.write(f"{event['code']}: {event['details']}") for event in json_events: @@ -245,9 +244,9 @@ def watch_transfer_progress(task_id, tfr_client, poll_interval=5, # Get the status of the task to see how many files we have processed. task = tfr_client.get_task(task_id) - status = task['status'] - progress.total = task['files'] - progress.update((task['files_skipped'] + task['files_transferred']) - progress.n) + status = task["status"] + progress.total = task["files"] + progress.update((task["files_skipped"] + task["files_transferred"]) - progress.n) # If the status of the task is not active we are finished. if status != "ACTIVE": @@ -264,12 +263,12 @@ def watch_transfer_progress(task_id, tfr_client, poll_interval=5, progress.close() -def _orchestrate_transfer_task(file_list: List[PathLike], - recursive: List[bool], +def _orchestrate_transfer_task(file_list: list[PathLike], + recursive: list[bool], destination_path: PathLike = "/~/", destination_endpoint: str = None, *, - progress: Union[bool, Literal["verbose"]] = True, + progress: bool | Literal["verbose"] = True, wait: bool = True, label=None): """ diff --git a/dkist/net/helpers.py b/dkist/net/helpers.py index 512519aa..7627d06a 100644 --- a/dkist/net/helpers.py +++ b/dkist/net/helpers.py @@ -3,8 +3,9 @@ """ import datetime from os import PathLike -from typing import List, Union, Literal, Iterable, Optional +from typing import Literal from pathlib import Path +from collections.abc import Iterable from astropy import table @@ -20,7 +21,7 @@ __all__ = ["transfer_complete_datasets"] -def _get_dataset_inventory(dataset_id: Union[str, Iterable[str]]) -> DKISTQueryResponseTable: # pragma: no cover +def _get_dataset_inventory(dataset_id: str | Iterable[str]) -> DKISTQueryResponseTable: # pragma: no cover """ Do a search for a single dataset id """ @@ -34,12 +35,12 @@ def _get_dataset_inventory(dataset_id: Union[str, Iterable[str]]) -> DKISTQueryR return results -def transfer_complete_datasets(datasets: Union[str, Iterable[str], QueryResponseRow, DKISTQueryResponseTable, UnifiedResponse], +def transfer_complete_datasets(datasets: str | Iterable[str] | QueryResponseRow | DKISTQueryResponseTable | UnifiedResponse, path: PathLike = "/~/", destination_endpoint: str = None, - progress: Union[bool, Literal["verbose"]] = True, + progress: bool | Literal["verbose"] = True, wait: bool = True, - label: Optional[str] = None) -> Union[List[str], str]: + label: str | None = None) -> list[str] | str: """ Transfer one or more complete datasets to a path on a globus endpoint. @@ -100,7 +101,7 @@ def transfer_complete_datasets(datasets: Union[str, Iterable[str], QueryResponse if len(datasets) > 1: datasets = table.vstack(datasets, metadata_conflicts="silent") - elif isinstance(datasets, str) or all((isinstance(d, str) for d in datasets)): + elif isinstance(datasets, str) or all(isinstance(d, str) for d in datasets): # If we are passed just dataset IDs as strings search for them to get the inventory records datasets = _get_dataset_inventory(datasets) diff --git a/dkist/net/tests/__init__.py b/dkist/net/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dkist/net/tests/conftest.py b/dkist/net/tests/conftest.py index c56d1be4..50521c4e 100644 --- a/dkist/net/tests/conftest.py +++ b/dkist/net/tests/conftest.py @@ -14,7 +14,7 @@ def all_dkist_attrs_classes(request): return getattr(da, request.param) -@pytest.fixture(params=da.__all__ + ['Time', 'Instrument', 'Wavelength', 'Physobs']) +@pytest.fixture(params=[*da.__all__, "Time", "Instrument", "Wavelength", "Physobs"]) def all_attrs_classes(request): at = getattr(da, request.param, None) return at or getattr(a, request.param) @@ -28,32 +28,32 @@ def api_param_names(): Excludes ones with input dependent query params """ return { - a.Time: ('endTimeMin', 'startTimeMax'), - a.Instrument: ('instrumentNames',), - a.Wavelength: ('wavelengthRanges',), - a.Physobs: ('hasAllStokes',), - a.Provider: tuple(), - da.Dataset: ('datasetIds',), - da.WavelengthBand: ('filterWavelengths',), - da.Observable: ('observables',), - da.Experiment: ('primaryExperimentIds',), - da.Proposal: ('primaryProposalIds',), - da.TargetType: ('targetTypes',), - da.Recipe: ('recipeId',), - da.Embargoed: ('isEmbargoed',), - da.FriedParameter: ('qualityAverageFriedParameterMin', 'qualityAverageFriedParameterMax'), - da.PolarimetricAccuracy: ('qualityAveragePolarimetricAccuracyMin', 'qualityAveragePolarimetricAccuracyMax'), - da.ExposureTime: ('exposureTimeMin', 'exposureTimeMax'), - da.EmbargoEndTime: ('embargoEndDateMin', 'embargoEndDateMax'), - da.SpectralSampling: ('averageDatasetSpectralSamplingMin', 'averageDatasetSpectralSamplingMax'), - da.SpatialSampling: ('averageDatasetSpatialSamplingMin', 'averageDatasetSpatialSamplingMax'), - da.TemporalSampling: ('averageDatasetTemporalSamplingMin', 'averageDatasetTemporalSamplingMax'), - da.Page: ('pageNumber',), - da.PageSize: ('pageSize',), - da.SummitSoftwareVersion: ('highLevelSoftwareVersion',), - da.WorkflowName: ('workflowName',), - da.WorkflowVersion: ('workflowVersion',), - da.ObservingProgramExecutionID: ('observingProgramExecutionId',), - da.InstrumentProgramExecutionID: ('instrumentProgramExecutionId',), - da.HeaderVersion: ('headerVersion',), + a.Time: ("endTimeMin", "startTimeMax"), + a.Instrument: ("instrumentNames",), + a.Wavelength: ("wavelengthRanges",), + a.Physobs: ("hasAllStokes",), + a.Provider: (), + da.Dataset: ("datasetIds",), + da.WavelengthBand: ("filterWavelengths",), + da.Observable: ("observables",), + da.Experiment: ("primaryExperimentIds",), + da.Proposal: ("primaryProposalIds",), + da.TargetType: ("targetTypes",), + da.Recipe: ("recipeId",), + da.Embargoed: ("isEmbargoed",), + da.FriedParameter: ("qualityAverageFriedParameterMin", "qualityAverageFriedParameterMax"), + da.PolarimetricAccuracy: ("qualityAveragePolarimetricAccuracyMin", "qualityAveragePolarimetricAccuracyMax"), + da.ExposureTime: ("exposureTimeMin", "exposureTimeMax"), + da.EmbargoEndTime: ("embargoEndDateMin", "embargoEndDateMax"), + da.SpectralSampling: ("averageDatasetSpectralSamplingMin", "averageDatasetSpectralSamplingMax"), + da.SpatialSampling: ("averageDatasetSpatialSamplingMin", "averageDatasetSpatialSamplingMax"), + da.TemporalSampling: ("averageDatasetTemporalSamplingMin", "averageDatasetTemporalSamplingMax"), + da.Page: ("pageNumber",), + da.PageSize: ("pageSize",), + da.SummitSoftwareVersion: ("highLevelSoftwareVersion",), + da.WorkflowName: ("workflowName",), + da.WorkflowVersion: ("workflowVersion",), + da.ObservingProgramExecutionID: ("observingProgramExecutionId",), + da.InstrumentProgramExecutionID: ("instrumentProgramExecutionId",), + da.HeaderVersion: ("headerVersion",), } diff --git a/dkist/net/tests/strategies.py b/dkist/net/tests/strategies.py index 27207dd8..c3837381 100644 --- a/dkist/net/tests/strategies.py +++ b/dkist/net/tests/strategies.py @@ -29,7 +29,7 @@ def get_registered_values(): def _generate_from_register_values(attr_type): possible_values = get_registered_values()[attr_type] - possible_values = list(map(lambda x: x[0], possible_values)) + possible_values = [x[0] for x in possible_values] return st.builds(attr_type, st.sampled_from(possible_values)) @@ -44,7 +44,7 @@ def _supported_attr_types(): @st.composite def _browse_movie(draw): - return a.dkist.BrowseMovie(**draw(st.dictionaries(st.sampled_from(('movieurl', 'movieobjectkey')), + return a.dkist.BrowseMovie(**draw(st.dictionaries(st.sampled_from(("movieurl", "movieobjectkey")), st.text(), min_size=1))) diff --git a/dkist/net/tests/test_attr_walker.py b/dkist/net/tests/test_attr_walker.py index d21dc7bc..440cd794 100644 --- a/dkist/net/tests/test_attr_walker.py +++ b/dkist/net/tests/test_attr_walker.py @@ -24,26 +24,26 @@ def query_or_instrument(): """ return (a.Instrument("VBI") | a.Instrument("VISP")) & a.Time("2020/06/01", "2020/06/02") -@pytest.fixture(scope="function") +@pytest.fixture() def boundingbox_params(): """ Create possible bounding box input coordinates and args for inputs to the bounding box tests. """ bottom_left_icrs = SkyCoord(ICRS(ra=1 * u.deg, dec=2 * u.deg, distance=150000000 * u.km), - obstime='2021-01-02T12:34:56') + obstime="2021-01-02T12:34:56") top_right_icrs = SkyCoord(ICRS(ra=3 * u.deg, dec=4 * u.deg, distance=150000000 * u.km), - obstime='2021-01-02T12:34:56') + obstime="2021-01-02T12:34:56") bottom_left_vector_icrs = SkyCoord([ICRS(ra=1 * u.deg, dec=2 * u.deg, distance=150000000 * u.km), ICRS(ra=3 * u.deg, dec=4 * u.deg, distance=150000000 * u.km)], - obstime='2021-01-02T12:34:56') - bottom_left = SkyCoord(1 * u.deg, 1 * u.deg, frame='heliographic_stonyhurst', obstime='2021-01-02T12:34:56') - top_right = SkyCoord(2 * u.deg, 2 * u.deg, frame='heliographic_stonyhurst', obstime='2021-01-02T12:34:56') + obstime="2021-01-02T12:34:56") + bottom_left = SkyCoord(1 * u.deg, 1 * u.deg, frame="heliographic_stonyhurst", obstime="2021-01-02T12:34:56") + top_right = SkyCoord(2 * u.deg, 2 * u.deg, frame="heliographic_stonyhurst", obstime="2021-01-02T12:34:56") width = 3.4 * u.deg height = 1.2 * u.deg - yield { + return { # bottom_left, top_right, width, height "bottom left vector icrs": [bottom_left_vector_icrs, None, None, None], "bottom left top right icrs": [bottom_left_icrs, top_right_icrs, None, None], @@ -52,13 +52,12 @@ def boundingbox_params(): } -@pytest.fixture(scope="function", - params=["bottom left vector icrs", +@pytest.fixture(params=["bottom left vector icrs", "bottom left top right icrs", "bottom left top right", "bottom left width height",],) def boundingbox_param(request, boundingbox_params): - yield boundingbox_params[request.param] + return boundingbox_params[request.param] def test_walker_single(all_attrs_classes, api_param_names): @@ -92,14 +91,14 @@ def test_walker_single(all_attrs_classes, api_param_names): elif issubclass(all_attrs_classes, da.BrowseMovie): at = all_attrs_classes(movieurl="klsdjalkjd", movieobjectkey="lkajsd") - api_param_names[all_attrs_classes] = ('browseMovieUrl', 'browseMovieObjectKey') + api_param_names[all_attrs_classes] = ("browseMovieUrl", "browseMovieObjectKey") elif issubclass(all_attrs_classes, da.BoundingBox): bottom_left = SkyCoord([ICRS(ra=1 * u.deg, dec=2 * u.deg, distance=150000000 * u.km), ICRS(ra=3 * u.deg, dec=4 * u.deg, distance=150000000 * u.km)], - obstime='2021-01-02T12:34:56') + obstime="2021-01-02T12:34:56") at = all_attrs_classes(bottom_left=bottom_left) - api_param_names[all_attrs_classes] = ('rectangleContainingBoundingBox',) + api_param_names[all_attrs_classes] = ("rectangleContainingBoundingBox",) if not at: pytest.skip(f"Not testing {all_attrs_classes!r}") @@ -112,11 +111,11 @@ def test_walker_single(all_attrs_classes, api_param_names): assert not set(api_param_names[all_attrs_classes]).difference(params[0].keys()) -@pytest.mark.parametrize("search,search_type", +@pytest.mark.parametrize(("search", "search_type"), [ - ('containing', 'rectangleContainingBoundingBox'), - ('contained', 'rectangleContainedByBoundingBox'), - ('intersecting', 'rectangleIntersectingBoundingBox'), + ("containing", "rectangleContainingBoundingBox"), + ("contained", "rectangleContainedByBoundingBox"), + ("intersecting", "rectangleIntersectingBoundingBox"), ] ) def test_boundingbox(search, search_type, boundingbox_param): @@ -125,7 +124,7 @@ def test_boundingbox(search, search_type, boundingbox_param): out = walker.create(bb_query) assert len(out) == 1 - assert all([isinstance(a, dict) for a in out]) + assert all(isinstance(a, dict) for a in out) # can't verify exact coordinates, they change a bit for key in out[0].keys(): @@ -133,11 +132,11 @@ def test_boundingbox(search, search_type, boundingbox_param): for value in out[0].values(): # want to make sure the value is of the format (flt, flt), (flt, flt) - coordinate_regex = re.compile(r'^(\()(-?\d+)(\.\d+)?(,)(-?\d+)(\.\d+)?(\))(,)(\()(-?\d+)(\.\d+)?(,)(-?\d+)(\.\d+)?(\))$') + coordinate_regex = re.compile(r"^(\()(-?\d+)(\.\d+)?(,)(-?\d+)(\.\d+)?(\))(,)(\()(-?\d+)(\.\d+)?(,)(-?\d+)(\.\d+)?(\))$") assert coordinate_regex.search(value) def test_args_browsemovie(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Either movieurl or movieobjectkey must be specified"): da.BrowseMovie() @@ -162,7 +161,7 @@ def test_and_simple(query_and_simple): out = walker.create(query_and_simple) assert len(out) == 1 assert isinstance(out, list) - assert all([isinstance(a, dict) for a in out]) + assert all(isinstance(a, dict) for a in out) assert out == [ { @@ -177,7 +176,7 @@ def test_or_instrument(query_or_instrument): out = walker.create(query_or_instrument) assert len(out) == 2 assert isinstance(out, list) - assert all([isinstance(a, dict) for a in out]) + assert all(isinstance(a, dict) for a in out) assert out == [ { diff --git a/dkist/net/tests/test_attrs.py b/dkist/net/tests/test_attrs.py index 6ae46465..74252a5b 100644 --- a/dkist/net/tests/test_attrs.py +++ b/dkist/net/tests/test_attrs.py @@ -11,8 +11,8 @@ def test_embargoed_inputs(): assert not da.Embargoed.false.value assert da.Embargoed.true.value - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="is_embargoed must be either True or False"): da.Embargoed("neither up nor down") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="is_embargoed must be either True or False"): da.Embargoed(42) diff --git a/dkist/net/tests/test_attrs_values.py b/dkist/net/tests/test_attrs_values.py index d5f1e59d..34305a7b 100644 --- a/dkist/net/tests/test_attrs_values.py +++ b/dkist/net/tests/test_attrs_values.py @@ -126,7 +126,7 @@ def test_attempt_local_update_error_download(mocker, caplog_dkist, tmp_homedir, ("dkist", logging.ERROR, "Failed to download new attrs values."), ] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="This is a value error"): success = attempt_local_update(silence_errors=False) @@ -151,7 +151,7 @@ def test_attempt_local_update_fail_invalid_download(mocker, tmp_path, caplog_dki success = attempt_local_update(user_file=json_file, silence_errors=False) -@pytest.mark.parametrize("user_file, update_needed, allow_update, should_update", [ +@pytest.mark.parametrize(("user_file", "update_needed", "allow_update", "should_update"), [ ("user_file", False, True, False), ("user_file", True, True, True), ("user_file", True, False, False), @@ -171,4 +171,4 @@ def test_get_search_attrs_values(mocker, caplog_dkist, values_in_home, user_file assert isinstance(attr_values, dict) # Test that some known attrs are in the result - assert set((a.Instrument, a.dkist.HeaderVersion, a.dkist.WorkflowName)).issubset(attr_values.keys()) + assert {a.Instrument, a.dkist.HeaderVersion, a.dkist.WorkflowName}.issubset(attr_values.keys()) diff --git a/dkist/net/tests/test_client.py b/dkist/net/tests/test_client.py index 998700a8..85256095 100644 --- a/dkist/net/tests/test_client.py +++ b/dkist/net/tests/test_client.py @@ -1,6 +1,6 @@ import json -import hypothesis.strategies as st # noqa +import hypothesis.strategies as st import parfive import pytest from hypothesis import HealthCheck, given, settings @@ -12,7 +12,7 @@ import dkist.net from dkist.net.client import DKISTClient, DKISTQueryResponseTable -from dkist.net.tests import strategies as dst # noqa +from dkist.net.tests import strategies as dst from dkist.utils.inventory import INVENTORY_KEY_MAP @@ -25,7 +25,7 @@ def client(): @pytest.mark.remote_data def test_search(client): # TODO: Write an online test to verify real behaviour once there is stable data - res = client.search(a.Time("2019/01/01", "2021/01/01")) + client.search(a.Time("2019/01/01", "2021/01/01")) @pytest.mark.remote_data @@ -39,7 +39,8 @@ def test_search_by_time(client, time): res = client.search(time, a.Instrument("VBI")) assert len(res) == 1 assert res[0]["Primary Proposal ID"] == "pid_1_50" - assert res[0]["Start Time"].value == '2022-12-27T19:27:42.338' and res[0]["End Time"].value == '2022-12-27T20:00:09.005' + assert res[0]["Start Time"].value == "2022-12-27T19:27:42.338" + assert res[0]["End Time"].value == "2022-12-27T20:00:09.005" @pytest.fixture def empty_query_response(): @@ -98,8 +99,8 @@ def example_api_response(): @pytest.fixture def expected_table_keys(): translated_keys = set(INVENTORY_KEY_MAP.values()) - removed_keys = {'Wavelength Min', 'Wavelength Max'} - added_keys = {'Wavelength'} + removed_keys = {"Wavelength Min", "Wavelength Max"} + added_keys = {"Wavelength"} expected_keys = translated_keys - removed_keys expected_keys.update(added_keys) return expected_keys @@ -134,14 +135,14 @@ def test_query_response_from_results_unknown_field(empty_query_response, example This test asserts that if the API starts returning new fields we don't error, they get passed though verbatim. """ dclient = DKISTClient() - example_api_response["searchResults"][0].update({'spamEggs': 'Some Spam'}) + example_api_response["searchResults"][0].update({"spamEggs": "Some Spam"}) qr = DKISTQueryResponseTable.from_results([example_api_response], client=dclient) assert len(qr) == 1 assert isinstance(qr.client, DKISTClient) assert qr.client is dclient assert isinstance(qr[0], QueryResponseRow) - assert set(qr.colnames).difference(expected_table_keys) == {'spamEggs'} + assert set(qr.colnames).difference(expected_table_keys) == {"spamEggs"} assert set(qr.colnames).isdisjoint(INVENTORY_KEY_MAP.keys()) @@ -202,14 +203,14 @@ def test_can_handle_query(client, query): assert client._can_handle_query(query) -@pytest.mark.parametrize("query", ( +@pytest.mark.parametrize("query", [ a.Instrument("bob"), a.Physobs("who's got the button"), a.Level(2), (a.Instrument("VBI"), a.Level(0)), (a.Instrument("VBI"), a.Detector("test")), - tuple(), -)) + (), +]) def test_cant_handle_query(client, query): """Some examples of invalid queries.""" assert not client._can_handle_query(query) @@ -220,7 +221,7 @@ def test_cant_handle_query(client, query): @given(st.one_of(dst.query_and(), dst.query_or(), dst.query_or_composite())) def test_fido_valid(mocker, mocked_client, query): # Test that Fido is passing through our queries to our client - mocked_search = mocker.patch('dkist.net.client.DKISTClient.search') + mocked_search = mocker.patch("dkist.net.client.DKISTClient.search") mocked_search.return_value = DKISTQueryResponseTable() Fido.search(query) @@ -240,7 +241,7 @@ def test_fetch_with_headers(httpserver, tmpdir, mocked_client): headers={"Content-Disposition": "attachment; filename=abcd.asdf"} ) - response = DKISTQueryResponseTable({'Dataset ID': ['abcd']}) + response = DKISTQueryResponseTable({"Dataset ID": ["abcd"]}) with dkist.net.conf.set_temp("download_endpoint", httpserver.url_for("/download")): downloader = parfive.Downloader() mocked_client.fetch(response, downloader=downloader, path=tmpdir / "{file}") diff --git a/dkist/net/tests/test_helpers.py b/dkist/net/tests/test_helpers.py index e26542bb..6223d2d7 100644 --- a/dkist/net/tests/test_helpers.py +++ b/dkist/net/tests/test_helpers.py @@ -11,7 +11,7 @@ @pytest.fixture def orchestrate_transfer_mock(mocker): - yield mocker.patch("dkist.net.helpers._orchestrate_transfer_task", autospec=True) + return mocker.patch("dkist.net.helpers._orchestrate_transfer_task", autospec=True) @pytest.mark.parametrize( @@ -31,8 +31,8 @@ def test_download_default_keywords(orchestrate_transfer_mock, keywords): "Dataset ID": "AAAA", "Primary Proposal ID": "pm_1_10", "Storage Bucket": "data", - 'Wavelength Max': 856, - 'Wavelength Min': 854, + "Wavelength Max": 856, + "Wavelength Min": 854, } ]), **keywords @@ -49,7 +49,7 @@ def test_download_default_keywords(orchestrate_transfer_mock, keywords): def test_transfer_unavailable_data(mocker): - get_inv_mock = mocker.patch( + mocker.patch( "dkist.net.client.DKISTClient.search", autospec=True, return_value=[], @@ -68,8 +68,8 @@ def test_transfer_from_dataset_id(mocker, orchestrate_transfer_mock): "Dataset ID": "AAAA", "Primary Proposal ID": "pm_1_10", "Storage Bucket": "data", - 'Wavelength Max': 856, - 'Wavelength Min': 854, + "Wavelength Max": 856, + "Wavelength Min": 854, } ]), ) @@ -98,15 +98,15 @@ def test_transfer_from_multiple_dataset_id(mocker, orchestrate_transfer_mock): "Dataset ID": "AAAA", "Primary Proposal ID": "pm_1_10", "Storage Bucket": "data", - 'Wavelength Max': 856, - 'Wavelength Min': 854, + "Wavelength Max": 856, + "Wavelength Min": 854, }, { "Dataset ID": "BBBB", "Primary Proposal ID": "pm_1_10", "Storage Bucket": "data", - 'Wavelength Max': 856, - 'Wavelength Min': 854, + "Wavelength Max": 856, + "Wavelength Min": 854, } ]), ) @@ -145,8 +145,8 @@ def test_transfer_from_table(orchestrate_transfer_mock, mocker): "Dataset ID": ["A", "B"], "Primary Proposal ID": ["pm_1_10", "pm_2_20"], "Storage Bucket": ["data", "data"], - 'Wavelength Max': [856, 856], - 'Wavelength Min': [854, 854], + "Wavelength Max": [856, 856], + "Wavelength Min": [854, 854], }, ) @@ -177,8 +177,8 @@ def test_transfer_from_length_one_table(orchestrate_transfer_mock, mocker): "Dataset ID": ["A"], "Primary Proposal ID": ["pm_1_10"], "Storage Bucket": ["data"], - 'Wavelength Max': [856], - 'Wavelength Min': [854], + "Wavelength Max": [856], + "Wavelength Min": [854], }, ) @@ -203,8 +203,8 @@ def test_transfer_from_row(orchestrate_transfer_mock, mocker): "Dataset ID": ["A"], "Primary Proposal ID": ["pm_1_10"], "Storage Bucket": ["data"], - 'Wavelength Max': [856], - 'Wavelength Min': [854], + "Wavelength Max": [856], + "Wavelength Min": [854], }, ) @@ -230,8 +230,8 @@ def test_transfer_from_UnifiedResponse(orchestrate_transfer_mock, mocker): "Dataset ID": ["A"], "Primary Proposal ID": ["pm_1_10"], "Storage Bucket": ["data"], - 'Wavelength Max': [856], - 'Wavelength Min': [854], + "Wavelength Max": [856], + "Wavelength Min": [854], }, ), DKISTQueryResponseTable( @@ -239,8 +239,8 @@ def test_transfer_from_UnifiedResponse(orchestrate_transfer_mock, mocker): "Dataset ID": ["B"], "Primary Proposal ID": ["pm_2_20"], "Storage Bucket": ["data"], - 'Wavelength Max': [856], - 'Wavelength Min': [854], + "Wavelength Max": [856], + "Wavelength Min": [854], }, ), ) @@ -276,8 +276,8 @@ def test_transfer_path_interpolation(orchestrate_transfer_mock, mocker): "Dataset ID": "AAAA", "Primary Proposal ID": "pm_1_10", "Storage Bucket": "data", - 'Wavelength Max': 856, - 'Wavelength Min': 854, + "Wavelength Max": 856, + "Wavelength Min": 854, "Instrument": "HIT", # Highly Imaginary Telescope } ]), diff --git a/dkist/tests/generate_aia_dataset.py b/dkist/tests/generate_aia_dataset.py index 4bc606a4..d85d0655 100644 --- a/dkist/tests/generate_aia_dataset.py +++ b/dkist/tests/generate_aia_dataset.py @@ -1,5 +1,4 @@ import os -import glob from pathlib import Path import numpy as np @@ -77,7 +76,7 @@ def references_from_filenames(filename_array, relative_to=None): with fits.open(filepath) as hdul: hdu_index = 1 hdu = hdul[hdu_index] - dtype = BITPIX2DTYPE[hdu.header['BITPIX']] + dtype = BITPIX2DTYPE[hdu.header["BITPIX"]] shape = tuple(reversed(hdu.shape)) # Convert paths to relative paths @@ -94,8 +93,8 @@ def references_from_filenames(filename_array, relative_to=None): def main(): from dkist_inventory.transforms import generate_lookup_table - path = Path('~/sunpy/data/jsocflare/').expanduser() - files = glob.glob(str(path / '*.fits')) + path = Path("~/sunpy/data/jsocflare/").expanduser() + files = path.glob("*.fits") # requestid = 'JSOC_20180831_1097' requestid = None @@ -107,11 +106,11 @@ def main(): requestid, path=str(path), overwrite=False).wait() files = [] for f in filesd.values(): - files.append(f['path']) + files.append(f["path"]) else: results = Fido.search( - a.jsoc.Time('2017-09-06T12:00:00', '2017-09-06T12:02:00'), - a.jsoc.Series('aia.lev1_euv_12s'), a.jsoc.Segment('image'), + a.jsoc.Time("2017-09-06T12:00:00", "2017-09-06T12:02:00"), + a.jsoc.Series("aia.lev1_euv_12s"), a.jsoc.Segment("image"), a.jsoc.Notify("stuart@cadair.com")) print(results) @@ -134,14 +133,13 @@ def main(): for i, filepath in enumerate(files): with fits.open(filepath) as hdul: header = hdul[1].header - time = parse_time(header['DATE-OBS']) + time = parse_time(header["DATE-OBS"]) if i == 0: - root_header = header start_time = time inds.append(i) times.append(time) seconds.append((time - start_time).total_seconds()) - waves.append(header['WAVELNTH']) + waves.append(header["WAVELNTH"]) # Construct an array and sort it by wavelength and time arr = np.array((inds, seconds, waves)).T @@ -161,7 +159,7 @@ def main(): # this assumes all wavelength images are taken at the same time time_coords = np.array( [t.isoformat() for t in times])[list_sorter].reshape(shape)[0, :] - wave_coords = np.array(waves)[list_sorter].reshape(shape)[:, 0] + np.array(waves)[list_sorter].reshape(shape)[:, 0] smap0 = sunpy.map.Map(files[0]) spatial = map_to_transform(smap0) @@ -174,7 +172,7 @@ def main(): wave_frame = cf.SpectralFrame(axes_order=(3, ), unit=u.AA, name="wavelength", axes_names=("wavelength", )) time_frame = cf.TemporalFrame( axes_order=(2, ), unit=u.s, reference_time=Time(time_coords[0]), name="time", axes_names=("time", )) - sky_frame = cf.CelestialFrame(axes_order=(0, 1), name='helioprojective', + sky_frame = cf.CelestialFrame(axes_order=(0, 1), name="helioprojective", reference_frame=smap0.coordinate_frame, axes_names=("helioprojective longitude", "helioprojective latitude")) @@ -195,15 +193,15 @@ def main(): ea = references_from_filenames(cube, relative_to=str(path)) tree = { - 'gwcs': wcs, - 'dataset': ea, + "gwcs": wcs, + "dataset": ea, } with asdf.AsdfFile(tree) as ff: # ff.write_to("test.asdf") - filename = str(path / "aia_{}.asdf".format(time_coords[0])) + filename = str(path / f"aia_{time_coords[0]}.asdf") ff.write_to(filename) - print("Saved to : {}".format(filename)) + print(f"Saved to : {filename}") # import sys; sys.exit(0) diff --git a/dkist/tests/generate_eit_test_dataset.py b/dkist/tests/generate_eit_test_dataset.py index 12aa9e2c..d88634a4 100644 --- a/dkist/tests/generate_eit_test_dataset.py +++ b/dkist/tests/generate_eit_test_dataset.py @@ -78,15 +78,15 @@ def main(): with fits.open(filepath) as hdul: header = hdul[0].header headers.append(dict(header)) - headers[-1].pop('') - headers[-1].pop('COMMENT') - headers[-1].pop('HISTORY') + headers[-1].pop("") + headers[-1].pop("COMMENT") + headers[-1].pop("HISTORY") headers[-1]["DNAXIS"] = 3 headers[-1]["DNAXIS3"] = len(files) headers[-1]["DAAXES"] = 2 headers[-1]["DEAXES"] = 1 headers[-1]["DINDEX3"] = i + 1 - time = parse_time(header['DATE-OBS']) + time = parse_time(header["DATE-OBS"]) if i == 0: start_time = time inds.append(i) @@ -104,7 +104,7 @@ def main(): hcubemodel = spatial & timemodel - sky_frame = cf.CelestialFrame(axes_order=(0, 1), name='helioprojective', + sky_frame = cf.CelestialFrame(axes_order=(0, 1), name="helioprojective", reference_frame=smap0.coordinate_frame, axes_names=("helioprojective longitude", "helioprojective latitude")) time_frame = cf.TemporalFrame(axes_order=(2, ), unit=u.s, @@ -135,13 +135,13 @@ def main(): ds._file_manager = ac tree = { - 'dataset': ds + "dataset": ds } with asdf.AsdfFile(tree) as ff: filename = rootdir / "EIT" / "eit_test_dataset.asdf" ff.write_to(filename) - print("Saved to : {}".format(filename)) + print(f"Saved to : {filename}") ds.plot() diff --git a/dkist/utils/_model_to_graphviz.py b/dkist/utils/_model_to_graphviz.py index 989a7184..e22c8ad2 100644 --- a/dkist/utils/_model_to_graphviz.py +++ b/dkist/utils/_model_to_graphviz.py @@ -47,16 +47,16 @@ def model_to_subgraph(model, inputs=None, outputs=None): label = model.__class__.name subgraph = pydot.Subgraph(f"{id(model)}_subgraph", label=label) - model_node = pydot.Node(name=id(model), label=label, shape='box') + model_node = pydot.Node(name=id(model), label=label, shape="box") subgraph.add_node(model_node) for inp, label in zip(inputs, input_labels): - input_node = pydot.Node(name=inp, label=label, shape='none') + input_node = pydot.Node(name=inp, label=label, shape="none") subgraph.add_node(input_node) subgraph.add_edge(pydot.Edge(input_node, model_node)) for out, label in zip(outputs, output_labels): - output_node = pydot.Node(name=out, label=label, shape='none') + output_node = pydot.Node(name=out, label=label, shape="none") subgraph.add_node(output_node) subgraph.add_edge(pydot.Edge(model_node, output_node)) @@ -136,7 +136,7 @@ def recursively_find_node(top, name): return node for sg in top.get_subgraph_list(): - labels = [n.get_label() for n in sg.get_node_list()] + [n.get_label() for n in sg.get_node_list()] if node := recursively_find_node(sg, name): return node diff --git a/dkist/utils/inventory.py b/dkist/utils/inventory.py index 084bea61..4b5d6e6a 100644 --- a/dkist/utils/inventory.py +++ b/dkist/utils/inventory.py @@ -3,12 +3,11 @@ """ import re import string -from typing import Dict from collections import defaultdict from astropy.table import Table -__all__ = ['dehumanize_inventory', 'humanize_inventory', 'INVENTORY_KEY_MAP'] +__all__ = ["dehumanize_inventory", "humanize_inventory", "INVENTORY_KEY_MAP"] class DefaultMap(defaultdict): @@ -19,7 +18,7 @@ def __missing__(self, key): return key -INVENTORY_KEY_MAP: Dict[str, str] = DefaultMap(None, { +INVENTORY_KEY_MAP: dict[str, str] = DefaultMap(None, { "asdfObjectKey": "asdf Filename", "boundingBox": "Bounding Box", "browseMovieObjectKey": "Movie Filename", @@ -78,9 +77,9 @@ def __missing__(self, key): def _key_clean(key): - key = re.sub('[%s]' % re.escape(string.punctuation), '_', key) - key = key.replace(' ', '_') - key = ''.join(char for char in key + key = re.sub("[%s]" % re.escape(string.punctuation), "_", key) + key = key.replace(" ", "_") + key = "".join(char for char in key if char.isidentifier() or char.isnumeric()) return key.lower() @@ -93,11 +92,11 @@ def path_format_keys(keymap): def _path_format_table(keymap=INVENTORY_KEY_MAP): - t = Table({'Inventory Keyword': list(keymap.keys()), 'Path Key': path_format_keys(keymap)}) - return '\n'.join(t.pformat(max_lines=-1, html=True)) + t = Table({"Inventory Keyword": list(keymap.keys()), "Path Key": path_format_keys(keymap)}) + return "\n".join(t.pformat(max_lines=-1, html=True)) -def humanize_inventory(inventory: Dict[str, str]) -> Dict[str, str]: +def humanize_inventory(inventory: dict[str, str]) -> dict[str, str]: """ Convert an inventory dict to have human readable keys. """ @@ -113,13 +112,13 @@ def path_format_inventory(human_inv): Given a single humanized inventory record return a dict for formatting paths. """ # Putting this here because of circular imports - from ..net.client import DKISTQueryResponseTable as Table + from dkist.net.client import DKISTQueryResponseTable as Table t = Table.from_results([{"searchResults": [human_inv]}], client=None) return t[0].response_block_map -def dehumanize_inventory(humanized_inventory: Dict[str, str]) -> Dict[str, str]: +def dehumanize_inventory(humanized_inventory: dict[str, str]) -> dict[str, str]: """ Convert a human readable inventory dict back to the original keys. """ diff --git a/dkist/utils/sysinfo.py b/dkist/utils/sysinfo.py index 9c7bcde8..369c681a 100644 --- a/dkist/utils/sysinfo.py +++ b/dkist/utils/sysinfo.py @@ -4,7 +4,7 @@ import sunpy.extern.distro as distro from sunpy.util.sysinfo import find_dependencies, get_keys_list, get_requirements -__all__ = ['system_info'] +__all__ = ["system_info"] def system_info(): @@ -13,7 +13,7 @@ def system_info(): """ package_name = "dkist" requirements = get_requirements(package_name) - base_reqs = get_keys_list(requirements['required']) + base_reqs = get_keys_list(requirements["required"]) missing_packages, installed_packages = find_dependencies(package=package_name) extra_prop = {"System": platform.system(), "Arch": f"{platform.architecture()[0]}, ({platform.processor()})", @@ -27,19 +27,19 @@ def system_info(): print() print("General") print("#######") - if sys_prop['System'] == "Linux": + if sys_prop["System"] == "Linux": print(f"OS: {distro.name()} ({distro.version()}, Linux {platform.release()})") - elif sys_prop['System'] == "Darwin": + elif sys_prop["System"] == "Darwin": print(f"OS: Mac OS {platform.mac_ver()[0]}") - elif sys_prop['System'] == "Windows": + elif sys_prop["System"] == "Windows": print(f"OS: Windows {platform.release()} {platform.version()}") else: print("Unknown OS") - for sys_info in ['Arch', package_name]: - print(f'{sys_info}: {sys_prop[sys_info]}') - print(f'Installation path: {distribution(package_name)._path}') + for sys_info in ["Arch", package_name]: + print(f"{sys_info}: {sys_prop[sys_info]}") + print(f"Installation path: {distribution(package_name)._path}") print() print("Required Dependencies") print("#####################") for req in base_reqs: - print(f'{req}: {sys_prop[req]}') + print(f"{req}: {sys_prop[req]}") diff --git a/dkist/utils/tests/test_inventory.py b/dkist/utils/tests/test_inventory.py index f2629ad6..86c80b36 100644 --- a/dkist/utils/tests/test_inventory.py +++ b/dkist/utils/tests/test_inventory.py @@ -114,7 +114,7 @@ def test_path_format_table(): "qualityReportObjectKey": "Quality Report Filename", } table = _path_format_table(test_keymap) - table = table[table.find('\n')+1:] + table = table[table.find("\n")+1:] assert table == output diff --git a/dkist/wcs/models.py b/dkist/wcs/models.py index 97d08cc9..27a2957e 100755 --- a/dkist/wcs/models.py +++ b/dkist/wcs/models.py @@ -1,6 +1,7 @@ from abc import ABC -from typing import Union, Literal, Iterable +from typing import Literal from itertools import product +from collections.abc import Iterable import numpy as np @@ -31,11 +32,11 @@ def generate_celestial_transform( - crpix: Union[Iterable[float], u.Quantity], - cdelt: Union[Iterable[float], u.Quantity], - pc: Union[ArrayLike, u.Quantity], - crval: Union[Iterable[float], u.Quantity], - lon_pole: Union[float, u.Quantity] = None, + crpix: Iterable[float] | u.Quantity, + cdelt: Iterable[float] | u.Quantity, + pc: ArrayLike | u.Quantity, + crval: Iterable[float] | u.Quantity, + lon_pole: float | u.Quantity = None, projection: Model = m.Pix2Sky_TAN(), ) -> CompoundModel: """ @@ -134,9 +135,9 @@ def _validate_table_shapes(pc_table, crval_table): table_shape = crval_table.shape[:-1] if pc_table.shape == (2, 2): - pc_table = np.broadcast_to(pc_table, list(table_shape) + [2, 2], subok=True) + pc_table = np.broadcast_to(pc_table, [*list(table_shape), 2, 2], subok=True) if crval_table.shape == (2,): - crval_table = np.broadcast_to(crval_table, list(table_shape) + [2], subok=True) + crval_table = np.broadcast_to(crval_table, [*list(table_shape), 2], subok=True) return table_shape, pc_table, crval_table @@ -197,7 +198,7 @@ def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None): if (np.array(ind) > np.array(self.table_shape) - 1).any() or (np.array(ind) < 0).any(): return m.Const1D(fill_val) & m.Const1D(fill_val) - sct = generate_celestial_transform( + return generate_celestial_transform( crpix=crpix, cdelt=cdelt, pc=self.pc_table[ind], @@ -206,7 +207,6 @@ def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None): projection=self.projection, ) - return sct def _map_transform(self, *arrays, crpix, cdelt, lon_pole, inverse=False): # We need to broadcast the arrays together so they are all the same shape @@ -274,9 +274,9 @@ def input_units(self): for d in dims[:self.n_inputs-2]: units[d] = u.pix return units - else: - dims = ["x", "y", "z", "q", "m"] - return {d: u.pix for d in dims[:self.n_inputs]} + + dims = ["x", "y", "z", "q", "m"] + return {d: u.pix for d in dims[:self.n_inputs]} class VaryingCelestialTransform(BaseVaryingCelestialTransform): @@ -293,7 +293,7 @@ class VaryingCelestialTransform(BaseVaryingCelestialTransform): @property def inverse(self): - ivct = InverseVaryingCelestialTransform( + return InverseVaryingCelestialTransform( crpix=self.crpix, cdelt=self.cdelt, lon_pole=self.lon_pole, @@ -301,7 +301,6 @@ def inverse(self): crval_table=self.crval_table, projection=self.projection, ) - return ivct class VaryingCelestialTransform2D(BaseVaryingCelestialTransform): @@ -309,7 +308,7 @@ class VaryingCelestialTransform2D(BaseVaryingCelestialTransform): @property def inverse(self): - ivct = InverseVaryingCelestialTransform2D( + return InverseVaryingCelestialTransform2D( crpix=self.crpix, cdelt=self.cdelt, lon_pole=self.lon_pole, @@ -317,7 +316,6 @@ def inverse(self): crval_table=self.crval_table, projection=self.projection, ) - return ivct class VaryingCelestialTransform3D(BaseVaryingCelestialTransform): @@ -325,7 +323,7 @@ class VaryingCelestialTransform3D(BaseVaryingCelestialTransform): @property def inverse(self): - ivct = InverseVaryingCelestialTransform3D( + return InverseVaryingCelestialTransform3D( crpix=self.crpix, cdelt=self.cdelt, lon_pole=self.lon_pole, @@ -333,7 +331,6 @@ def inverse(self): crval_table=self.crval_table, projection=self.projection, ) - return ivct class InverseVaryingCelestialTransform(BaseVaryingCelestialTransform): @@ -536,14 +533,14 @@ def __repr__(self): } def varying_celestial_transform_from_tables( - crpix: Union[Iterable[float], u.Quantity], - cdelt: Union[Iterable[float], u.Quantity], - pc_table: Union[ArrayLike, u.Quantity], - crval_table: Union[Iterable[float], u.Quantity], - lon_pole: Union[float, u.Quantity] = None, + crpix: Iterable[float] | u.Quantity, + cdelt: Iterable[float] | u.Quantity, + pc_table: ArrayLike | u.Quantity, + crval_table: Iterable[float] | u.Quantity, + lon_pole: float | u.Quantity = None, projection: Model = m.Pix2Sky_TAN(), inverse: bool = False, - slit: Union[None, Literal[0, 1]] = None, + slit: None | Literal[0, 1] = None, ) -> BaseVaryingCelestialTransform: """ Generate a `.BaseVaryingCelestialTransform` based on the dimensionality of the tables. @@ -601,7 +598,7 @@ def inputs(self, value): @property def input_units(self): - return {f'x{idx}': u.pix for idx in range(self.n_inputs)} + return {f"x{idx}": u.pix for idx in range(self.n_inputs)} @property def outputs(self): @@ -613,9 +610,9 @@ def outputs(self, value): @property def return_units(self): - return {'y': u.pix} + return {"y": u.pix} - def __init__(self, array_shape, order='C', **kwargs): + def __init__(self, array_shape, order="C", **kwargs): if len(array_shape) < 2 or np.prod(array_shape) < 1: raise ValueError("array_shape must be at least 2D and have values >= 1") self.array_shape = tuple(array_shape) @@ -624,8 +621,8 @@ def __init__(self, array_shape, order='C', **kwargs): self.order = order super().__init__(**kwargs) # super dunder init sets inputs and outputs to default values so set what we want here - self.inputs = tuple([f'x{idx}' for idx in range(self.n_inputs)]) - self.outputs = 'y', + self.inputs = tuple([f"x{idx}" for idx in range(self.n_inputs)]) + self.outputs = "y", def evaluate(self, *inputs_): """Evaluate the forward ravel for a given tuple of pixel values.""" @@ -675,7 +672,7 @@ def inputs(self, value): @property def input_units(self): - return {'x': u.pix} + return {"x": u.pix} @property def n_outputs(self): @@ -691,9 +688,9 @@ def outputs(self, value): @property def return_units(self): - return {f'y{idx}': u.pix for idx in range(self.n_outputs)} + return {f"y{idx}": u.pix for idx in range(self.n_outputs)} - def __init__(self, array_shape, order='C', **kwargs): + def __init__(self, array_shape, order="C", **kwargs): if len(array_shape) < 2 or np.prod(array_shape) < 1: raise ValueError("array_shape must be at least 2D and have values >= 1") self.array_shape = array_shape @@ -702,8 +699,8 @@ def __init__(self, array_shape, order='C', **kwargs): self.order = order super().__init__(**kwargs) # super dunder init sets inputs and outputs to default values so set what we want here - self.inputs = 'x', - self.outputs = tuple([f'y{idx}' for idx in range(self.n_outputs)]) + self.inputs = "x", + self.outputs = tuple([f"y{idx}" for idx in range(self.n_outputs)]) def evaluate(self, input_): """Evaluate the reverse ravel (unravel) for a given pixel value.""" diff --git a/dkist/wcs/tests/__init__.py b/dkist/wcs/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dkist/wcs/tests/test_coupled_compound_model.py b/dkist/wcs/tests/test_coupled_compound_model.py index 90075c74..f3a00a7a 100644 --- a/dkist/wcs/tests/test_coupled_compound_model.py +++ b/dkist/wcs/tests/test_coupled_compound_model.py @@ -133,11 +133,11 @@ def test_coupled_sep_2d_extra(vct_2d_pc, linear_time): def test_coupled_slit_no_repeat(linear_time): pc_table = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] * u.pix - kwargs = dict(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - lon_pole=180 * u.deg, - slit=1) + kwargs = {"crpix": (5, 5) * u.pix, + "cdelt": (1, 1) * u.arcsec/u.pix, + "crval_table": (0, 0) * u.arcsec, + "lon_pole": 180 * u.deg, + "slit": 1} vct_slit = varying_celestial_transform_from_tables(pc_table=pc_table, **kwargs) @@ -152,11 +152,11 @@ def test_coupled_slit_with_repeat(linear_time): pc_table = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)] * u.pix pc_table = pc_table.reshape((5, 3, 2, 2)) - kwargs = dict(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - lon_pole=180 * u.deg, - slit=1) + kwargs = {"crpix": (5, 5) * u.pix, + "cdelt": (1, 1) * u.arcsec/u.pix, + "crval_table": (0, 0) * u.arcsec, + "lon_pole": 180 * u.deg, + "slit": 1} vct_slit = varying_celestial_transform_from_tables(pc_table=pc_table, **kwargs) diff --git a/dkist/wcs/tests/test_models.py b/dkist/wcs/tests/test_models.py index 395fb5c7..5677e2b9 100755 --- a/dkist/wcs/tests/test_models.py +++ b/dkist/wcs/tests/test_models.py @@ -93,7 +93,7 @@ def test_varying_transform_pc(): assert u.allclose(vct.inverse(*world, 5*u.pix), pixel[:2], atol=0.01*u.pix) -@pytest.mark.parametrize(("pixel", "lon_shape"), ( +@pytest.mark.parametrize(("pixel", "lon_shape"), [ ((*np.mgrid[0:10, 0:10] * u.pix, np.arange(10) * u.pix), (10, 10)), (np.mgrid[0:10, 0:10, 0:5] * u.pix, (10, 10, 5)), ((2 * u.pix, 2 * u.pix, np.arange(10) * u.pix), (10,)), @@ -101,7 +101,7 @@ def test_varying_transform_pc(): np.arange(10) * u.pix, np.arange(10)[..., None] * u.pix), (10, 10)), (np.mgrid[0:1024, 0:1000, 0:2] * u.pix, (1024, 1000, 2)), -)) +]) def test_varying_transform_pc_shapes(pixel, lon_shape): varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] * u.pix @@ -267,7 +267,7 @@ def test_varying_transform_4d_pc_unitless(): assert np.isnan(vct(0, 0, -10, 0)).all() -@pytest.mark.parametrize(("pixel", "lon_shape"), ( +@pytest.mark.parametrize(("pixel", "lon_shape"), [ ((*np.mgrid[0:5, 0:5] * u.pix, np.arange(5) * u.pix, 0 * u.pix), (5, 5)), (np.mgrid[0:10, 0:10, 0:5, 0:3] * u.pix, (10, 10, 5, 3)), ((2 * u.pix, 2 * u.pix, 0*u.pix, np.arange(3) * u.pix), (3,)), @@ -275,7 +275,7 @@ def test_varying_transform_4d_pc_unitless(): np.arange(10) * u.pix, np.arange(5)[..., None] * u.pix, np.arange(3)[..., None, None]), (3, 5, 10)), -)) +]) def test_varying_transform_4d_pc_shapes(pixel, lon_shape): varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)] * u.pix varying_matrix_lt = varying_matrix_lt.reshape((5, 3, 2, 2)) @@ -300,11 +300,11 @@ def test_vct_dispatch(): varying_matrix_lt = varying_matrix_lt.reshape((2, 2, 2, 2, 2, 2)) crval_table = list(zip(np.arange(1, 17), np.arange(17, 33))) * u.arcsec crval_table = crval_table.reshape((2, 2, 2, 2, 2)) - kwargs = dict( - crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - lon_pole=180 * u.deg, - ) + kwargs = { + "crpix": (5, 5) * u.pix, + "cdelt": (1, 1) * u.arcsec/u.pix, + "lon_pole": 180 * u.deg, + } vct = varying_celestial_transform_from_tables( pc_table=varying_matrix_lt[0, 0, 0], @@ -342,11 +342,11 @@ def test_vct_shape_errors(): crval_table = list(zip(np.arange(1, 16), np.arange(16, 31))) * u.arcsec crval_table = crval_table.reshape((5, 3, 2)) - kwargs = dict( - crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - lon_pole=180 * u.deg, - ) + kwargs = { + "crpix": (5, 5) * u.pix, + "cdelt": (1, 1) * u.arcsec/u.pix, + "lon_pole": 180 * u.deg, + } with pytest.raises(ValueError, match="only be constructed with a 1-dimensional"): VaryingCelestialTransform(crval_table=crval_table, pc_table=pc_table, **kwargs) @@ -377,7 +377,7 @@ def test_vct_slit_bounds(slit): slit=slit, ) -@pytest.mark.parametrize("num_varying_axes", [pytest.param(1, id='1D'), pytest.param(2, id='2D'), pytest.param(3, id='3D')]) +@pytest.mark.parametrize("num_varying_axes", [pytest.param(1, id="1D"), pytest.param(2, id="2D"), pytest.param(3, id="3D")]) @pytest.mark.parametrize("slit", [pytest.param(1, id="spectrograph"), pytest.param(None, id="imager")]) @pytest.mark.parametrize("has_units", [pytest.param(True, id="With Units"), pytest.param(False, id="Without Units")]) def test_vct(has_units, slit, num_varying_axes): @@ -419,8 +419,8 @@ def test_vct(has_units, slit, num_varying_axes): atol *= u.pix for i in range(num_sensor_axes): sensor_axis_pts[i] *= u.pix - grid = np.meshgrid(*sensor_axis_pts, *varying_axis_pts, indexing='ij') - grid2 = np.meshgrid(*sensor_axis_pts, *varying_axis_pts_1, indexing='ij') + grid = np.meshgrid(*sensor_axis_pts, *varying_axis_pts, indexing="ij") + grid2 = np.meshgrid(*sensor_axis_pts, *varying_axis_pts_1, indexing="ij") # the portion of the grid due to the varying axes coordinates varying_axes_grid = grid[num_sensor_axes:] @@ -448,7 +448,7 @@ def test_vct(has_units, slit, num_varying_axes): # grid2 has coordinates outside the lut boundaries and should have nans world2 = vct(*grid2) - assert np.any(np.isnan([item for item in world2])) + assert np.any(np.isnan(list(world2))) def _evaluate_ravel(array_shape, inputs, order="C"): @@ -464,8 +464,7 @@ def _evaluate_ravel(array_shape, inputs, order="C"): inputs = inputs[::-1] rounded_inputs = rounded_inputs[::-1] offsets = np.cumprod(array_shape[1:][::-1])[::-1] - result = np.dot(offsets, rounded_inputs[:-1]) + inputs[-1] - return result + return np.dot(offsets, rounded_inputs[:-1]) + inputs[-1] def _evaluate_unravel(array_shape, index, order="C"): @@ -477,7 +476,7 @@ def _evaluate_unravel(array_shape, index, order="C"): curr_offset = index # This if test is to handle multidimensional inputs properly if isinstance(index, np.ndarray): - output_shape = tuple([len(array_shape), len(index)]) + output_shape = (len(array_shape), len(index)) else: output_shape = len(array_shape) indices = np.zeros(output_shape, dtype=float) @@ -490,7 +489,7 @@ def _evaluate_unravel(array_shape, index, order="C"): return tuple(indices) -@pytest.mark.parametrize("ndim", [pytest.param(2, id='2D'), pytest.param(3, id='3D')]) +@pytest.mark.parametrize("ndim", [pytest.param(2, id="2D"), pytest.param(3, id="3D")]) @pytest.mark.parametrize("has_units", [pytest.param(True, id="With Units"), pytest.param(False, id="Without Units")]) @pytest.mark.parametrize("input_type", [pytest.param("array", id="Array Inputs"), pytest.param("scalar", id="Scalar Inputs")]) def test_ravel_model(ndim, has_units, input_type): @@ -530,7 +529,7 @@ def test_ravel_model(ndim, has_units, input_type): assert np.allclose(round_trip, expected_ravel) -@pytest.mark.parametrize("ndim", [pytest.param(2, id='2D'), pytest.param(3, id='3D')]) +@pytest.mark.parametrize("ndim", [pytest.param(2, id="2D"), pytest.param(3, id="3D")]) @pytest.mark.parametrize("has_units", [pytest.param(True, id="With Units"), pytest.param(False, id="Without Units")]) @pytest.mark.parametrize("input_type", [pytest.param("array", id="Array Inputs"), pytest.param("scalar", id="Scalar Inputs")]) def test_raveled_tabular1d(ndim, has_units, input_type): @@ -579,7 +578,7 @@ def test_raveled_tabular1d(ndim, has_units, input_type): assert np.allclose(raveled_tab.inverse.inverse(*inputs), expected_ravel) -@pytest.mark.parametrize("ndim", [pytest.param(2, id='2D'), pytest.param(3, id='3D')]) +@pytest.mark.parametrize("ndim", [pytest.param(2, id="2D"), pytest.param(3, id="3D")]) @pytest.mark.parametrize("order", ["C", "F"]) def test_ravel_ordering(ndim, order): rng = default_rng() @@ -595,21 +594,23 @@ def test_ravel_ordering(ndim, order): assert int(ravel_value) == values[tuple(inputs)] -@pytest.mark.parametrize("ndim", [pytest.param(2, id='2D'), pytest.param(3, id='3D')]) +@pytest.mark.parametrize("ndim", [pytest.param(2, id="2D"), pytest.param(3, id="3D")]) @pytest.mark.parametrize("order", ["C", "F"]) def test_ravel_repr(ndim, order): rng = default_rng() array_shape = tuple(rng.integers(1, 21, ndim)) ravel = Ravel(array_shape, order=order) unravel = ravel.inverse - assert str(array_shape) in repr(ravel) and order in repr(ravel) - assert str(array_shape) in repr(unravel) and order in repr(unravel) + assert str(array_shape) in repr(ravel) + assert order in repr(ravel) + assert str(array_shape) in repr(unravel) + assert order in repr(unravel) @pytest.mark.parametrize("array_shape", [(0, 1), (1, 0), (1,)]) @pytest.mark.parametrize("ravel", [Ravel, Unravel]) def test_ravel_bad_array_shape(array_shape, ravel): - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="array_shape must be at least 2D and have values >= 1"): ravel(array_shape) @@ -617,7 +618,7 @@ def test_ravel_bad_array_shape(array_shape, ravel): @pytest.mark.parametrize("ravel", [Ravel, Unravel]) def test_ravel_bad_order(order, ravel): array_shape=(2, 2, 2) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="order kwarg must be one of 'C' or 'F'"): ravel(array_shape, order) diff --git a/docs/conf.py b/docs/conf.py index 839a938c..e493697d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -2,12 +2,13 @@ Configuration file for the Sphinx documentation builder. """ # -- stdlib imports ------------------------------------------------------------ +import datetime import os import sys -import datetime import warnings -from pkg_resources import get_distribution + from packaging.version import Version +from pkg_resources import get_distribution # -- Check for dependencies ---------------------------------------------------- doc_requires = get_distribution("dkist").requires(extras=("docs",)) @@ -15,7 +16,7 @@ for requirement in doc_requires: try: get_distribution(requirement) - except Exception as e: + except Exception: missing_requirements.append(requirement.name) if missing_requirements: print( @@ -26,22 +27,22 @@ # -- Read the Docs Specific Configuration -------------------------------------- # This needs to be done before sunpy is imported -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" if on_rtd: - os.environ['SUNPY_CONFIGDIR'] = '/home/docs/' - os.environ['HOME'] = '/home/docs/' - os.environ['LANG'] = 'C' - os.environ['LC_ALL'] = 'C' - os.environ['HIDE_PARFIVE_PROGESS'] = 'True' + os.environ["SUNPY_CONFIGDIR"] = "/home/docs/" + os.environ["HOME"] = "/home/docs/" + os.environ["LANG"] = "C" + os.environ["LC_ALL"] = "C" + os.environ["HIDE_PARFIVE_PROGESS"] = "True" # -- Non stdlib imports -------------------------------------------------------- -import dkist # NOQA -from dkist import __version__ # NOQA +import dkist # noqa +from dkist import __version__ # -- Project information ------------------------------------------------------- -project = 'DKIST' -author = 'NSO / AURA' -copyright = '{}, {}'.format(datetime.datetime.now().year, author) +project = "DKIST" +author = "NSO / AURA" +copyright = f"{datetime.datetime.now().year}, {author}" # The full version, including alpha/beta/rc tags release = __version__ @@ -54,30 +55,30 @@ # Suppress warnings about overriding directives as we overload some of the # doctest extensions. -suppress_warnings = ['app.add_directive', ] +suppress_warnings = ["app.add_directive", ] # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'matplotlib.sphinxext.plot_directive', - 'sphinx_automodapi.automodapi', - 'sphinx_automodapi.smart_resolver', - 'sphinx_changelog', - 'sphinx.ext.autodoc', - 'sphinx.ext.coverage', - 'sphinx.ext.doctest', - 'sphinx.ext.inheritance_diagram', - 'sphinx.ext.intersphinx', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.todo', - 'sphinx.ext.viewcode', - 'sphinx_autodoc_typehints', # must be loaded after napoleon - 'sunpy.util.sphinx.doctest', - 'sunpy.util.sphinx.generate', - 'myst_nb', - 'sphinx_design', + "matplotlib.sphinxext.plot_directive", + "sphinx_automodapi.automodapi", + "sphinx_automodapi.smart_resolver", + "sphinx_changelog", + "sphinx.ext.autodoc", + "sphinx.ext.coverage", + "sphinx.ext.doctest", + "sphinx.ext.inheritance_diagram", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.todo", + "sphinx.ext.viewcode", + "sphinx_autodoc_typehints", # must be loaded after napoleon + "sunpy.util.sphinx.doctest", + "sunpy.util.sphinx.generate", + "myst_nb", + "sphinx_design", ] # Add any paths that contain templates here, relative to this directory. @@ -90,20 +91,20 @@ # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'jupyter_execute', '**/*_NOTES.md'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "jupyter_execute", "**/*_NOTES.md"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: -source_suffix = '.rst' +source_suffix = ".rst" -myst_enable_extensions = ['colon_fence', 'dollarmath', 'substitution'] +myst_enable_extensions = ["colon_fence", "dollarmath", "substitution"] # The master toctree document. -master_doc = 'index' +master_doc = "index" # The reST default role (used for this markup: `text`) to use for all # documents. Set to the "smart" one. -default_role = 'obj' +default_role = "obj" napoleon_use_rtype = False @@ -139,11 +140,11 @@ ), "astropy": ("https://docs.astropy.org/en/stable/", None), "parfive": ("https://parfive.readthedocs.io/en/stable/", None), - "sunpy": ('https://docs.sunpy.org/en/stable/', None), - "ndcube": ('https://docs.sunpy.org/projects/ndcube/en/latest/', None), - "gwcs": ('https://gwcs.readthedocs.io/en/latest/', None), - "asdf": ('https://asdf.readthedocs.io/en/latest/', None), - "dask": ('https://dask.pydata.org/en/latest/', None), + "sunpy": ("https://docs.sunpy.org/en/stable/", None), + "ndcube": ("https://docs.sunpy.org/projects/ndcube/en/latest/", None), + "gwcs": ("https://gwcs.readthedocs.io/en/latest/", None), + "asdf": ("https://asdf.readthedocs.io/en/latest/", None), + "dask": ("https://dask.pydata.org/en/latest/", None), } # -- Options for HTML output --------------------------------------------------- @@ -159,12 +160,12 @@ graphviz_output_format = "svg" graphviz_dot_args = [ - '-Nfontsize=10', - '-Nfontname=Helvetica Neue, Helvetica, Arial, sans-serif', - '-Efontsize=10', - '-Efontname=Helvetica Neue, Helvetica, Arial, sans-serif', - '-Gfontsize=10', - '-Gfontname=Helvetica Neue, Helvetica, Arial, sans-serif' + "-Nfontsize=10", + "-Nfontname=Helvetica Neue, Helvetica, Arial, sans-serif", + "-Efontsize=10", + "-Efontname=Helvetica Neue, Helvetica, Arial, sans-serif", + "-Gfontsize=10", + "-Gfontname=Helvetica Neue, Helvetica, Arial, sans-serif" ] # Use a high-contrast code style from accessible-pygments @@ -175,7 +176,7 @@ # -- MyST_NB ------------------------------------------------------------------- nb_execution_allow_errors = False nb_execution_in_temp = True -nb_execution_mode = 'auto' +nb_execution_mode = "auto" nb_execution_timeout = 300 -nb_output_stderr = 'show' +nb_output_stderr = "show" nb_execution_show_tb = True