Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Kotlin API for Matcha-TTS models. #1668

Merged
merged 3 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/jni.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ jobs:

cd ./kotlin-api-examples
./run.sh

- uses: actions/upload-artifact@v4
with:
name: tts-files-${{ matrix.os }}
path: kotlin-api-examples/test-*.wav
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ sherpa-onnx-moonshine-tiny-en-int8
sherpa-onnx-moonshine-base-en-int8
harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE
harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md
matcha-icefall-zh-baker
10 changes: 10 additions & 0 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ function testTts() {
rm vits-piper-en_US-amy-low.tar.bz2
fi

if [ ! -f ./matcha-icefall-zh-baker/model-steps-3.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
tar xvf matcha-icefall-zh-baker.tar.bz2
rm matcha-icefall-zh-baker.tar.bz2
fi

if [ ! -f ./hifigan_v2.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx
fi

out_filename=test_tts.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_tts.kt \
Expand Down
29 changes: 27 additions & 2 deletions kotlin-api-examples/test_tts.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
package com.k2fsa.sherpa.onnx

fun main() {
testTts()
testVits()
testMatcha()
}

fun testTts() {
fun testMatcha() {
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
var config = OfflineTtsConfig(
model=OfflineTtsModelConfig(
matcha=OfflineTtsMatchaModelConfig(
acousticModel="./matcha-icefall-zh-baker/model-steps-3.onnx",
vocoder="./hifigan_v2.onnx",
tokens="./matcha-icefall-zh-baker/tokens.txt",
lexicon="./matcha-icefall-zh-baker/lexicon.txt",
dictDir="./matcha-icefall-zh-baker/dict",
),
numThreads=1,
debug=true,
),
ruleFsts="./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst",
)
val tts = OfflineTts(config=config)
val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback)
audio.save(filename="test-zh.wav")
tts.release()
println("Saved to test-zh.wav")
}

fun testVits() {
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
var config = OfflineTtsConfig(
Expand Down
12 changes: 8 additions & 4 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
auto p = new SherpaOnnxOnlinePunctuation;
try {
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");
punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
punctuation_config.model.cnn_bilstm =
SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
punctuation_config.model.bpe_vocab =
SHERPA_ONNX_OR(config->model.bpe_vocab, "");
punctuation_config.model.num_threads =
SHERPA_ONNX_OR(config->model.num_threads, 1);
punctuation_config.model.debug = config->model.debug;
punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
punctuation_config.model.provider =
SHERPA_ONNX_OR(config->model.provider, "cpu");

p->impl =
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
SherpaOnnxOnlinePunctuationModelConfig model;
} SherpaOnnxOnlinePunctuationConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation
SherpaOnnxOnlinePunctuation;

// Create an online punctuation processor. The user has to invoke
// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
// to avoid memory leak
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *
SherpaOnnxCreateOnlinePunctuation(
const SherpaOnnxOnlinePunctuationConfig *config);

// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/jieba-lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class JiebaLexicon::Impl {

this_sentence.insert(this_sentence.end(), ids.begin(), ids.end());

if (w == "。" || w == "!" || w == "?" || w == ",") {
if (IsPunct(w)) {
ans.emplace_back(std::move(this_sentence));
this_sentence = {};
}
Expand Down
49 changes: 49 additions & 0 deletions sherpa-onnx/jni/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
jobject model = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model);

// vits
fid = env->GetFieldID(model_config_cls, "vits",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
jobject vits = env->GetObjectField(model, fid);
Expand Down Expand Up @@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(vits_cls, "lengthScale", "F");
ans.model.vits.length_scale = env->GetFloatField(vits, fid);

// matcha
fid = env->GetFieldID(model_config_cls, "matcha",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;");
jobject matcha = env->GetObjectField(model, fid);
jclass matcha_cls = env->GetObjectClass(matcha);

fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.acoustic_model = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.vocoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.lexicon = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.tokens = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.data_dir = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.dict_dir = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "noiseScale", "F");
ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid);

fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);

fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);

Expand Down
12 changes: 12 additions & 0 deletions sherpa-onnx/kotlin-api/Tts.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig(
var lengthScale: Float = 1.0f,
)

data class OfflineTtsMatchaModelConfig(
var acousticModel: String = "",
var vocoder: String = "",
var lexicon: String = "",
var tokens: String = "",
var dataDir: String = "",
var dictDir: String = "",
var noiseScale: Float = 1.0f,
var lengthScale: Float = 1.0f,
)

data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
Expand Down
Loading