diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d0189db14..fc5733ba0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,6 @@ jobs: build-and-test: name: "Python ${{ matrix.python-version }} on ${{ matrix.os }} jax=${{ matrix.jax-version }}" runs-on: "${{ matrix.os }}" - strategy: matrix: python-version: ["3.9", "3.10", "3.11"] @@ -22,7 +21,6 @@ jobs: - python-version: "3.9" os: "ubuntu-latest" jax-version: "0.4.27" # Keep version in sync with pyproject.toml and copy.bara.sky! - steps: - uses: "actions/checkout@v2" - uses: "actions/setup-python@v4" @@ -33,13 +31,72 @@ jobs: - name: Run CI tests run: JAX_VERSION="${{ matrix.jax-version }}" bash test.sh shell: bash + doctests: + name: "Doctests on ${{ matrix.os }} with Python ${{ matrix.python-version }}" + runs-on: "${{ matrix.os }}" + strategy: + matrix: + python-version: ["3.11"] # only build docs with a somewhat latest python + os: [ubuntu-latest] + steps: + - uses: "actions/checkout@v2" + - uses: "actions/setup-python@v4" + with: + python-version: "${{ matrix.python-version }}" + cache: "pip" + cache-dependency-path: 'pyproject.toml' + - name: Build docs and run doctests + run: | + python3 -m pip install --quiet --editable ".[docs]" + cd docs + make html + make doctest # run doctests + shell: bash + linting: + name: "Lint check with flake8 and pylint" + runs-on: "ubuntu-latest" + steps: + - uses: "actions/checkout@v2" + - uses: "actions/setup-python@v4" + with: + python-version: "3.11" + cache: "pip" + cache-dependency-path: "pyproject.toml" + - name: Install linting dependencies + run: | + pip install -U pip setuptools wheel + pip install -U flake8 pytest-xdist pylint pylint-exit + - name: Lint with flake8 + run: | + python3 -m flake8 --select=E9,F63,F7,F82,E225,E251 --show-source --statistics + - name: Lint module files with pylint + run: | + PYLINT_ARGS="-efail -wfail -cfail -rfail" + python3 -m pylint --rcfile=.pylintrc $(find optax -name '*.py' | grep -v 'test.py' | xargs) -d E1102 || pylint-exit $PYLINT_ARGS $? + - name: Lint test files with pylint + run: | + PYLINT_ARGS="-efail -wfail -cfail -rfail" + python3 -m pylint --rcfile=.pylintrc $(find optax -name '*_test.py' | xargs) -d W0212,E1102 || pylint-exit $PYLINT_ARGS $? + ruff-lint: + name: "Lint check with ruff" + runs-on: "ubuntu-latest" + steps: + - uses: "actions/checkout@v2" + - uses: "actions/setup-python@v4" + with: + python-version: "3.11" + cache: "pip" + cache-dependency-path: "pyproject.toml" + - name: Install ruff and lint check + run: | + pip install -U ruff + ruff check . markdown-link-check: name: "Check links in markdown files" runs-on: "ubuntu-latest" steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Check links uses: gaurav-nelson/github-action-markdown-link-check@v1 with: diff --git a/examples/contrib/reduce_on_plateau.ipynb b/examples/contrib/reduce_on_plateau.ipynb index 4187947ad..05e3d9c2b 100644 --- a/examples/contrib/reduce_on_plateau.ipynb +++ b/examples/contrib/reduce_on_plateau.ipynb @@ -372,7 +372,7 @@ "source": [ "opt = optax.chain(\n", " optax.adam(LEARNING_RATE),\n", - " reduce_on_plateau(\n", + " contrib.reduce_on_plateau(\n", " patience=PATIENCE,\n", " cooldown=COOLDOWN,\n", " factor=FACTOR,\n", @@ -759,7 +759,7 @@ } ], "source": [ - "transform = reduce_on_plateau(\n", + "transform = contrib.reduce_on_plateau(\n", " patience=PATIENCE,\n", " cooldown=COOLDOWN,\n", " factor=FACTOR,\n", diff --git a/examples/linear_assignment_problem.ipynb b/examples/linear_assignment_problem.ipynb index e994f00bd..f77d62477 100644 --- a/examples/linear_assignment_problem.ipynb +++ b/examples/linear_assignment_problem.ipynb @@ -55,7 +55,7 @@ "outputs": [], "source": [ "import networkx as nx\n", - "from jax import numpy as jnp, random\n", + "from jax import random\n", "import optax\n", "from matplotlib import pyplot as plt" ] diff --git a/examples/nanolm.ipynb b/examples/nanolm.ipynb index b098b69c9..f34db019c 100644 --- a/examples/nanolm.ipynb +++ b/examples/nanolm.ipynb @@ -629,7 +629,7 @@ } ], "source": [ - "plt.title(f\"Convergence of adamw (train loss)\")\n", + "plt.title(\"Convergence of adamw (train loss)\")\n", "plt.plot(all_train_losses, label=\"train\", lw=3)\n", "plt.plot(\n", " jnp.arange(0, len(all_eval_losses) * N_FREQ_EVAL, N_FREQ_EVAL),\n", diff --git a/optax/schedules/_inject.py b/optax/schedules/_inject.py index 43bd2feef..2eb19e945 100644 --- a/optax/schedules/_inject.py +++ b/optax/schedules/_inject.py @@ -30,7 +30,7 @@ def _convert_floats(x, dtype): """Convert float-like inputs to dtype, rest pass through.""" - if jax.dtypes.scalar_type_of(x) == float: + if jax.dtypes.scalar_type_of(x) is float: return jnp.asarray(x, dtype=dtype) return x diff --git a/pyproject.toml b/pyproject.toml index 7bccb6db3..3b01fa880 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,3 +87,17 @@ dp-accounting = [ [tool.setuptools.packages.find] include = ["README.md", "LICENSE"] exclude = ["*_test.py"] + +[tool.ruff.lint] +select = [ + "F", + "E", +] +ignore = [ + "E731", # lambdas are allowed + "E501", # don't check line lengths + "F401", # allow unused imports + "E402", # allow modules not at top of file + "E741", # allow "l" as a variable name + "E703", # allow semicolons (for jupyter notebooks) +] diff --git a/test.sh b/test.sh index e6d258535..f387d95bc 100755 --- a/test.sh +++ b/test.sh @@ -101,4 +101,7 @@ make html make doctest # run doctests cd .. +pip install -U ruff +ruff check . + echo "All tests passed. Congrats!"