Skip to content

Commit

Permalink
feat: update python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
ChieloNewctle committed May 9, 2024
1 parent dd4d9b6 commit 6e3e4ee
Show file tree
Hide file tree
Showing 12 changed files with 431 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Cargo Build & Test
name: Cargo Build & Test the Crate

on:
push:
Expand Down
231 changes: 231 additions & 0 deletions .github/workflows/ci_python.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
name: Cargo Build & Test the Python Bindings

defaults:
run:
working-directory: python

on:
push:
branches:
- main
tags:
- "*"
pull_request:
workflow_dispatch:

permissions:
contents: read

jobs:
format:
name: Check Python format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
- name: Install dependencies
run: pip install ruff black
- name: Ruff
run: ruff check .
- name: Black
run: black --check --diff .

rustfmt:
name: Check Rust format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- run: rustup update stable && rustup default stable
- run: rustup component add rustfmt
- run: cargo fmt --all --check

test:
name: Run tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "pypy3.10"]

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install locally
run: pip install -e ".[test]"
- name: Install additional dependencies
run: pip install pytest-md pytest-emoji
- uses: pavelzw/pytest-action@v2
with:
emoji: false
verbose: true
job-summary: true
- name: Test building wheels
uses: PyO3/maturin-action@v1
with:
sccache: true
manylinux: auto

linux:
runs-on: ubuntu-latest
strategy:
matrix:
target: [x86_64, aarch64, armv7]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --interpreter 3.8 pypy3.8 pypy3.9 pypy3.10
sccache: true
manylinux: auto
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-linux-${{ matrix.target }}
path: dist
- name: pytest
if: ${{ startsWith(matrix.target, 'x86_64') }}
shell: bash
run: |
set -e
pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall
pytest --import-mode=importlib
- name: pytest
if: ${{ !startsWith(matrix.target, 'x86') && matrix.target != 'ppc64' }}
uses: uraimo/[email protected]
with:
arch: ${{ matrix.target }}
distro: ubuntu22.04
githubToken: ${{ github.token }}
install: |
apt-get update
apt-get install -y --no-install-recommends python3 python3-pip
pip3 install -U pip
run: |
set -e
pip3 install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall
pytest --import-mode=importlib
windows:
runs-on: windows-latest
strategy:
matrix:
target: [x64]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: "3.10"
architecture: ${{ matrix.target }}
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --interpreter 3.8 pypy3.8 pypy3.9 pypy3.10
sccache: true
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-windows-${{ matrix.target }}
path: dist
- name: pytest
if: ${{ !startsWith(matrix.target, 'aarch64') }}
shell: bash
run: |
set -e
pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall
pytest --import-mode=importlib
macos:
runs-on: macos-latest
strategy:
matrix:
target: [x86_64, aarch64]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --interpreter 3.8 pypy3.8 pypy3.9 pypy3.10
sccache: true
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-macos-${{ matrix.target }}
path: dist
- name: pytest
if: ${{ !startsWith(matrix.target, 'aarch64') }}
shell: bash
run: |
set -e
pip install --pre "mtc_token_healing[test]" --find-links dist --force-reinstall
pytest --import-mode=importlib
sdist:
needs: [test]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Build sdist
uses: PyO3/maturin-action@v1
with:
command: sdist
args: --out dist
- name: Upload sdist
uses: actions/upload-artifact@v4
with:
name: wheels-sdist
path: dist

release:
name: Release
runs-on: ubuntu-latest
if: "startsWith(github.ref, 'refs/tags/')"
needs: [test, format, rustfmt, linux, windows, macos, sdist]
permissions:
# Used to upload release artifacts
contents: write
steps:
- uses: actions/download-artifact@v4
with:
pattern: wheels-*
merge-multiple: true
- name: Publish to PyPI
uses: PyO3/maturin-action@v1
env:
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
with:
command: upload
args: --non-interactive --skip-existing *
- name: Upload to GitHub Release
uses: softprops/action-gh-release@v2
with:
files: |
*.whl
*.tar.gz
prerelease: ${{ contains(github.ref, 'alpha') || contains(github.ref, 'beta') }}
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ description = "Token healing implementation"
repository = "https://github.com/ModelTC/mtc-token-healing"
homepage = "https://github.com/ModelTC/mtc-token-healing"
documentation = "https://docs.rs/mtc-token-healing"
authors = ["Chielo Newctle <[email protected]>"]

