forked from k2-fsa/sherpa-onnx
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support exporting models to onnx from 3D-Speaker (k2-fsa#522)
- Loading branch information
1 parent
5526691
commit 07e2b9a
Showing
10 changed files
with
442 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
name: export-3dspeaker-to-onnx | ||
|
||
on: | ||
workflow_dispatch: | ||
|
||
concurrency: | ||
group: export-3dspeaker-to-onnx-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
export-3dspeaker-to-onnx: | ||
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
name: export 3d-speaker to ONNX | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: [macos-latest] | ||
python-version: ["3.8"] | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- name: Setup Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Run | ||
shell: bash | ||
run: | | ||
cd scripts/3dspeaker | ||
./run.sh | ||
mv -v *.onnx ../.. | ||
- name: Release | ||
uses: svenstaro/upload-release-action@v2 | ||
with: | ||
file_glob: true | ||
file: ./*.onnx | ||
overwrite: true | ||
repo_name: k2-fsa/sherpa-onnx | ||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
tag: speaker-recongition-models |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Introduction | ||
|
||
This directory contains scripts | ||
about exporting models from https://github.com/alibaba-damo-academy/3D-Speaker | ||
to `onnx` so that they can be used in `sherpa-onnx`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2023-2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
|
||
import argparse | ||
import json | ||
import os | ||
import pathlib | ||
import re | ||
from typing import Dict | ||
|
||
import onnx | ||
import torch | ||
from infer_sv import supports | ||
from modelscope.hub.snapshot_download import snapshot_download | ||
from speakerlab.utils.builder import dynamic_import | ||
|
||
|
||
def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
"""Add meta data to an ONNX model. It is changed in-place. | ||
Args: | ||
filename: | ||
Filename of the ONNX model to be changed. | ||
meta_data: | ||
Key-value pairs. | ||
""" | ||
model = onnx.load(filename) | ||
for key, value in meta_data.items(): | ||
meta = model.metadata_props.add() | ||
meta.key = key | ||
meta.value = str(value) | ||
|
||
onnx.save(model, filename) | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
required=True, | ||
choices=[ | ||
"speech_campplus_sv_en_voxceleb_16k", | ||
"speech_campplus_sv_zh-cn_16k-common", | ||
"speech_eres2net_sv_en_voxceleb_16k", | ||
"speech_eres2net_sv_zh-cn_16k-common", | ||
"speech_eres2net_base_200k_sv_zh-cn_16k-common", | ||
"speech_eres2net_base_sv_zh-cn_3dspeaker_16k", | ||
"speech_eres2net_large_sv_zh-cn_3dspeaker_16k", | ||
], | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
@torch.no_grad() | ||
def main(): | ||
args = get_args() | ||
local_model_dir = "pretrained" | ||
model_id = f"damo/{args.model}" | ||
conf = supports[model_id] | ||
cache_dir = snapshot_download( | ||
model_id, | ||
revision=conf["revision"], | ||
) | ||
cache_dir = pathlib.Path(cache_dir) | ||
|
||
save_dir = os.path.join(local_model_dir, model_id.split("/")[1]) | ||
save_dir = pathlib.Path(save_dir) | ||
save_dir.mkdir(exist_ok=True, parents=True) | ||
|
||
download_files = ["examples", conf["model_pt"]] | ||
for src in cache_dir.glob("*"): | ||
if re.search("|".join(download_files), src.name): | ||
dst = save_dir / src.name | ||
try: | ||
dst.unlink() | ||
except FileNotFoundError: | ||
pass | ||
dst.symlink_to(src) | ||
pretrained_model = save_dir / conf["model_pt"] | ||
pretrained_state = torch.load(pretrained_model, map_location="cpu") | ||
|
||
model = conf["model"] | ||
embedding_model = dynamic_import(model["obj"])(**model["args"]) | ||
embedding_model.load_state_dict(pretrained_state) | ||
embedding_model.eval() | ||
|
||
with open(f"{cache_dir}/configuration.json") as f: | ||
json_config = json.loads(f.read()) | ||
print(json_config) | ||
|
||
T = 100 | ||
C = 80 | ||
x = torch.rand(1, T, C) | ||
filename = f"{args.model}.onnx" | ||
torch.onnx.export( | ||
embedding_model, | ||
x, | ||
filename, | ||
opset_version=13, | ||
input_names=["x"], | ||
output_names=["embedding"], | ||
dynamic_axes={ | ||
"x": {0: "N", 1: "T"}, | ||
"embeddings": {0: "N"}, | ||
}, | ||
) | ||
|
||
# all models from 3d-speaker expect input samples in the range | ||
# [-1, 1] | ||
normalize_samples = 1 | ||
|
||
# all models from 3d-speaker normalize the features by the global mean | ||
feature_normalize_type = "global-mean" | ||
sample_rate = json_config["model"]["model_config"]["sample_rate"] | ||
|
||
feat_dim = conf["model"]["args"]["feat_dim"] | ||
assert feat_dim == 80, feat_dim | ||
|
||
output_dim = conf["model"]["args"]["embedding_size"] | ||
|
||
if "zh-cn" in args.model: | ||
language = "Chinese" | ||
elif "en" in args.model: | ||
language = "English" | ||
else: | ||
raise ValueError(f"Unsupported language for model {args.model}") | ||
|
||
comment = f"This model is from damo/{args.model}" | ||
url = f"https://www.modelscope.cn/models/damo/{args.model}/summary" | ||
|
||
meta_data = { | ||
"framework": "3d-speaker", | ||
"language": language, | ||
"url": url, | ||
"comment": comment, | ||
"sample_rate": sample_rate, | ||
"output_dim": output_dim, | ||
"normalize_samples": normalize_samples, | ||
"feature_normalize_type": feature_normalize_type, | ||
} | ||
print(meta_data) | ||
add_meta_data(filename=filename, meta_data=meta_data) | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -e | ||
|
||
function install_3d_speaker() { | ||
echo "Install 3D-Speaker" | ||
git clone https://github.com/alibaba-damo-academy/3D-Speaker.git | ||
pushd 3D-Speaker | ||
pip install -q -r ./requirements.txt | ||
pip install -q modelscope onnx onnxruntime kaldi-native-fbank | ||
popd | ||
} | ||
|
||
function download_test_data() { | ||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav | ||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav | ||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav | ||
|
||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_en_16k.wav | ||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_en_16k.wav | ||
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_en_16k.wav | ||
} | ||
|
||
install_3d_speaker | ||
|
||
download_test_data | ||
|
||
export PYTHONPATH=$PWD/3D-Speaker:$PYTHONPATH | ||
export PYTHONPATH=$PWD/3D-Speaker/speakerlab/bin:$PYTHONPATH | ||
|
||
models=( | ||
speech_campplus_sv_en_voxceleb_16k | ||
speech_campplus_sv_zh-cn_16k-common | ||
speech_eres2net_sv_en_voxceleb_16k | ||
speech_eres2net_sv_zh-cn_16k-common | ||
speech_eres2net_base_200k_sv_zh-cn_16k-common | ||
speech_eres2net_base_sv_zh-cn_3dspeaker_16k | ||
speech_eres2net_large_sv_zh-cn_3dspeaker_16k | ||
) | ||
for model in ${models[@]}; do | ||
echo "--------------------$model--------------------" | ||
python3 ./export-onnx.py --model $model | ||
|
||
python3 ./test-onnx.py \ | ||
--model ${model}.onnx \ | ||
--file1 ./speaker1_a_cn_16k.wav \ | ||
--file2 ./speaker1_b_cn_16k.wav | ||
|
||
python3 ./test-onnx.py \ | ||
--model ${model}.onnx \ | ||
--file1 ./speaker1_a_cn_16k.wav \ | ||
--file2 ./speaker2_a_cn_16k.wav | ||
|
||
python3 ./test-onnx.py \ | ||
--model ${model}.onnx \ | ||
--file1 ./speaker1_a_en_16k.wav \ | ||
--file2 ./speaker1_b_en_16k.wav | ||
|
||
python3 ./test-onnx.py \ | ||
--model ${model}.onnx \ | ||
--file1 ./speaker1_a_en_16k.wav \ | ||
--file2 ./speaker2_a_en_16k.wav | ||
done |
Oops, something went wrong.