From 665b869f033bb181145ca550dfb824d301e8016d Mon Sep 17 00:00:00 2001 From: ductranminh Date: Thu, 1 Feb 2024 20:33:22 +0700 Subject: [PATCH] Add context biasing for mobile (#568) --- .../java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt | 8 ++++-- .../java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt | 8 ++++-- sherpa-onnx/jni/jni.cc | 28 +++++++++++++++---- swift-api-examples/SherpaOnnx.swift | 21 ++++++++++++-- 4 files changed, 52 insertions(+), 13 deletions(-) diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt index e3d60a207..dfd8a4d80 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt @@ -85,7 +85,7 @@ class SherpaOnnx( acceptWaveform(ptr, samples, sampleRate) fun inputFinished() = inputFinished(ptr) - fun reset(recreate: Boolean = false) = reset(ptr, recreate = recreate) + fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords) fun decode() = decode(ptr) fun isEndpoint(): Boolean = isEndpoint(ptr) fun isReady(): Boolean = isReady(ptr) @@ -93,6 +93,9 @@ class SherpaOnnx( val text: String get() = getText(ptr) + val tokens: Array + get() = getTokens(ptr) + private external fun delete(ptr: Long) private external fun new( @@ -107,10 +110,11 @@ class SherpaOnnx( private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) private external fun inputFinished(ptr: Long) private external fun getText(ptr: Long): String - private external fun reset(ptr: Long, recreate: Boolean) + private external fun reset(ptr: Long, recreate: Boolean, hotwords: String) private external fun decode(ptr: Long) private external fun isEndpoint(ptr: Long): Boolean private external fun isReady(ptr: Long): Boolean + private external fun getTokens(ptr: Long): Array companion object { init { diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt index 08b3d8999..228f19c1f 100644 --- a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt @@ -120,7 +120,7 @@ class SherpaOnnx( acceptWaveform(ptr, samples, sampleRate) fun inputFinished() = inputFinished(ptr) - fun reset(recreate: Boolean = false) = reset(ptr, recreate = recreate) + fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords) fun decode() = decode(ptr) fun isEndpoint(): Boolean = isEndpoint(ptr) fun isReady(): Boolean = isReady(ptr) @@ -128,6 +128,9 @@ class SherpaOnnx( val text: String get() = getText(ptr) + val tokens: Array + get() = getTokens(ptr) + private external fun delete(ptr: Long) private external fun new( @@ -142,10 +145,11 @@ class SherpaOnnx( private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) private external fun inputFinished(ptr: Long) private external fun getText(ptr: Long): String - private external fun reset(ptr: Long, recreate: Boolean) + private external fun reset(ptr: Long, recreate: Boolean, hotwords: String) private external fun decode(ptr: Long) private external fun isEndpoint(ptr: Long): Boolean private external fun isReady(ptr: Long): Boolean + private external fun getTokens(ptr: Long): Array companion object { init { diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index a5c829844..e52abc37d 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -76,11 +76,24 @@ class SherpaOnnx { bool IsReady() const { return recognizer_.IsReady(stream_.get()); } - void Reset(bool recreate) { - if (recreate) { - stream_ = recognizer_.CreateStream(); + // If keywords is an empty string, it just recreates the decoding stream + // If keywords is not empty, it will create a new decoding stream with + // the given keywords appended to the default keywords. + void Reset(bool recreate, const std::string &keywords = {}) { + if (keywords.empty()) { + if (recreate) { + stream_ = recognizer_.CreateStream(); + } else { + recognizer_.Reset(stream_.get()); + } } else { - recognizer_.Reset(stream_.get()); + auto stream = recognizer_.CreateStream(keywords); + // Set new keywords failed, the stream_ will not be updated. + if (stream != nullptr) { + stream_ = std::move(stream); + } else { + SHERPA_ONNX_LOGE("Failed to set keywords: %s", keywords.c_str()); + } } } @@ -1509,9 +1522,12 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete( SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( - JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate) { + JNIEnv *env, jobject /*obj*/, + jlong ptr, jboolean recreate, jstring keywords) { auto model = reinterpret_cast(ptr); - model->Reset(recreate); + const char *p_keywords = env->GetStringUTFChars(keywords, nullptr); + model->Reset(recreate, p_keywords); + env->ReleaseStringUTFChars(keywords, p_keywords); } SHERPA_ONNX_EXTERN_C diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 397d92e67..8a5ea907e 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -188,7 +188,7 @@ class SherpaOnnxOnlineRecongitionResult { class SherpaOnnxRecognizer { /// A pointer to the underlying counterpart in C let recognizer: OpaquePointer! - let stream: OpaquePointer! + var stream: OpaquePointer! /// Constructor taking a model config init( @@ -237,8 +237,23 @@ class SherpaOnnxRecognizer { /// Reset the recognizer, which clears the neural network model state /// and the state for decoding. - func reset() { - Reset(recognizer, stream) + /// If hotwords is an empty string, it just recreates the decoding stream + /// If hotwords is not empty, it will create a new decoding stream with + /// the given hotWords appended to the default hotwords. + func reset(hotwords: String? = nil) { + guard let words = hotwords, !words.isEmpty else { + Reset(recognizer, stream) + return + } + + words.withCString { cString in + let newStream = CreateOnlineStreamWithHotwords(recognizer, cString) + // lock while release and replace stream + objc_sync_enter(self) + DestroyOnlineStream(stream) + stream = newStream + objc_sync_exit(self) + } } /// Signal that no more audio samples would be available.