[package]
name = "mtc-token-healing"
Expand All @@ -19,16 +20,16 @@ description.workspace = true
repository.workspace = true
homepage.workspace = true
documentation.workspace = true
authors.workspace = true
readme = "README.md"
authors = ["Chielo Newctle <[email protected]>"]
exclude = ["release-plz.toml", ".github"]
exclude = ["release-plz.toml", ".github", "python"]

[dependencies]
derive_more = "0.99.17"
general-sam = { version = "1.0.0", features = ["trie"] }
pyo3 = { version = "0.21.2", optional = true }
smallvec = "1.13.2"
thiserror = "1.0.59"
thiserror = "1.0.60"

[features]
pyo3 = ["dep:pyo3"]
Expand All @@ -38,7 +39,7 @@ clap = { version = "4.5.4", features = ["derive", "env"] }
color-eyre = "0.6.3"
rand = "0.8.5"
regex = "1.10.4"
serde_json = "1.0.116"
serde_json = "1.0.117"
tokenizers = { version = "0.19.1", features = ["hf-hub", "http"] }
tokio = { version = "1.37.0", features = ["rt-multi-thread"] }

Expand Down
1 change: 1 addition & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ description.workspace = true
repository.workspace = true
homepage.workspace = true
documentation.workspace = true
authors.workspace = true

[lib]
name = "mtc_token_healing"
Expand Down
18 changes: 17 additions & 1 deletion python/mtc_token_healing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
from .mtc_token_healing import CountInfo
from .mtc_token_healing import (
BestChoice,
CountInfo,
InferRequest,
InferResponse,
Prediction,
VocabPrefixAutomaton,
ReorderedTokenId,
SearchTree,
)

__all__ = [
"BestChoice",
"CountInfo",
"InferRequest",
"InferResponse",
"Prediction",
"VocabPrefixAutomaton",
"ReorderedTokenId",
"SearchTree",
]
9 changes: 9 additions & 0 deletions python/mtc_token_healing/mtc_token_healing.pyi
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
TokenId = int

class BestChoice: ...
class CountInfo: ...
class InferRequest: ...
class InferResponse: ...
class Prediction: ...
class VocabPrefixAutomaton: ...
class ReorderedTokenId: ...
class SearchTree: ...
2 changes: 2 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ classifiers = [
]
dynamic = ["version"]

[tool.maturin]

[project.optional-dependencies]
test = ["pytest"]
12 changes: 11 additions & 1 deletion python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
use ::mtc_token_healing::CountInfo;
use ::mtc_token_healing::{
vocab::PyVocabPrefixAutomaton, BestChoice, CountInfo, InferRequest, InferResponse, Prediction,
ReorderedTokenId, SearchTree,
};
use pyo3::prelude::*;

#[pymodule]
fn mtc_token_healing(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<BestChoice>()?;
m.add_class::<CountInfo>()?;
m.add_class::<InferRequest>()?;
m.add_class::<InferResponse>()?;
m.add_class::<Prediction>()?;
m.add_class::<PyVocabPrefixAutomaton>()?;
m.add_class::<ReorderedTokenId>()?;
m.add_class::<SearchTree>()?;
Ok(())
}
1 change: 1 addition & 0 deletions src/choice.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::TokenId;

#[derive(Clone, Debug)]
#[cfg_attr(feature = "pyo3", pyo3::pyclass(get_all, frozen))]
pub struct BestChoice {
pub extra_token_ids: Vec<TokenId>,
pub accum_log_prob: f64,
Expand Down
Loading

0 comments on commit 6e3e4ee

Please sign in to comment.