Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Sm00thix/IKPLS
Browse files Browse the repository at this point in the history
  • Loading branch information
Sm00thix committed Nov 9, 2023
2 parents e3a1115 + 2ed738b commit 5e31580
Show file tree
Hide file tree
Showing 17 changed files with 191 additions and 69 deletions.
50 changes: 50 additions & 0 deletions .github/actions/build/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: Build
description: Test

runs:
# strategy:
# fail-fast: false
# matrix:
# # os: [ubuntu-latest, windows-latest, macos-latest]
# os: [ubuntu-latest]
# # python-version: ["3.9", "3.10", "3.11", "3.12"]
# python-version: ["3.10"]
# include:
# - os: ubuntu-latest
# path: ~/.cache/pip

# runs-on: ${{ matrix.os }}
using: "composite"
steps:
- uses: actions/checkout@v3

- name: Set up Python #${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install build --user
pip install -U "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
shell: bash

- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
shell: bash

- name: Build a binary wheel and a source tarball
run: python3 -m build
shell: bash

- name: Store the distribution packages
uses: actions/upload-artifact@v3
with:
name: python-package-distributions
path: dist/
13 changes: 13 additions & 0 deletions .github/actions/publish/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Publish

runs:
using: "composite"
steps:
- name: Download all the dists
uses: actions/download-artifact@v3
with:
name: python-package-distributions
path: dist/

- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
24 changes: 24 additions & 0 deletions .github/actions/test/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Test
description: Test

inputs:
PYTHON_VERSION:
description: test
required: true

runs:
using: "composite"
steps:
- name: Test
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pandas
python -m pip install --upgrade numpy
python -m pip install --upgrade tqdm
python -m pip install --upgrade scikit-learn
python -m pip install flake8 pytest
python -m pip install build --user
pip install -U "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install pytest pytest-cov
python3 -m pytest tests --doctest-modules --junitxml=junit/test-results.xml --cov=ikpls/ --cov-report=xml --cov-report=html
shell: bash
30 changes: 30 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python package

on:
pull_request:
branches: [ "main" ]

jobs:
test_package:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest] # os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.10"] # python-version: ["3.9", "3.10", "3.11", "3.12"]
env:
JAX_ENABLE_X64: True
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/test

build_package:
runs-on: ubuntu-latest
env:
JAX_ENABLE_X64: True
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/build
53 changes: 0 additions & 53 deletions .github/workflows/python-package.yml

This file was deleted.

42 changes: 42 additions & 0 deletions .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python package

on:
push:
branches: [ "main" ]

jobs:
test_package:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest] # os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.10"] # python-version: ["3.9", "3.10", "3.11", "3.12"]
env:
JAX_ENABLE_X64: True
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/test

build_package:
runs-on: ubuntu-latest
env:
JAX_ENABLE_X64: True
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/build

publish_package:
needs: [build_package, test_package]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/publish
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/ikpls
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

!.github/*
File renamed without changes.
2 changes: 1 addition & 1 deletion algorithms/jax_ikpls_alg_1.py → ikpls/jax_ikpls_alg_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from algorithms.jax_ikpls_base import PLSBase
from ikpls.jax_ikpls_base import PLSBase
import jax
from jax.experimental import host_callback
import jax.numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion algorithms/jax_ikpls_alg_2.py → ikpls/jax_ikpls_alg_2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from algorithms.jax_ikpls_base import PLSBase
from ikpls.jax_ikpls_base import PLSBase
import jax
from jax.experimental import host_callback
import jax.numpy as jnp
Expand Down
File renamed without changes.
File renamed without changes.
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[tool.poetry]
name = "ikpls"
version = "0.1.1"
description = ""
authors = ["Sm00thix <[email protected]>"]
license = "Apache-2.0"
readme = "README.md"
repository = "https://github.com/Sm00thix/IKPLS"

[tool.poetry.dependencies]
python = ">=3.9, <3.13"
numpy = "^1.26.1"
jax = "^0.4.19"
scikit-learn = "^1.3.2"
tqdm = "^4.66.1"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
5 changes: 0 additions & 5 deletions requirements.txt

This file was deleted.

4 changes: 2 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import load_data, test_consistency
from . import load_data, test_ikpls

__all__ = ["load_data", "test_consistency"]
__all__ = ["load_data", "test_ikpls"]
8 changes: 4 additions & 4 deletions tests/test_consistency.py → tests/test_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from jax import numpy as jnp
from typing import Tuple, Callable
from sklearn.cross_decomposition import PLSRegression as SkPLS
from algorithms.jax_ikpls_alg_1 import PLS as JAX_Alg_1
from algorithms.jax_ikpls_alg_2 import PLS as JAX_Alg_2
from algorithms.numpy_ikpls import PLS as NpPLS
from ikpls.jax_ikpls_alg_1 import PLS as JAX_Alg_1
from ikpls.jax_ikpls_alg_2 import PLS as JAX_Alg_2
from ikpls.numpy_ikpls import PLS as NpPLS
from . import load_data


Expand Down Expand Up @@ -2234,7 +2234,7 @@ def jax_rmse_per_component(
] + sk_models[i].intercept_
sk_preds[i] = sk_pred
assert_allclose(
sk_pred[-1], sk_models[i].predict(X), atol=0, rtol=1e-14
sk_pred[-1], sk_models[i].predict(X), atol=0, rtol=1e-13
) # Sanity check. SkPLS also uses the maximum number of components in its predict method.

# Compute RMSE on the validation predictions
Expand Down
6 changes: 3 additions & 3 deletions time_pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
gen_random_data,
SK_PLS_All_Components,
)
from algorithms.numpy_ikpls import PLS as NP_PLS
from algorithms.jax_ikpls_alg_1 import PLS as JAX_PLS_Alg_1
from algorithms.jax_ikpls_alg_2 import PLS as JAX_PLS_Alg_2
from ikpls.numpy_ikpls import PLS as NP_PLS
from ikpls.jax_ikpls_alg_1 import PLS as JAX_PLS_Alg_1
from ikpls.jax_ikpls_alg_2 import PLS as JAX_PLS_Alg_2

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down

0 comments on commit 5e31580

Please sign in to comment.