Skip to content

Commit

Permalink
Add C++ runtime and Python APIs for Moonshine models (k2-fsa#1473)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 26, 2024
1 parent ad3e0eb commit a6f110e
Show file tree
Hide file tree
Showing 33 changed files with 1,572 additions and 36 deletions.
50 changes: 50 additions & 0 deletions .github/scripts/test-offline-moonshine.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env bash

set -e

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

export GIT_CLONE_PROTECTION_ACTIVE=false

echo "EXE is $EXE"
echo "PATH: $PATH"

which $EXE

names=(
tiny
base
)

for name in ${names[@]}; do
log "------------------------------------------------------------"
log "Run $name"
log "------------------------------------------------------------"

repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-$name.tar.bz2
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-$name-en-int8.tar.bz2
curl -SL -O $repo_url
tar xvf sherpa-onnx-moonshine-$name-en-int8.tar.bz2
rm sherpa-onnx-moonshine-$name-en-int8.tar.bz2
repo=sherpa-onnx-moonshine-$name-en-int8
log "Start testing ${repo_url}"

log "test int8 onnx"

time $EXE \
--moonshine-preprocessor=$repo/preprocess.onnx \
--moonshine-encoder=$repo/encode.int8.onnx \
--moonshine-uncached-decoder=$repo/uncached_decode.int8.onnx \
--moonshine-cached-decoder=$repo/cached_decode.int8.onnx \
--tokens=$repo/tokens.txt \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

rm -rf $repo
done
10 changes: 10 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "test offline Moonshine"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2

python3 ./python-api-examples/offline-moonshine-decode-files.py

rm -rf sherpa-onnx-moonshine-tiny-en-int8

log "test offline speaker diarization"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Expand Down
13 changes: 13 additions & 0 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*

- name: Test offline Moonshine
if: matrix.build_type != 'Debug'
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
readelf -d build/bin/sherpa-onnx-offline
.github/scripts/test-offline-moonshine.sh
du -h -d1 .
- name: Test offline CTC
shell: bash
run: |
Expand Down
11 changes: 9 additions & 2 deletions .github/workflows/macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test offline Moonshine
if: matrix.build_type != 'Debug'
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-moonshine.sh
- name: Test C++ API
shell: bash
run: |
Expand Down Expand Up @@ -243,8 +252,6 @@ jobs:
.github/scripts/test-offline-whisper.sh
- name: Test online transducer
shell: bash
run: |
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ jobs:
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*

- name: Test offline Moonshine for windows x64
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe
.github/scripts/test-offline-moonshine.sh
- name: Test C++ API
shell: bash
run: |
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ jobs:
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*

- name: Test offline Moonshine for windows x86
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe
.github/scripts/test-offline-moonshine.sh
- name: Test C++ API
shell: bash
run: |
Expand Down
117 changes: 108 additions & 9 deletions python-api-examples/generate-subtitles.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,19 @@
--feature-dim=80 \
/path/to/test.mp4
(3) For Whisper models
(3) For Moonshine models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
--num-threads=2 \
/path/to/test.mp4
(4) For Whisper models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
Expand All @@ -58,7 +70,7 @@
--num-threads=2 \
/path/to/test.mp4
(4) For SenseVoice CTC models
(5) For SenseVoice CTC models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
Expand All @@ -68,7 +80,7 @@
/path/to/test.mp4
(5) For WeNet CTC models
(6) For WeNet CTC models
./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
Expand All @@ -83,6 +95,7 @@
used in this file.
"""
import argparse
import datetime as dt
import shutil
import subprocess
import sys
Expand Down Expand Up @@ -157,7 +170,7 @@ def get_args():
parser.add_argument(
"--num-threads",
type=int,
default=1,
default=2,
help="Number of threads for neural network computation",
)

Expand Down Expand Up @@ -208,6 +221,34 @@ def get_args():
""",
)

parser.add_argument(
"--moonshine-preprocessor",
default="",
type=str,
help="Path to moonshine preprocessor model",
)

parser.add_argument(
"--moonshine-encoder",
default="",
type=str,
help="Path to moonshine encoder model",
)

parser.add_argument(
"--moonshine-uncached-decoder",
default="",
type=str,
help="Path to moonshine uncached decoder model",
)

parser.add_argument(
"--moonshine-cached-decoder",
default="",
type=str,
help="Path to moonshine cached decoder model",
)

parser.add_argument(
"--decoding-method",
type=str,
Expand Down Expand Up @@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
Expand All @@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

assert_file_exists(args.paraformer)

Expand All @@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

assert_file_exists(args.sense_voice)
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
Expand All @@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

assert_file_exists(args.wenet_ctc)

Expand All @@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
Expand All @@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings,
)
elif args.moonshine_preprocessor:
assert_file_exists(args.moonshine_preprocessor)
assert_file_exists(args.moonshine_encoder)
assert_file_exists(args.moonshine_uncached_decoder)
assert_file_exists(args.moonshine_cached_decoder)

recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=args.moonshine_preprocessor,
encoder=args.moonshine_encoder,
uncached_decoder=args.moonshine_uncached_decoder,
cached_decoder=args.moonshine_cached_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
raise ValueError("Please specify at least one model")

Expand Down Expand Up @@ -424,28 +511,32 @@ def main():
segment_list = []

print("Started!")
start_t = dt.datetime.now()
num_processed_samples = 0

is_silence = False
is_eof = False
# TODO(fangjun): Support multithreads
while True:
# *2 because int16_t has two bytes
data = process.stdout.read(frames_per_read * 2)
if not data:
if is_silence:
if is_eof:
break
is_silence = True
# The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data
is_eof = True
# pad 1 second at the end of the file for the VAD
data = np.zeros(1 * args.sample_rate, dtype=np.int16)

samples = np.frombuffer(data, dtype=np.int16)
samples = samples.astype(np.float32) / 32768

num_processed_samples += samples.shape[0]

buffer = np.concatenate([buffer, samples])
while len(buffer) > window_size:
vad.accept_waveform(buffer[:window_size])
buffer = buffer[window_size:]

if is_silence:
if is_eof:
vad.flush()

streams = []
Expand All @@ -471,6 +562,11 @@ def main():
seg.text = stream.result.text
segment_list.append(seg)

end_t = dt.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = num_processed_samples / 16000
rtf = elapsed_seconds / duration

srt_filename = Path(args.sound_file).with_suffix(".srt")
with open(srt_filename, "w", encoding="utf-8") as f:
for i, seg in enumerate(segment_list):
Expand All @@ -479,6 +575,9 @@ def main():
print("", file=f)

print(f"Saved to {srt_filename}")
print(f"Audio duration:\t{duration:.3f} s")
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print("Done!")


Expand Down
Loading

0 comments on commit a6f110e

Please sign in to comment.