Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
enh: apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Jon Haitz Legarreta Gorroño <[email protected]>
  • Loading branch information
oesteban and jhlegarreta committed Jul 9, 2024
1 parent af6064a commit 5a9a967
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions src/eddymotion/model/dipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def compute_spherical_covariance(

_ensure_positive_scale(a)

return np.where(theta <= a, 1 - 3 * theta / (2 * a) + theta ** 3 / (2 * a ** 3), 0)
return np.where(theta <= a, 1 - 3 * theta / (2 * a) + theta**3 / (2 * a**3), 0)


def compute_derivative(
Expand Down Expand Up @@ -367,9 +367,9 @@ def compute_derivative(
min_angles = theta > a

if weighting == "spherical":
deriv_a = 1.5 * (theta[min_angles] / a ** 2 - theta[min_angles] ** 3 / a ** 4)
deriv_a = 1.5 * (theta[min_angles] / a**2 - theta[min_angles] ** 3 / a**4)
elif weighting == "exponential":
deriv_a = np.exp(-theta[min_angles] / a) * theta[min_angles] / a ** 2
deriv_a = np.exp(-theta[min_angles] / a) * theta[min_angles] / a**2
else:
raise ValueError(f"Unknown kernel weighting '{weighting}'.")

Expand Down Expand Up @@ -464,6 +464,27 @@ def __init__(
a_bounds: tuple[float, float] = (1e-5, np.pi),
sigma_sq_bounds: tuple[float, float] = (1e-5, 1e4),
):
r"""
Initialize a kernel with pairwise angles.
Parameters
----------
weighting : :obj:`str`
The type of kernel to build (either "exponential", "sphere", or "test").
lambda_s : :obj:`float`
The :math:`\lambda_s` hyperparameter.
a : :obj:`float`
Minimum angle in rads.
sigma_sq : :obj:`float`
Error allowed in collinear orientations.
lambda_s_bounds : :obj:`tuple`
Bounds for the :math:`\lambda_s` hyperparameter.
a_bounds : :obj:`tuple`
Bounds for the a parameter.
sigma_sq_bounds : :obj:`tuple`
Bounds for the error parameter.
"""
self._weighting = weighting # For __repr__
self.lambda_s = lambda_s
self.a = a
Expand Down Expand Up @@ -492,9 +513,9 @@ def __call__(self, gtab_X, gtab_Y=None, eval_gradient=False):
----------
gtab_X: :obj:`~dipy.core.gradients.GradientTable`
Gradient table (X)
gtab_Y: :obj:`~dipy.core.gradients.GradientTable`
gtab_Y: :obj:`~dipy.core.gradients.GradientTable`, optional
Gradient table (Y, optional)
eval_gradient : :obj:`bool`
eval_gradient : :obj:`bool`, optional
Determines whether the gradient with respect to the log of
the kernel hyperparameter is computed.
Only supported when gtab_Y is ``None``.
Expand Down

0 comments on commit 5a9a967

Please sign in to comment.