Skip to content

Commit

Permalink
Fix keyword spotting.
Browse files Browse the repository at this point in the history
See also #1417
  • Loading branch information
csukuangfj committed Jan 6, 2025
1 parent 930986b commit 206d3f7
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 12 deletions.
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class KeywordSpotterImpl {

virtual bool IsReady(OnlineStream *s) const = 0;

virtual void Reset(OnlineStream *s) const = 0;

virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;

virtual KeywordResult GetResult(OnlineStream *s) const = 0;
Expand Down
16 changes: 16 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
return s->GetNumProcessedFrames() + model_->ChunkSize() <
s->NumFramesReady();
}
void Reset(OnlineStream *s) const override { InitOnlineStream(s); }

void DecodeStreams(OnlineStream **ss, int32_t n) const override {
for (int32_t i = 0; i < n; ++i) {
auto s = ss[i];
auto r = s->GetKeywordResult(true);
int32_t num_trailing_blanks = r.num_trailing_blanks;
// assume subsampling_factor is 4
// assume frameshift is 0.01 second
float trailing_slience = num_trailing_blanks * 4 * 0.01;

// it resets automatically after detecting 1.5 seconds of silence
float threshold = 1.5;
if (trailing_slience > threshold) {
Reset(s);
}
}

int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();

Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}

void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); }

void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/keyword-spotter.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class KeywordSpotter {
*/
bool IsReady(OnlineStream *s) const;

// Remember to call it after detecting a keyword
void Reset(OnlineStream *s) const;

/** Decode a single stream. */
void DecodeStream(OnlineStream *s) const {
OnlineStream *ss[1] = {s};
Expand Down
14 changes: 8 additions & 6 deletions sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,15 @@ as the device_name.

while (spotter.IsReady(stream.get())) {
spotter.DecodeStream(stream.get());
}

const auto r = spotter.GetResult(stream.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
const auto r = spotter.GetResult(stream.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;

spotter.Reset(stream.get());
}
}
}

Expand Down
14 changes: 8 additions & 6 deletions sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,15 @@ for a list of pre-trained models to download.
while (!stop) {
while (spotter.IsReady(s.get())) {
spotter.DecodeStream(s.get());
}

const auto r = spotter.GetResult(s.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
const auto r = spotter.GetResult(s.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;

spotter.Reset(s.get());
}
}

Pa_Sleep(20); // sleep for 20ms
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/python/csrc/keyword-spotter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) {
py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
.def("is_ready", &PyClass::IsReady,
py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
.def("decode_stream", &PyClass::DecodeStream,
py::call_guard<py::gil_scoped_release>())
.def(
Expand Down

0 comments on commit 206d3f7

Please sign in to comment.