Skip to content

Commit

Permalink
refactored the remaining fstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Fersoil authored and Fersoil committed Dec 23, 2024
1 parent fad4f14 commit fe75082
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 49 deletions.
14 changes: 8 additions & 6 deletions src/meteors/attr/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def ensure_torch_tensor(value: np.ndarray | torch.Tensor, context: str) -> torch
return value

if isinstance(value, np.ndarray):
logger.debug(f"Converting {context} from NumPy array to PyTorch tensor")
logger.debug("Converting {} from NumPy array to PyTorch tensor".format(context))
return torch.from_numpy(value)

raise TypeError("{} must be a NumPy array or PyTorch tensor".format(context))
Expand Down Expand Up @@ -101,7 +101,7 @@ def validate_shapes(attributes: torch.Tensor, hsi: HSI) -> None:
"""
if list(attributes.shape) != list(hsi.image.shape):
raise ShapeMismatchError(
f"Attributes and HSI have different, unmatching shapes: {attributes.shape}, {hsi.image.shape}"
"Attributes and HSI have different, unmatching shapes: {}, {}".format(attributes.shape, hsi.image.shape)
)


Expand Down Expand Up @@ -149,8 +149,8 @@ def align_band_names_with_mask(
)
else:
logger.info(
f"Adding 'not_included' to band names because {value} ids "
"is present in the mask and not in band names"
"Adding 'not_included' to band names because {} ids "
"is present in the mask and not in band names".format(value)
)
band_names["not_included"] = value
band_name_values.add(value)
Expand Down Expand Up @@ -184,7 +184,9 @@ def validate_attribution_method(value: str | None) -> str | None:
value = value.title()
if value not in AVAILABLE_ATTRIBUTION_METHODS:
logger.warning(
f"Unknown attribution method: {value}. The core implemented methods are {AVAILABLE_ATTRIBUTION_METHODS}"
"Unknown attribution method: {}. The core implemented methods are {}".format(
value, AVAILABLE_ATTRIBUTION_METHODS
)
)
return value

Expand Down Expand Up @@ -226,7 +228,7 @@ def resolve_inference_device_attributes(device: str | torch.device | None, info:
if not isinstance(device, torch.device):
raise TypeError("Device should be a string or torch device")

logger.debug(f"Device for inference: {device.type}")
logger.debug("Device for inference: {}".format(device.type))
return device


Expand Down
8 changes: 5 additions & 3 deletions src/meteors/attr/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def validate_and_transform_baseline(baseline: int | float | torch.Tensor | None,
elif isinstance(baseline, torch.Tensor):
if baseline.shape != hsi.image.shape:
raise ShapeMismatchError(
f"Passed baseline and HSI have incorrect shapes: {baseline.shape} and {hsi.image.shape}"
"Passed baseline and HSI have incorrect shapes: {} and {}".format(baseline.shape, hsi.image.shape)
)
if not isinstance(baseline, torch.Tensor):
raise TypeError("Expected torch.Tensor | int | float as baseline, but got {}".format(type(baseline)))
Expand Down Expand Up @@ -86,11 +86,13 @@ def __init__(self, callable: ExplainableModel | Explainer) -> None:
self.chained_explainer = callable
self.explainable_model: ExplainableModel = callable.explainable_model
logger.debug(
f"Initializing {self.__class__.__name__} explainer on model {callable.explainable_model} chained with {callable.__class__.__name__}"
"Initializing {} explainer on model {} chained with {}".format(
self.__class__.__name__, callable.explainable_model, callable.__class__.__name__
)
)
else:
self.explainable_model = callable
logger.debug(f"Initializing {self.__class__.__name__} explainer on model {callable}")
logger.debug("Initializing {} explainer on model {}".format(self.__class__.__name__, callable))

self.forward_func = self.explainable_model.forward_func

Expand Down
66 changes: 40 additions & 26 deletions src/meteors/attr/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ def validate_band_format(bands: dict[str | tuple[str, ...], BandType], variable_
continue
raise TypeError(
(
f"{variable_name} should be either a value, list of values, "
"tuple of two values or list of tuples of two values."
)
"{} should be either a value, list of values, " "tuple of two values or list of tuples of two values."
).format(variable_name)
)


Expand Down Expand Up @@ -180,9 +179,9 @@ def validate_segment_format(
):
raise ValueError(
(
f"Each segment range should be a tuple or list of two numbers of data type {dtype} (start, end). "
f"Where start < end. But got: {segment}"
)
"Each segment range should be a tuple or list of two numbers of data type {} (start, end). "
"Where start < end. But got: {}"
).format(dtype, segment)
)
return segment

Expand Down Expand Up @@ -223,7 +222,7 @@ def adjust_and_validate_segment_ranges(
for start, end in segment_ranges:
if start < 0:
if end > 0:
logger.debug(f"Adjusting segment start from {start} to 0")
logger.debug("Adjusting segment start from {} to 0".format(start))
start = 0
else:
raise ValueError("Segment range {} is out of bounds".format((start, end)))
Expand Down Expand Up @@ -284,13 +283,13 @@ def validate_mask_shape(mask_type: Literal["segmentation", "band"], hsi: HSI, ma

if mask_type == "segmentation" and ("H" in orientation_mismatches or "W" in orientation_mismatches):
raise ValueError(
f"Image and mask orientation mismatch: {hsi.orientation} and {mask_shape}."
"Image and mask orientation mismatch: {} and {}.".format(hsi.orientation, mask_shape)
+ "Segmentation mask should differ only in the band dimension"
)

if mask_type == "band" and "C" in orientation_mismatches:
raise ValueError(
f"Image and mask orientation mismatch: {hsi.orientation} and {mask_shape}."
"Image and mask orientation mismatch: {} and {}.".format(hsi.orientation, mask_shape)
+ "Band mask should differ only in the height and width dimensions"
)

Expand Down Expand Up @@ -480,7 +479,9 @@ def get_band_mask(
]
ignored_params_str = " and ".join(ignored_params)
logger.info(
f"Only the band names will be used to create the band mask. The additional parameters {ignored_params_str} will be ignored."
"Only the band names will be used to create the band mask. The additional parameters {} will be ignored.".format(
ignored_params_str
)
)
try:
validate_band_names(band_names)
Expand All @@ -503,7 +504,9 @@ def get_band_mask(
)
except Exception as e:
raise ValueError(
f"Incorrect band ranges wavelengths provided, please check if provided wavelengths are correct: {e}"
"Incorrect band ranges wavelengths provided, please check if provided wavelengths are correct: {}".format(
e
)
) from e
elif band_indices is not None:
logger.debug("Getting band mask from band groups given by ranges of indices")
Expand All @@ -512,7 +515,9 @@ def get_band_mask(
band_groups = Lime._get_band_indices_from_input_band_indices(hsi.wavelengths, band_indices)
except Exception as e:
raise ValueError(
f"Incorrect band ranges indices provided, please check if provided indices are correct: {e}"
"Incorrect band ranges indices provided, please check if provided indices are correct: {}".format(
e
)
) from e

return Lime._create_tensor_band_mask(
Expand Down Expand Up @@ -577,7 +582,9 @@ def _extract_bands_from_spyndex(segment_name: list[str] | tuple[str, ...] | str)
band_names_segment.append(band_name)
else:
raise BandSelectionError(
f"Invalid band name {band_name}, band name must be either in `spyndex.indices` or `spyndex.bands`"
"Invalid band name {}, band name must be either in `spyndex.indices` or `spyndex.bands`".format(
band_name
)
)

return tuple(set(band_names_segment)) if len(band_names_segment) > 1 else band_names_segment[0]
Expand Down Expand Up @@ -654,9 +661,11 @@ def _get_band_wavelengths_indices_from_band_names(

if min_wavelength > wavelengths.max() or max_wavelength < wavelengths.min():
logger.debug(
f"Band {band_name} is not present in the given wavelengths. "
f"Band ranges from {min_wavelength} nm to {max_wavelength} nm and the HSI wavelengths "
f"range from {wavelengths.min():.2f} nm to {wavelengths.max():.2f} nm. The given band will be skipped"
"Band {} is not present in the given wavelengths. "
"Band ranges from {:.2f} nm to {:.2f} nm and the HSI wavelengths "
"range from {:.2f} nm to {:.2f} nm. The given band will be skipped".format(
band_name, min_wavelength, max_wavelength, wavelengths.min(), wavelengths.max()
)
)
else:
segment_indices_ranges += Lime._convert_wavelengths_to_indices(
Expand Down Expand Up @@ -832,8 +841,9 @@ def _check_overlapping_segments(dict_labels_to_indices: dict[str | tuple[str, ..
label_second_str = label_second if isinstance(label_second, str) else "/".join(label_second)

logger.warning(
f"Segments {label_first_str} and {label_second_str} are overlapping,"
" overlapping wavelengths will be assigned to only one"
"Segments {} and {} are overlapping, overlapping wavelengths will be assigned to only one".format(
label_first_str, label_second_str
)
)

@staticmethod
Expand Down Expand Up @@ -866,8 +876,9 @@ def _validate_and_create_dict_labels_to_segment_ids(
if len(dict_labels_to_segment_ids) != len(segment_labels):
raise ValueError(
(
f"Incorrect dict_labels_to_segment_ids - length mismatch. Expected: "
f"{len(segment_labels)}, Actual: {len(dict_labels_to_segment_ids)}"
"Incorrect dict_labels_to_segment_ids - length mismatch. Expected: {}, " "Actual: {}".format(
len(segment_labels), len(dict_labels_to_segment_ids)
)
)
)

Expand Down Expand Up @@ -912,8 +923,8 @@ def _create_single_dim_band_mask(
if not are_indices_valid:
raise ValueError(
(
f"Indices for segment {segment_label} are out of bounds for the one-dimensional band mask"
f"of shape {band_mask_single_dim.shape}"
"Indices for segment {} are out of bounds for the one-dimensional band mask "
"of shape {}".format(segment_label, band_mask_single_dim.shape)
)
)
band_mask_single_dim[segment_indices] = segment_id
Expand Down Expand Up @@ -952,8 +963,7 @@ def _create_tensor_band_mask(
device = hsi.device
segment_labels = list(dict_labels_to_indices.keys())

logger.debug(f"Creating a band mask on the device {device} using {len(segment_labels)} segments")

logger.debug("Creating a band mask on the device {} using {} segments".format(device, len(segment_labels)))
# Check for overlapping segments
Lime._check_overlapping_segments(dict_labels_to_indices)

Expand Down Expand Up @@ -1147,7 +1157,9 @@ def get_spatial_attributes(

if len(hsi) != len(segmentation_mask):
raise ValueError(
f"Number of segmentation masks should be equal to the number of HSI images provided, provided {len(segmentation_mask)}"
"Number of segmentation masks should be equal to the number of HSI images provided, provided {}".format(
len(segmentation_mask)
)
)

segmentation_mask = [
Expand Down Expand Up @@ -1303,7 +1315,9 @@ def get_spectral_attributes(
logger.debug("Reusing the same band mask for all images")
else:
raise ValueError(
f"Number of band masks should be equal to the number of HSI images provided, provided {len(band_mask)}"
"Number of band masks should be equal to the number of HSI images provided, provided {}".format(
len(band_mask)
)
)

band_mask = [
Expand Down
8 changes: 5 additions & 3 deletions src/meteors/attr/noise_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,10 @@ def attribute(
if isinstance(stdevs, tuple):
if len(stdevs) != len(hsi):
raise ValueError(
"The number of stdevs must match the number of input images, number of stdevs:"
f"{len(stdevs)}, number of input images: {len(hsi)}"
"The number of stdevs must match the number of input images, number of stdevs: {}, "
"number of input images: {}".format(len(stdevs), len(hsi))
)

else:
stdevs = tuple([stdevs] * len(hsi))

Expand Down Expand Up @@ -401,7 +402,8 @@ def perturb_input(
else:
if num_perturbed_bands < 0 or num_perturbed_bands > input.shape[0]:
raise ValueError(
f"Cannot perturb {num_perturbed_bands} bands in the input with {input.shape[0]} channels. The number of perturbed bands must be in the range [0, {input.shape[0]}]"
"Cannot perturb {} bands in the input with {} channels. The number of perturbed bands must be "
"in the range [0, {}]".format(num_perturbed_bands, input.shape[0], input.shape[0])
)

channels_to_be_perturbed = torch_random_choice(
Expand Down
10 changes: 7 additions & 3 deletions src/meteors/hsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def resolve_inference_device_hsi(device: str | torch.device | None, info: Valida
if not isinstance(device, torch.device):
raise TypeError("Device should be a string or torch device")

logger.debug(f"Device for inference: {device.type}")
logger.debug("Device for inference: {}".format(device.type))
return device


Expand Down Expand Up @@ -158,7 +158,9 @@ def validate_shapes(wavelengths: torch.Tensor, image: torch.Tensor, spectral_axi
"""
if wavelengths.shape[0] != image.shape[spectral_axis]:
raise ShapeMismatchError(
f"Length of wavelengths must match the number of channels in the image. Passed {wavelengths.shape[0]} wavelengths for {image.shape[spectral_axis]} channels",
"Length of wavelengths must match the number of channels in the image. Passed {} wavelengths for {} channels".format(
wavelengths.shape[0], image.shape[spectral_axis]
),
)


Expand Down Expand Up @@ -224,7 +226,9 @@ def process_and_validate_binary_mask(
binary_mask = binary_mask.expand_as(image)
except RuntimeError:
raise ShapeMismatchError(
f"Mismatch in shapes of binary mask and HSI. Passed shapes are respectively: {binary_mask.shape}, {image.shape}"
"Mismatch in shapes of binary mask and HSI. Passed shapes are respectively: {}, {}".format(
binary_mask.shape, image.shape
)
)

return binary_mask
Expand Down
11 changes: 7 additions & 4 deletions src/meteors/shap/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def ensure_data_type_and_reshape(data: np.ndarray | torch.Tensor | pd.DataFrame)
converted_data = np.array(data)
except Exception as e:
raise TypeError(
f"Expected NumPy array | Torch Tensor | Pandas DataFrame as data, but got {type(data)} and failed to convert to NumPy array"
"Expected NumPy array | Torch Tensor | Pandas DataFrame as data, but got {} and failed to convert to NumPy array".format(
type(data)
)
) from e
if not np.issubdtype(converted_data.dtype, np.number):
raise TypeError("Expected numeric data, but got {}".format(converted_data.dtype))
Expand Down Expand Up @@ -132,7 +134,7 @@ def _validate_shapes(self):
else:
raise ShapeMismatchError(
"Shape of the explanations does not match the shape of the input data. "
f"Expected {data_shape}, but got {self.explanations.shape}"
"Expected {}, but got {}".format(data_shape, explanation_shape)
)
elif len(explanation_shape) == len(data_shape):
if explanation_shape[-2] == data_shape[-1] and data_shape[0] == 1:
Expand All @@ -148,8 +150,9 @@ def _validate_shapes(self):
return

raise ShapeMismatchError(
"Shape of the explanations does not match the shape of the input data. "
f"Expected {data_shape}, but got {explanation_shape}"
"Shape of the explanations does not match the shape of the input data. " "Expected {}, but got {}".format(
data_shape, explanation_shape
)
)

@model_validator(mode="after")
Expand Down
5 changes: 3 additions & 2 deletions src/meteors/visualize/attr_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def visualize_attributes(
"""
if image_attributes.hsi.orientation != ("H", "W", "C"):
logger.info(
f"The orientation of the image is not (H, W, C): {image_attributes.hsi.orientation}. "
f"Changing it to (H, W, C) for visualization."
"The orientation of the image is not (H, W, C): {}. " "Changing it to (H, W, C) for visualization.".format(
image_attributes.hsi.orientation
)
)
rotated_attributes_dataclass = image_attributes.change_orientation("HWC", inplace=False)
else:
Expand Down
6 changes: 4 additions & 2 deletions src/meteors/visualize/shap_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,10 @@ def wavelengths_bar(

if len(transformations_list) > cmap.N:
raise ValueError(
f"Number of transformations ({len(transformations_list)}) is greater than the number of "
f"colors in the colormap ({cmap.N}). Please provide a colormap with more colors."
"Number of transformations ({}) is greater than the number of "
"colors in the colormap ({}). Please provide a colormap with more colors.".format(
len(transformations_list), cmap.N
)
)

# the bottoms of the current bars
Expand Down

0 comments on commit fe75082

Please sign in to comment.