diff --git a/cxx-api-examples/sense-voice-cxx-api.cc b/cxx-api-examples/sense-voice-cxx-api.cc index 15d752058..ea642b980 100644 --- a/cxx-api-examples/sense-voice-cxx-api.cc +++ b/cxx-api-examples/sense-voice-cxx-api.cc @@ -19,7 +19,7 @@ #include "sherpa-onnx/c-api/cxx-api.h" int32_t main() { - using namespace sherpa_onnx::cxx; + using namespace sherpa_onnx::cxx; // NOLINT OfflineRecognizerConfig config; config.model_config.sense_voice.model = diff --git a/cxx-api-examples/streaming-zipformer-cxx-api.cc b/cxx-api-examples/streaming-zipformer-cxx-api.cc index 5a49dcfc9..ac4abc479 100644 --- a/cxx-api-examples/streaming-zipformer-cxx-api.cc +++ b/cxx-api-examples/streaming-zipformer-cxx-api.cc @@ -20,7 +20,7 @@ #include "sherpa-onnx/c-api/cxx-api.h" int32_t main() { - using namespace sherpa_onnx::cxx; + using namespace sherpa_onnx::cxx; // NOLINT OnlineRecognizerConfig config; // please see diff --git a/cxx-api-examples/whisper-cxx-api.cc b/cxx-api-examples/whisper-cxx-api.cc index 82f0ddb53..348d115bd 100644 --- a/cxx-api-examples/whisper-cxx-api.cc +++ b/cxx-api-examples/whisper-cxx-api.cc @@ -19,7 +19,7 @@ #include "sherpa-onnx/c-api/cxx-api.h" int32_t main() { - using namespace sherpa_onnx::cxx; + using namespace sherpa_onnx::cxx; // NOLINT OfflineRecognizerConfig config; config.model_config.whisper.encoder = diff --git a/scripts/check_style_cpplint.sh b/scripts/check_style_cpplint.sh index eedc9afc1..ea419242a 100755 --- a/scripts/check_style_cpplint.sh +++ b/scripts/check_style_cpplint.sh @@ -71,6 +71,9 @@ function is_source_code_file() { } function check_style() { + if [[ $1 == mfc-example* ]]; then + return + fi python3 $cpplint_src $1 || abort $1 } @@ -99,7 +102,7 @@ function do_check() { ;; 2) echo "Check all files" - files=$(find $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc") + files=$(find $sherpa_onnx_dir/cxx-api-examples $sherpa_onnx_dir/c-api-examples $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc") ;; *) echo "Check last commit" diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index aaedc3be0..51d712eb8 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -5,6 +5,7 @@ #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ #include +#include #include #include #include @@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl } auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); + + // The embedding model may output NaN. valid_indexes contains indexes + // in chunk_speaker_samples_list_pair.second that don't lead to + // NaN embeddings. + std::vector valid_indexes; + valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size()); + Matrix2D embeddings = ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, - std::move(callback), callback_arg); + &valid_indexes, std::move(callback), callback_arg); + + if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) { + std::vector chunk_speaker_pair; + std::vector> sample_indexes; + + chunk_speaker_pair.reserve(valid_indexes.size()); + sample_indexes.reserve(valid_indexes.size()); + for (auto i : valid_indexes) { + chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]); + sample_indexes.push_back( + std::move(chunk_speaker_samples_list_pair.second[i])); + } + + chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair); + chunk_speaker_samples_list_pair.second = std::move(sample_indexes); + } std::vector cluster_labels = clustering_->Cluster( &embeddings(0, 0), embeddings.rows(), embeddings.cols()); @@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl Matrix2D ComputeEmbeddings( const float *audio, int32_t n, const std::vector> &sample_indexes, + std::vector *valid_indexes, OfflineSpeakerDiarizationProgressCallback callback, void *callback_arg) const { const auto &meta_data = segmentation_model_.GetModelMetaData(); int32_t sample_rate = meta_data.sample_rate; Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim()); + auto IsNaNWrapper = [](float f) -> bool { return std::isnan(f); }; + int32_t k = 0; + int32_t cur_row_index = 0; for (const auto &v : sample_indexes) { auto stream = embedding_extractor_.CreateStream(); for (const auto &p : v) { @@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl std::vector embedding = embedding_extractor_.Compute(stream.get()); - std::copy(embedding.begin(), embedding.end(), &ans(k, 0)); + if (std::none_of(embedding.begin(), embedding.end(), IsNaNWrapper)) { + // a valid embedding + std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0)); + cur_row_index += 1; + valid_indexes->push_back(k); + } k += 1; @@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl } } + if (k != cur_row_index) { + auto seq = Eigen::seqN(0, cur_row_index); + ans = ans(seq, Eigen::all); + } + return ans; } diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h index 66ad15af3..7e0883085 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h @@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { auto variance = EX2 - EX.array().pow(2); auto stddev = variance.array().sqrt(); - m = (m.rowwise() - EX).array().rowwise() / stddev.array(); + m = (m.rowwise() - EX).array().rowwise() / (stddev.array() + 1e-5); } private: