Skip to content

Commit

Permalink
Remove the initial argument to jax.nn.softmax and `jax.nn.log_sof…
Browse files Browse the repository at this point in the history
…tmax`.

This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later.

PiperOrigin-RevId: 693023366
  • Loading branch information
Jake VanderPlas authored and Google-ML-Automation committed Nov 4, 2024
1 parent 26c0c5c commit e9acaa8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* The deprecated module `jax.experimental.export` has been removed. It was replaced
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
for information on migrating to the new API.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
has been removed, after being deprecated in v0.4.27.
* The following deprecated methods and functions in {mod}`jax.export` have
been removed:
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
Expand Down
15 changes: 6 additions & 9 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import math
import numpy as np
from typing import Any, Literal
import warnings

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -502,7 +501,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array:
def log_softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
initial: Unspecified = _UNSPECIFIED) -> Array:
r"""Log-Softmax function.
Computes the logarithm of the :code:`softmax` function, which rescales
Expand All @@ -528,10 +527,9 @@ def log_softmax(x: ArrayLike,
See also:
:func:`softmax`
"""
# TODO(jakevdp): remove the initial argument after JAX v0.4.40.
if initial is not _UNSPECIFIED:
# Added 2024-4-10
warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.",
DeprecationWarning, stacklevel=2)
raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.")
del initial
numpy_util.check_arraylike("log_softmax", x)
x_arr = jnp.asarray(x)
Expand All @@ -551,7 +549,7 @@ def log_softmax(x: ArrayLike,
def softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None,
initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array:
initial: Unspecified = _UNSPECIFIED) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
Expand All @@ -577,10 +575,9 @@ def softmax(x: ArrayLike,
See also:
:func:`log_softmax`
"""
# TODO(jakevdp): remove the initial argument after JAX v0.4.40.
if initial is not _UNSPECIFIED:
# Added 2024-4-10
warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.",
DeprecationWarning, stacklevel=2)
raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.")
del initial
if config.softmax_custom_jvp.value:
# mypy is confused by the `functools.partial` application in the definition
Expand Down

0 comments on commit e9acaa8

Please sign in to comment.