Skip to content

Commit

Permalink
fix: Set the keep_gradient argument to False in methods that can stor…
Browse files Browse the repository at this point in the history
…e the gradient in the final results (#151)
  • Loading branch information
WolodjaZ authored Dec 3, 2024
1 parent bd13684 commit 3de807b
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 6 deletions.
7 changes: 5 additions & 2 deletions examples/hyperview_challenge/attr_showcase.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The plots that we use to visualize our attributions are created from `visualize`. The `visualize_attributes` function visualizes the attribution with two spatial plots on top and two spectral plots below. The spatial plots present correlation of the spatial pixels with the output. If the correlation is negative (red), it means that this pixel lowered the model estimation for class `P`; if positive (green), it increased the prediction; and if around zero (white), then it was not impactful. In the lower plots, we see attributions aggregated per wavelength, showcasing how each wavelength correlated with the output - once again, negative values lowered, positive ones increased, and values close to zero were not impactful. These attributions are aggregated spatially and spectrally with mean over spatial or spectral axis."
"The plots that we use to visualize our attributions are created from the `visualize`. The `visualize_attributes` function visualizes the attribution with two spatial plots on top and two spectral plots below. The spatial plots present correlation of the spatial pixels with the output. If the correlation is negative (red), it means that this pixel lowered the model estimation for class `P`; if positive (green), it increased the prediction; and if around zero (white), then it was not impactful. In the lower plots, we see attributions aggregated per wavelength, showcasing how each wavelength correlated with the output - once again, negative values lowered, positive ones increased, and values close to zero were not impactful. These attributions are aggregated spatially and spectrally with mean over spatial or spectral axis.\n",
"\n",
"`InputXGradient` and `IntegratedGradients` methods can also store gradients for future use. This maybe useful if we want to analyze the gradients in more detail or visualize them in a different way, but for efficiency reasons, `keep_gradient` is set to `False` by default."
]
},
{
Expand Down Expand Up @@ -863,7 +865,8 @@
"- `target` - the target index class to be analyzed: 0 for `P` (Phosphorus) class\n",
"- `n_samples` - the number of perturbed samples to generate\n",
"- `method` - how the final aggregation of all the generatated perturbation attributions should be calculated\n",
"- `baseline` - for the `HyperNoiseTunnel` method, the baseline value to replace the occluded region. It should be a float value."
"- `baseline` - for the `HyperNoiseTunnel` method, the baseline value to replace the occluded region. It should be a float value.\n",
"- `keep_gradient` - whether to keep the gradients for future use. If we choose the `IntegratedGradients` or `InputXGradient` as the base method, we can set this to `True` to store the gradients for future use. By default, it is set to `False`."
]
},
{
Expand Down
11 changes: 10 additions & 1 deletion src/meteors/attr/input_x_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def attribute(
hsi: list[HSI] | HSI,
target: list[int] | int | None = None,
additional_forward_args: Any = None,
keep_gradient: bool = False,
) -> HSIAttributes | list[HSIAttributes]:
"""
Method for generating attributions using the InputXGradient method.
Expand All @@ -52,6 +53,10 @@ def attribute(
containing multiple additional arguments including tensors or any arbitrary python types.
These arguments are provided to forward_func in order following the arguments in inputs.
Note that attributions are not computed with respect to these arguments. Default: None
keep_gradient (bool, optional): Indicates whether to keep the gradient tensors in memory. By the default,
the gradient tensors are removed from the computation graph after the attributions are computed, due
to memory efficiency. If the gradient tensors are needed for further processing, this parameter should
be set to True. Default: False
Returns:
HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
Expand Down Expand Up @@ -88,7 +93,11 @@ def attribute(

try:
attributes = [
HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
HSIAttributes(
hsi=hsi_image,
attributes=attribution if keep_gradient else attribution.detach(),
attribution_method=self.get_name(),
)
for hsi_image, attribution in zip(hsi, gradient_attribution)
]
except Exception as e:
Expand Down
12 changes: 11 additions & 1 deletion src/meteors/attr/integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def attribute(
"riemann_right", "riemann_left", "riemann_middle", "riemann_trapezoid", "gausslegendre"
] = "gausslegendre",
return_convergence_delta: bool = False,
keep_gradient: bool = False,
) -> HSIAttributes | list[HSIAttributes]:
"""
Method for generating attributions using the Integrated Gradients method.
Expand Down Expand Up @@ -94,6 +95,10 @@ def attribute(
return_convergence_delta (bool, optional): Indicates whether to return convergence delta or not.
If return_convergence_delta is set to True convergence delta will be returned in a tuple following
attributions. Default: False
keep_gradient (bool, optional): Indicates whether to keep the gradient tensors in memory. By the default,
the gradient tensors are removed from the computation graph after the attributions are computed, due
to memory efficiency. If the gradient tensors are needed for further processing, this parameter should
be set to True. Default: False
Returns:
HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
Expand Down Expand Up @@ -156,7 +161,12 @@ def attribute(

try:
attributes = [
HSIAttributes(hsi=hsi_image, attributes=attribution, score=error, attribution_method=self.get_name())
HSIAttributes(
hsi=hsi_image,
attributes=attribution if keep_gradient else attribution.detach(),
score=error,
attribution_method=self.get_name(),
)
for hsi_image, attribution, error in zip(hsi, attributions, approximation_error)
]
except Exception as e:
Expand Down
24 changes: 22 additions & 2 deletions src/meteors/attr/noise_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(self, callable: ExplainableModel | Explainer) -> None:
sig = inspect.signature(self.chained_explainer.attribute)
if "abs" in sig.parameters:
self.chained_explainer.attribute = partial(self.chained_explainer.attribute, abs=False) # type: ignore
if "keep_gradient" in sig.parameters:
self.chained_explainer.attribute = partial(self.chained_explainer.attribute, keep_gradient=True) # type: ignore

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -214,6 +216,7 @@ def attribute(
perturbation_axis: None | tuple[int | slice] = None,
stdevs: float | tuple[float, ...] = 1.0,
method: Literal["smoothgrad", "smoothgrad_sq", "vargrad"] = "smoothgrad",
keep_gradient: bool = False,
) -> HSIAttributes | list[HSIAttributes]:
"""
Method for generating attributions using the Noise Tunnel method.
Expand Down Expand Up @@ -254,6 +257,10 @@ def attribute(
each value in the tuple is used for the corresponding input. Default: 1.0
method (Literal["smoothgrad", "smoothgrad_sq", "vargrad"], optional): Smoothing type of the attributions.
smoothgrad, smoothgrad_sq or vargrad Default: smoothgrad if type is not provided.
keep_gradient (bool, optional): Indicates whether to keep the gradient tensors in memory. By the default,
the gradient tensors are removed from the computation graph after the attributions are computed, due
to memory efficiency. If the gradient tensors are needed for further processing, this parameter should
be set to True. Default: False
Returns:
HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
Expand Down Expand Up @@ -306,7 +313,11 @@ def attribute(

try:
attributes = [
HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
HSIAttributes(
hsi=hsi_image,
attributes=attribution if keep_gradient else attribution.detach(),
attribution_method=self.get_name(),
)
for hsi_image, attribution in zip(hsi, nt_attributes)
]
except Exception as e:
Expand Down Expand Up @@ -415,6 +426,7 @@ def attribute(
perturbation_prob: float = 0.5,
num_perturbed_bands: int | None = None,
method: Literal["smoothgrad", "smoothgrad_sq", "vargrad"] = "smoothgrad",
keep_gradient: bool = False,
) -> HSIAttributes | list[HSIAttributes]:
"""
Method for generating attributions using the Hyper Noise Tunnel method.
Expand Down Expand Up @@ -453,6 +465,10 @@ def attribute(
If set to None, the bands are perturbed with probability `perturbation_prob` each. Defaults to None.
method (Literal["smoothgrad", "smoothgrad_sq", "vargrad"], optional): Smoothing type of the attributions.
smoothgrad, smoothgrad_sq or vargrad Default: smoothgrad if type is not provided.
keep_gradient (bool, optional): Indicates whether to keep the gradient tensors in memory. By the default,
the gradient tensors are removed from the computation graph after the attributions are computed, due
to memory efficiency. If the gradient tensors are needed for further processing, this parameter should
be set to True. Default: False
Returns:
HSIAttributes | list[HSIAttributes]: The computed attributions for the input hyperspectral image(s).
Expand Down Expand Up @@ -511,7 +527,11 @@ def attribute(

try:
attributes = [
HSIAttributes(hsi=hsi_image, attributes=attribution, attribution_method=self.get_name())
HSIAttributes(
hsi=hsi_image,
attributes=attribution if keep_gradient else attribution.detach(),
attribution_method=self.get_name(),
)
for hsi_image, attribution in zip(hsi, hnt_attributes)
]
except Exception as e:
Expand Down
12 changes: 12 additions & 0 deletions tests/attr/test_attribution_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def test_integrated_gradients(explainable_toy_model):

attributions = ig.attribute(image, return_convergence_delta=True)
assert attributions.score is not None
assert not attributions.attributes.requires_grad

# Test keeping gradient
attributions = ig.attribute(image, keep_gradient=True)
assert attributions.attributes.shape == image.image.shape
assert attributions.attributes.requires_grad

# Test multiple images
attributions = ig.attribute([image, image], return_convergence_delta=True)
Expand Down Expand Up @@ -176,6 +182,12 @@ def test_input_x_gradient(explainable_toy_model):
assert attributions.attributes.shape == image.image.shape

assert not input_x_gradient.has_convergence_delta()
assert not attributions.attributes.requires_grad

# Test keeping gradient
attributions = input_x_gradient.attribute(image, keep_gradient=True)
assert attributions.attributes.shape == image.image.shape
assert attributions.attributes.requires_grad

# Test multiple images
attributions = input_x_gradient.attribute([image, image])
Expand Down
8 changes: 8 additions & 0 deletions tests/attr/test_noise_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_noise_attribute(explainable_toy_model):

assert attributes is not None
assert attributes.attributes.shape == (5, 5, 5)
assert not attributes.attributes.requires_grad

attributes = noise_tunnel.attribute(image, n_samples=1, method="smoothgrad_sq")

Expand All @@ -170,6 +171,9 @@ def test_noise_attribute(explainable_toy_model):
assert attributes is not None
assert attributes.attributes.shape == (5, 5, 5)

attributes = noise_tunnel.attribute(image, n_samples=1, keep_gradient=True)
assert attributes.attributes.requires_grad

# test changing the image orientation
image.orientation = ("W", "C", "H")
attributes = noise_tunnel.attribute(image, n_samples=1)
Expand Down Expand Up @@ -254,6 +258,7 @@ def test_hyper_attribute(explainable_toy_model):

assert attributes is not None
assert attributes.attributes.shape == (5, 5, 5)
assert not attributes.attributes.requires_grad

attributes = hyper_noise_tunnel.attribute(image, n_samples=1, method="smoothgrad_sq")

Expand All @@ -265,6 +270,9 @@ def test_hyper_attribute(explainable_toy_model):
assert attributes is not None
assert attributes.attributes.shape == (5, 5, 5)

attributes = hyper_noise_tunnel.attribute(image, n_samples=1, keep_gradient=True)
assert attributes.attributes.requires_grad

# test changing the image orientation
image.orientation = ("W", "C", "H")
attributes = hyper_noise_tunnel.attribute(image, n_samples=1)
Expand Down

0 comments on commit 3de807b

Please sign in to comment.