Skip to content

Commit

Permalink
fix: Consistency of terminology for Attributes (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
WolodjaZ authored Dec 11, 2024
1 parent 25a49f6 commit d2374a8
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 143 deletions.
4 changes: 2 additions & 2 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ The architecture of the package can be seen on the UML diagram:
- change_orientation
- to

::: src.meteors.attr.attributes.HSISpatialAttributes
::: src.meteors.attr.attributes.HSIAttributesSpatial
options:
heading_level: 3
show_bases: true
show_root_heading: true
show_root_full_path: false

::: src.meteors.attr.attributes.HSISpectralAttributes
::: src.meteors.attr.attributes.HSIAttributesSpectral
options:
heading_level: 3
show_bases: true
Expand Down
6 changes: 3 additions & 3 deletions src/meteors/attr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .attributes import HSIAttributes, HSISpatialAttributes, HSISpectralAttributes
from .attributes import HSIAttributes, HSIAttributesSpatial, HSIAttributesSpectral
from .explainer import Explainer

from .lime import Lime
Expand All @@ -10,8 +10,8 @@

__all__ = [
"HSIAttributes",
"HSISpatialAttributes",
"HSISpectralAttributes",
"HSIAttributesSpatial",
"HSIAttributesSpectral",
"Explainer",
"IntegratedGradients",
"InputXGradient",
Expand Down
8 changes: 4 additions & 4 deletions src/meteors/attr/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def change_orientation(self, target_orientation: tuple[str, str, str] | list[str
return attrs


class HSISpatialAttributes(HSIAttributes):
class HSIAttributesSpatial(HSIAttributes):
"""Represents spatial attributes of an hsi used for explanation.
Attributes:
Expand Down Expand Up @@ -440,7 +440,7 @@ def flattened_attributes(self) -> torch.Tensor:
Examples:
>>> segmentation_mask = torch.zeros((3, 2, 2))
>>> attrs = HSISpatialAttributes(hsi, attributes, score=0.5, segmentation_mask=segmentation_mask)
>>> attrs = HSIAttributesSpatial(hsi, attributes, score=0.5, segmentation_mask=segmentation_mask)
>>> attrs.flattened_attributes
tensor([[0., 0.],
[0., 0.]])
Expand All @@ -459,7 +459,7 @@ def _validate_hsi_attributions_and_mask(self) -> None:
super()._validate_hsi_attributions_and_mask()


class HSISpectralAttributes(HSIAttributes):
class HSIAttributesSpectral(HSIAttributes):
"""Represents an hsi with spectral attributes used for explanation.
Attributes:
Expand Down Expand Up @@ -495,7 +495,7 @@ def band_mask(self) -> torch.Tensor:
Examples:
>>> band_names = {"R": 0, "G": 1, "B": 2}
>>> attrs = HSISpectralAttributes(hsi, attributes, score=0.5, mask=band_mask)
>>> attrs = HSIAttributesSpectral(hsi, attributes, score=0.5, mask=band_mask)
>>> attrs.flattened_band_mask
torch.tensor([0, 1, 2])
"""
Expand Down
28 changes: 14 additions & 14 deletions src/meteors/attr/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import spyndex


from .attributes import HSISpatialAttributes, HSISpectralAttributes, ensure_torch_tensor
from .attributes import HSIAttributesSpatial, HSIAttributesSpectral, ensure_torch_tensor
from .explainer import Explainer
from .lime_base import Lime as LimeBase
from meteors import HSI
Expand Down Expand Up @@ -973,15 +973,15 @@ def attribute( # type: ignore
attribution_type: Literal["spatial", "spectral"] | None = None,
additional_forward_args: Any = None,
**kwargs: Any,
) -> HSISpatialAttributes | HSISpectralAttributes | list[HSISpatialAttributes] | list[HSISpectralAttributes]:
) -> HSIAttributesSpatial | HSIAttributesSpectral | list[HSIAttributesSpatial] | list[HSIAttributesSpectral]:
"""A wrapper function to attribute the image using the LIME method. It executes either the
`get_spatial_attributes` or `get_spectral_attributes` method based on the provided `attribution_type`. For more
detailed description of the methods, please refer to the respective method documentation.
Args:
hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
The output will be a list of HSISpatialAttributes or HSISpectralAttributes objects.
The output will be a list of HSIAttributesSpatial or HSIAttributesSpectral objects.
target (list[int] | int | None, optional): target class index for computing the attributions. If None,
methods assume that the output has only one class. If the output has multiple classes, the target index
must be provided. For multiple input images, a list of target indices can be provided, one for each
Expand All @@ -998,7 +998,7 @@ def attribute( # type: ignore
kwargs (Any): Additional keyword arguments for the LIME method.
Returns:
HSISpectralAttributes | HSISpatialAttributes | list[HSISpectralAttributes | HSISpatialAttributes]:
HSIAttributesSpectral | HSIAttributesSpatial | list[HSIAttributesSpectral | HSIAttributesSpatial]:
The computed attributions Spectral or Spatial for the input hyperspectral image(s).
if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
Expand Down Expand Up @@ -1045,21 +1045,21 @@ def get_spatial_attributes(
segmentation_method: Literal["slic", "patch"] = "slic",
additional_forward_args: Any = None,
**segmentation_method_params: Any,
) -> list[HSISpatialAttributes] | HSISpatialAttributes:
) -> list[HSIAttributesSpatial] | HSIAttributesSpatial:
"""
Get spatial attributes of an hsi image using the LIME method. Based on the provided hsi and segmentation mask
LIME method attributes the `superpixels` provided by the segmentation mask. Please refer to the original paper
`https://arxiv.org/abs/1602.04938` for more details or to Christoph Molnar's book
`https://christophm.github.io/interpretable-ml-book/lime.html`.
This function attributes the hyperspectral image using the LIME (Local Interpretable Model-Agnostic Explanations)
method for spatial data. It returns an `HSISpatialAttributes` object that contains the hyperspectral image,,
method for spatial data. It returns an `HSIAttributesSpatial` object that contains the hyperspectral image,,
the attributions, the segmentation mask, and the score of the interpretable model used for the explanation.
Args:
hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
The output will be a list of HSISpatialAttributes objects.
The output will be a list of HSIAttributesSpatial objects.
segmentation_mask (np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None, optional):
A segmentation mask according to which the attribution should be performed.
The segmentation mask should have a 2D or 3D shape, which can be broadcastable to the shape of the
Expand Down Expand Up @@ -1089,7 +1089,7 @@ def get_spatial_attributes(
**segmentation_method_params (Any): Additional parameters for the segmentation method.
Returns:
HSISpatialAttributes | list[HSISpatialAttributes]: An object containing the image, the attributions,
HSIAttributesSpatial | list[HSIAttributesSpatial]: An object containing the image, the attributions,
the segmentation mask, and the score of the interpretable model used for the explanation.
Raises:
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def get_spatial_attributes(

try:
spatial_attribution = [
HSISpatialAttributes(
HSIAttributesSpatial(
hsi=hsi_img,
attributes=lime_attr,
mask=segmentation_mask[idx].expand_as(hsi_img.image),
Expand All @@ -1199,20 +1199,20 @@ def get_spectral_attributes(
verbose: bool = False,
additional_forward_args: Any = None,
band_names: list[str | list[str]] | dict[tuple[str, ...] | str, int] | None = None,
) -> HSISpectralAttributes | list[HSISpectralAttributes]:
) -> HSIAttributesSpectral | list[HSIAttributesSpectral]:
"""
Attributes the hsi image using LIME method for spectral data. Based on the provided hsi and band mask, the LIME
method attributes the hsi based on `superbands` (clustered bands) provided by the band mask.
Please refer to the original paper `https://arxiv.org/abs/1602.04938` for more details or to
Christoph Molnar's book `https://christophm.github.io/interpretable-ml-book/lime.html`.
The function returns a HSISpectralAttributes object that contains the image, the attributions, the band mask,
The function returns a HSIAttributesSpectral object that contains the image, the attributions, the band mask,
the band names, and the score of the interpretable model used for the explanation.
Args:
hsi (list[HSI] | HSI): Input hyperspectral image(s) for which the attributions are to be computed.
If a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
The output will be a list of HSISpatialAttributes objects.
The output will be a list of HSIAttributesSpatial objects.
band_mask (np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor] | None, optional): Band mask that
is used for the spectral attribution. The band mask should have a 1D or 3D shape, which can be
broadcastable to the shape of the input image. The only dimensions on which the image and the mask shapes
Expand Down Expand Up @@ -1240,7 +1240,7 @@ def get_spectral_attributes(
band_names (list[str] | dict[str | tuple[str, ...], int] | None, optional): Band names. Defaults to None.
Returns:
HSISpectralAttributes | list[HSISpectralAttributes]: An object containing the image, the attributions,
HSIAttributesSpectral | list[HSIAttributesSpectral]: An object containing the image, the attributions,
the band mask, the band names, and the score of the interpretable model used for the explanation.
Raises:
Expand Down Expand Up @@ -1338,7 +1338,7 @@ def get_spectral_attributes(

try:
spectral_attribution = [
HSISpectralAttributes(
HSIAttributesSpectral(
hsi=hsi_img,
attributes=lime_attr,
mask=band_mask[idx].expand_as(hsi_img.image),
Expand Down
14 changes: 7 additions & 7 deletions src/meteors/attr/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from captum.attr import Occlusion as CaptumOcclusion

from .attributes import HSIAttributes, HSISpatialAttributes, HSISpectralAttributes
from .attributes import HSIAttributes, HSIAttributesSpatial, HSIAttributesSpectral
from .explainer import Explainer, validate_and_transform_baseline
from meteors import HSI
from meteors.models import ExplainableModel
Expand Down Expand Up @@ -214,7 +214,7 @@ def get_spatial_attributes(
additional_forward_args: Any = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
) -> HSISpatialAttributes | list[HSISpatialAttributes]:
) -> HSIAttributesSpatial | list[HSIAttributesSpatial]:
"""Compute spatial attributions for the input HSI using the Occlusion method. In this case, the sliding window
is applied to the spatial dimensions only.
Expand Down Expand Up @@ -256,7 +256,7 @@ def get_spatial_attributes(
show_progress (bool, optional): If True, displays a progress bar. Defaults to False.
Returns:
HSISpatialAttributes | list[HSISpatialAttributes]: The computed attributions for the input hyperspectral image(s).
HSIAttributesSpatial | list[HSIAttributesSpatial]: The computed attributions for the input hyperspectral image(s).
if a list of HSI objects is provided, the attributions are computed for each HSI object in the list.
Raises:
Expand Down Expand Up @@ -334,7 +334,7 @@ def get_spatial_attributes(

try:
spatial_attributes = [
HSISpatialAttributes(
HSIAttributesSpatial(
hsi=hsi_image, attributes=attribution, attribution_method=self.get_name(), mask=mask
)
for hsi_image, attribution, mask in zip(hsi, occlusion_attributions, segment_mask)
Expand All @@ -354,7 +354,7 @@ def get_spectral_attributes(
additional_forward_args: Any = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
) -> HSISpectralAttributes | list[HSISpectralAttributes]:
) -> HSIAttributesSpectral | list[HSIAttributesSpectral]:
"""Compute spectral attributions for the input HSI using the Occlusion method. In this case, the sliding window
is applied to the spectral dimension only.
Expand Down Expand Up @@ -396,7 +396,7 @@ def get_spectral_attributes(
show_progress (bool, optional): If True, displays a progress bar. Defaults to False.
Returns:
HSISpectralAttributes | list[HSISpectralAttributes]: The computed attributions for the input hyperspectral
HSIAttributesSpectral | list[HSIAttributesSpectral]: The computed attributions for the input hyperspectral
image(s). if a list of HSI objects is provided, the attributions are computed for each HSI object in
the list.
Expand Down Expand Up @@ -484,7 +484,7 @@ def get_spectral_attributes(

try:
spectral_attributes = [
HSISpectralAttributes(
HSIAttributesSpectral(
hsi=hsi_image,
attributes=attribution,
attribution_method=self.get_name(),
Expand Down
Loading

0 comments on commit d2374a8

Please sign in to comment.