From 1b40cd3aa0f8c8cb26cc0cfc40884d9dd604aad5 Mon Sep 17 00:00:00 2001 From: Stephan Finkensieper Date: Wed, 4 Sep 2024 14:55:23 +0000 Subject: [PATCH] Refactor dataset wrapper --- satpy/readers/mviri_l1b_fiduceo_nc.py | 167 ++++++++++++++------------ 1 file changed, 91 insertions(+), 76 deletions(-) diff --git a/satpy/readers/mviri_l1b_fiduceo_nc.py b/satpy/readers/mviri_l1b_fiduceo_nc.py index e0297f7ec3..4a619ac866 100644 --- a/satpy/readers/mviri_l1b_fiduceo_nc.py +++ b/satpy/readers/mviri_l1b_fiduceo_nc.py @@ -452,53 +452,50 @@ def is_high_resol(resolution): return resolution == HIGH_RESOL -class DatasetWrapper: - """Helper class for accessing the dataset.""" - - def __init__(self, nc): - """Wrap the given dataset.""" - self.nc = nc - - # decode data - self._decode_cf() - # rename duplicate dimensions - self._fix_duplicate_dimensions(self.nc) - - - @property - def attrs(self): - """Exposes dataset attributes.""" - return self.nc.attrs - - def __getitem__(self, item): - """Get a variable from the dataset.""" - ds = self.nc[item] - if self._should_dims_be_renamed(ds): - ds = self._rename_dims(ds) - elif self._coordinates_not_assigned(ds): - ds = self._reassign_coords(ds) +class DatasetPreprocessor: + def preprocess(self, ds): + ds = self._rename_vars(ds) + ds = self._decode_cf(ds) + ds = self._fix_duplicate_dimensions(ds) + self._reassign_coords(ds) self._cleanup_attrs(ds) return ds - def _should_dims_be_renamed(self, ds): - """Determine whether dataset dimensions need to be renamed.""" - return "y_ir_wv" in ds.dims or "y_tie" in ds.dims - - def _rename_dims(self, ds): - """Rename dataset dimensions to match satpy's expectations.""" + def _rename_vars(self, ds): + """Rename variables to match satpy's expectations.""" new_names = { - "y_ir_wv": "y", - "x_ir_wv": "x", - "y_tie": "y", - "x_tie": "x" + "time_ir_wv": "time", } - for old_name, new_name in new_names.items(): - if old_name in ds.dims: - ds = ds.rename({old_name: new_name}) + new_names_avail = { + old: new + for old, new in new_names.items() + if old in ds + } + return ds.rename(new_names_avail) + + def _decode_cf(self, ds): + # remove time before decoding and add again. + time_dims, time = self._decode_time(ds) + ds = ds.drop_vars(time.name) + ds = xr.decode_cf(ds) + ds[time.name] = (time_dims, time.values) return ds - def _coordinates_not_assigned(self, ds): - return "y" in ds.dims and "y" not in ds.coords + def _decode_time(self, ds): + time = ds["time"] + time_dims = time.dims + time = xr.where(time == time.attrs["_FillValue"], np.datetime64("NaT"), + (time + time.attrs["add_offset"]).astype("datetime64[s]").astype("datetime64[ns]")) + return (time_dims, time) + + def _fix_duplicate_dimensions(self, ds): + ds = ds.copy() + ds.variables["covariance_spectral_response_function_vis"].dims = ("srf_size_1", "srf_size_2") + ds = ds.drop_dims("srf_size") + ds.variables["channel_correlation_matrix_independent"].dims = ("channel_1", "channel_2") + ds.variables["channel_correlation_matrix_structured"].dims = ("channel_1", "channel_2") + ds = ds.drop_dims("channel") + return ds def _reassign_coords(self, ds): """Re-assign coordinates. @@ -506,58 +503,75 @@ def _reassign_coords(self, ds): For some reason xarray doesn't assign coordinates to all high resolution data variables. """ - return ds.assign_coords({"y": self.nc.coords["y"], - "x": self.nc.coords["x"]}) + for var_name, data_array in ds.data_vars.items(): + if self._coordinates_not_assigned(data_array): + ds[var_name] = data_array.assign_coords( + { + "y": ds.coords["y"], + "x": ds.coords["x"] + } + ) + + def _coordinates_not_assigned(self, ds): + return "y" in ds.dims and "y" not in ds.coords def _cleanup_attrs(self, ds): """Cleanup dataset attributes.""" # Remove ancillary_variables attribute to avoid downstream # satpy warnings. - ds.attrs.pop("ancillary_variables", None) + for data_array in ds.data_vars.values(): + data_array.attrs.pop("ancillary_variables", None) - def _decode_cf(self): - # remove time before decoding and add again. - time_dims, time = self._decode_time() - self.nc = self.nc.drop_vars(time.name) - self.nc = xr.decode_cf(self.nc) - self.nc[time.name] = (time_dims, time.values) - - def _decode_time(self): - time = self.get_time() - time_dims = self.nc[time.name].dims - time = xr.where(time == time.attrs["_FillValue"], np.datetime64("NaT"), - (time + time.attrs["add_offset"]).astype("datetime64[s]").astype("datetime64[ns]")) - return (time_dims, time) +class DatasetAccessor: + """Helper class for accessing the dataset.""" - def _fix_duplicate_dimensions(self, nc): - nc.variables["covariance_spectral_response_function_vis"].dims = ("srf_size_1", "srf_size_2") - self.nc = nc.drop_dims("srf_size") - nc.variables["channel_correlation_matrix_independent"].dims = ("channel_1", "channel_2") - nc.variables["channel_correlation_matrix_structured"].dims = ("channel_1", "channel_2") - self.nc = nc.drop_dims("channel") + def __init__(self, ds): + """Wrap the given dataset.""" + self.ds = ds - def get_time(self): - """Get time coordinate. + @property + def attrs(self): + """Exposes dataset attributes.""" + return self.ds.attrs - Variable is sometimes named "time" and sometimes "time_ir_wv". - """ - try: - return self["time_ir_wv"] - except KeyError: - return self["time"] + def __getitem__(self, item): + """Get a variable from the dataset.""" + data_array = self.ds[item] + if self._should_dims_be_renamed(data_array): + return self._rename_dims(data_array) + return data_array + + def _should_dims_be_renamed(self, data_array): + """Determine whether dataset dimensions need to be renamed.""" + return "y_ir_wv" in data_array.dims or "y_tie" in data_array.dims + + def _rename_dims(self, data_array): + """Rename dataset dimensions to match satpy's expectations.""" + new_names = { + "y_ir_wv": "y", + "x_ir_wv": "x", + "y_tie": "y", + "x_tie": "x" + } + new_names_avail = { + old: new + for old, new in new_names.items() + if old in data_array.dims + } + return data_array.rename(new_names_avail) def get_xy_coords(self, resolution): """Get x and y coordinates for the given resolution.""" if is_high_resol(resolution): - return self.nc.coords["x"], self.nc.coords["y"] - return self.nc.coords["x_ir_wv"], self.nc.coords["x_ir_wv"] + return self.ds.coords["x"], self.ds.coords["y"] + return self.ds.coords["x_ir_wv"], self.ds.coords["x_ir_wv"] def get_image_size(self, resolution): """Get image size for the given resolution.""" if is_high_resol(resolution): - return self.nc.coords["y"].size - return self.nc.coords["y_ir_wv"].size + return self.ds.coords["y"].size + return self.ds.coords["y_ir_wv"].size class FiduceoMviriBase(BaseFileHandler): @@ -592,7 +606,8 @@ def __init__(self, filename, filename_info, filetype_info, # noqa: D417 mask_and_scale=False, ) - self.nc = DatasetWrapper(nc_raw) + nc_preproc = DatasetPreprocessor().preprocess(nc_raw) + self.nc = DatasetAccessor(nc_preproc) self.projection_longitude = self._get_projection_longitude(filename_info) @@ -741,7 +756,7 @@ def _get_acq_time_uncached(self, resolution): Note that the acquisition time does not increase monotonically with the scanline number due to the scan pattern and rectification. """ - time2d = self.nc.get_time() + time2d = self.nc["time"] _, target_y = self.nc.get_xy_coords(resolution) return Interpolator.interp_acq_time(time2d, target_y=target_y.values)