diff --git a/sherpa-onnx/csrc/keyword-spotter-impl.h b/sherpa-onnx/csrc/keyword-spotter-impl.h index ded735ff5..6180f9172 100644 --- a/sherpa-onnx/csrc/keyword-spotter-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-impl.h @@ -38,6 +38,8 @@ class KeywordSpotterImpl { virtual bool IsReady(OnlineStream *s) const = 0; + virtual void Reset(OnlineStream *s) const = 0; + virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; virtual KeywordResult GetResult(OnlineStream *s) const = 0; diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index 759639184..d29b8b58d 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { return s->GetNumProcessedFrames() + model_->ChunkSize() < s->NumFramesReady(); } + void Reset(OnlineStream *s) const override { InitOnlineStream(s); } void DecodeStreams(OnlineStream **ss, int32_t n) const override { + for (int32_t i = 0; i < n; ++i) { + auto s = ss[i]; + auto r = s->GetKeywordResult(true); + int32_t num_trailing_blanks = r.num_trailing_blanks; + // assume subsampling_factor is 4 + // assume frameshift is 0.01 second + float trailing_slience = num_trailing_blanks * 4 * 0.01; + + // it resets automatically after detecting 1.5 seconds of silence + float threshold = 1.5; + if (trailing_slience > threshold) { + Reset(s); + } + } + int32_t chunk_size = model_->ChunkSize(); int32_t chunk_shift = model_->ChunkShift(); diff --git a/sherpa-onnx/csrc/keyword-spotter.cc b/sherpa-onnx/csrc/keyword-spotter.cc index d1bf6d63b..66d0907ab 100644 --- a/sherpa-onnx/csrc/keyword-spotter.cc +++ b/sherpa-onnx/csrc/keyword-spotter.cc @@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const { return impl_->IsReady(s); } +void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); } + void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const { impl_->DecodeStreams(ss, n); } diff --git a/sherpa-onnx/csrc/keyword-spotter.h b/sherpa-onnx/csrc/keyword-spotter.h index f0c31bdb4..c933f4b23 100644 --- a/sherpa-onnx/csrc/keyword-spotter.h +++ b/sherpa-onnx/csrc/keyword-spotter.h @@ -129,6 +129,9 @@ class KeywordSpotter { */ bool IsReady(OnlineStream *s) const; + // Remember to call it after detecting a keyword + void Reset(OnlineStream *s) const; + /** Decode a single stream. */ void DecodeStream(OnlineStream *s) const { OnlineStream *ss[1] = {s}; diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc index a909ff250..cfa46dc91 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc @@ -106,13 +106,15 @@ as the device_name. while (spotter.IsReady(stream.get())) { spotter.DecodeStream(stream.get()); - } - const auto r = spotter.GetResult(stream.get()); - if (!r.keyword.empty()) { - display.Print(keyword_index, r.AsJsonString()); - fflush(stderr); - keyword_index++; + const auto r = spotter.GetResult(stream.get()); + if (!r.keyword.empty()) { + display.Print(keyword_index, r.AsJsonString()); + fflush(stderr); + keyword_index++; + + spotter.Reset(stream.get()); + } } } diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc index 903debea9..4d75f9d49 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc @@ -150,13 +150,15 @@ for a list of pre-trained models to download. while (!stop) { while (spotter.IsReady(s.get())) { spotter.DecodeStream(s.get()); - } - const auto r = spotter.GetResult(s.get()); - if (!r.keyword.empty()) { - display.Print(keyword_index, r.AsJsonString()); - fflush(stderr); - keyword_index++; + const auto r = spotter.GetResult(s.get()); + if (!r.keyword.empty()) { + display.Print(keyword_index, r.AsJsonString()); + fflush(stderr); + keyword_index++; + + spotter.Reset(s.get()); + } } Pa_Sleep(20); // sleep for 20ms diff --git a/sherpa-onnx/python/csrc/keyword-spotter.cc b/sherpa-onnx/python/csrc/keyword-spotter.cc index 144992605..4a48ada4f 100644 --- a/sherpa-onnx/python/csrc/keyword-spotter.cc +++ b/sherpa-onnx/python/csrc/keyword-spotter.cc @@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) { py::arg("keywords"), py::call_guard()) .def("is_ready", &PyClass::IsReady, py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()) .def("decode_stream", &PyClass::DecodeStream, py::call_guard()) .def(