diff --git a/.github/scripts/test-speaker-recognition-python.sh b/.github/scripts/test-speaker-recognition-python.sh index 7d6eff9ff..22b1367de 100755 --- a/.github/scripts/test-speaker-recognition-python.sh +++ b/.github/scripts/test-speaker-recognition-python.sh @@ -57,5 +57,19 @@ done ls -lh popd +log "Download NeMo models" +model_dir=$d/nemo +mkdir -p $model_dir +pushd $model_dir +models=( +nemo_en_titanet_large.onnx +nemo_en_titanet_small.onnx +nemo_en_speakerverification_speakernet.onnx +) +for m in ${models[@]}; do + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m +done +ls -lh +popd python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index 38751b67c..ea1c27d46 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -1,9 +1,9 @@ function(download_kaldi_native_fbank) include(FetchContent) - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz") - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz") - set(kaldi_native_fbank_HASH "SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283") + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.6.tar.gz") + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.6.tar.gz") + set(kaldi_native_fbank_HASH "SHA256=6202a00cd06ba8ff89beb7b6f85cda34e073e94f25fc29e37c519bff0706bf19") set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) # If you don't have access to the Internet, # please pre-download kaldi-native-fbank set(possible_file_locations - $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.5.tar.gz - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.5.tar.gz - /tmp/kaldi-native-fbank-1.18.5.tar.gz - /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz + $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.6.tar.gz + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.6.tar.gz + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.6.tar.gz + /tmp/kaldi-native-fbank-1.18.6.tar.gz + /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.6.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 38fef1c5d..c889f8cee 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -100,6 +100,7 @@ set(sources list(APPEND sources speaker-embedding-extractor-impl.cc speaker-embedding-extractor-model.cc + speaker-embedding-extractor-nemo-model.cc speaker-embedding-extractor.cc speaker-embedding-manager.cc ) diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index e81c0dfb4..8c3a488e7 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -41,8 +41,12 @@ class FeatureExtractor::Impl { public: explicit Impl(const FeatureExtractorConfig &config) : config_(config) { opts_.frame_opts.dither = 0; - opts_.frame_opts.snip_edges = false; + opts_.frame_opts.snip_edges = config.snip_edges; opts_.frame_opts.samp_freq = config.sampling_rate; + opts_.frame_opts.frame_shift_ms = config.frame_shift_ms; + opts_.frame_opts.frame_length_ms = config.frame_length_ms; + opts_.frame_opts.remove_dc_offset = config.remove_dc_offset; + opts_.frame_opts.window_type = config.window_type; opts_.mel_opts.num_bins = config.feature_dim; @@ -52,6 +56,9 @@ class FeatureExtractor::Impl { // https://github.com/k2-fsa/sherpa-onnx/issues/514 opts_.mel_opts.high_freq = -400; + opts_.mel_opts.low_freq = config.low_freq; + opts_.mel_opts.is_librosa = config.is_librosa; + fbank_ = std::make_unique(opts_); } diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index 497dd01cc..31107d3c5 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -28,6 +28,14 @@ struct FeatureExtractorConfig { // If false, we will multiply the inputs by 32768 bool normalize_samples = true; + bool snip_edges = false; + float frame_shift_ms = 10.0f; // in milliseconds. + float frame_length_ms = 25.0f; // in milliseconds. + int32_t low_freq = 20; + bool is_librosa = false; + bool remove_dc_offset = true; // Subtract mean of wave before FFT. + std::string window_type = "povey"; // e.g. Hamming window + std::string ToString() const; void Register(ParseOptions *po); diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h b/sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h index eb87d9043..e819bd067 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc index 46cdfa61d..a9babec92 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc @@ -1,11 +1,12 @@ // sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h" namespace sherpa_onnx { @@ -14,6 +15,7 @@ namespace { enum class ModelType { kWeSpeaker, k3dSpeaker, + kNeMo, kUnkown, }; @@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kWeSpeaker; } else if (model_type.get() == std::string("3d-speaker")) { return ModelType::k3dSpeaker; + } else if (model_type.get() == std::string("nemo")) { + return ModelType::kNeMo; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); return ModelType::kUnkown; @@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create( // fall through case ModelType::k3dSpeaker: return std::make_unique(config); + case ModelType::kNeMo: + return std::make_unique(config); case ModelType::kUnkown: SHERPA_ONNX_LOGE( "Unknown model type in for speaker embedding extractor!"); diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.h b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.h index fa84b43e2..02362f89b 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.h +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.h @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/speaker-embedding-extractor-impl.h // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc index fedfcab54..2c9930f8b 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/speaker-embedding-extractor-model.cc // -// Copyright (c) 2023-2024 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h" diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-model.h b/sherpa-onnx/csrc/speaker-embedding-extractor-model.h index 3fa94ef3f..d5f179678 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-model.h +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-model.h @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/speaker-embedding-extractor-model.h // -// Copyright (c) 2023-2024 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h new file mode 100644 index 000000000..6678758c2 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h @@ -0,0 +1,128 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ +#include +#include +#include +#include + +#include "Eigen/Dense" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { + public: + explicit SpeakerEmbeddingExtractorNeMoImpl( + const SpeakerEmbeddingExtractorConfig &config) + : model_(config) {} + + int32_t Dim() const override { return model_.GetMetaData().output_dim; } + + std::unique_ptr CreateStream() const override { + FeatureExtractorConfig feat_config; + const auto &meta_data = model_.GetMetaData(); + feat_config.sampling_rate = meta_data.sample_rate; + feat_config.feature_dim = meta_data.feat_dim; + feat_config.normalize_samples = true; + feat_config.snip_edges = true; + feat_config.frame_shift_ms = meta_data.window_stride_ms; + feat_config.frame_length_ms = meta_data.window_size_ms; + feat_config.low_freq = 0; + feat_config.is_librosa = true; + feat_config.remove_dc_offset = false; + feat_config.window_type = meta_data.window_type; + + return std::make_unique(feat_config); + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() < s->NumFramesReady(); + } + + std::vector Compute(OnlineStream *s) const override { + int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames(); + if (num_frames <= 0) { + SHERPA_ONNX_LOGE( + "Please make sure IsReady(s) returns true. num_frames: %d", + num_frames); + return {}; + } + + std::vector features = + s->GetFrames(s->GetNumProcessedFrames(), num_frames); + + s->GetNumProcessedFrames() += num_frames; + + int32_t feat_dim = features.size() / num_frames; + + const auto &meta_data = model_.GetMetaData(); + if (!meta_data.feature_normalize_type.empty()) { + if (meta_data.feature_normalize_type == "per_feature") { + NormalizePerFeature(features.data(), num_frames, feat_dim); + } else { + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s", + meta_data.feature_normalize_type.c_str()); + exit(-1); + } + } + + if (num_frames % 16 != 0) { + int32_t pad = 16 - num_frames % 16; + features.resize((num_frames + pad) * feat_dim); + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{1, num_frames, feat_dim}; + Ort::Value x = + Ort::Value::CreateTensor(memory_info, features.data(), features.size(), + x_shape.data(), x_shape.size()); + + x = Transpose12(model_.Allocator(), &x); + + int64_t x_lens = num_frames; + std::array x_lens_shape{1}; + Ort::Value x_lens_tensor = Ort::Value::CreateTensor( + memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size()); + + Ort::Value embedding = + model_.Compute(std::move(x), std::move(x_lens_tensor)); + std::vector embedding_shape = + embedding.GetTensorTypeAndShapeInfo().GetShape(); + + std::vector ans(embedding_shape[1]); + std::copy(embedding.GetTensorData(), + embedding.GetTensorData() + ans.size(), ans.begin()); + + return ans; + } + + private: + void NormalizePerFeature(float *p, int32_t num_frames, + int32_t feat_dim) const { + auto m = Eigen::Map< + Eigen::Matrix>( + p, num_frames, feat_dim); + + auto EX = m.colwise().mean(); + auto EX2 = m.array().pow(2).colwise().sum() / num_frames; + auto variance = EX2 - EX.array().pow(2); + auto stddev = variance.array().sqrt(); + + m = (m.rowwise() - EX).array().rowwise() / stddev.array(); + } + + private: + SpeakerEmbeddingExtractorNeMoModel model_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h new file mode 100644 index 000000000..f0ff1f7ba --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_onnx { + +struct SpeakerEmbeddingExtractorNeMoModelMetaData { + int32_t output_dim = 0; + int32_t feat_dim = 80; + int32_t sample_rate = 0; + int32_t window_size_ms = 25; + int32_t window_stride_ms = 25; + + // Chinese, English, etc. + std::string language; + + // for 3d-speaker, it is global-mean + std::string feature_normalize_type; + std::string window_type = "hann"; +}; + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc new file mode 100644 index 000000000..4e257dcf3 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc @@ -0,0 +1,126 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h" + +namespace sherpa_onnx { + +class SpeakerEmbeddingExtractorNeMoModel::Impl { + public: + explicit Impl(const SpeakerEmbeddingExtractorConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.model); + Init(buf.data(), buf.size()); + } + } + + Ort::Value Compute(Ort::Value x, Ort::Value x_lens) const { + std::array inputs = {std::move(x), std::move(x_lens)}; + + // output_names_ptr_[0] is logits + // output_names_ptr_[1] is embeddings + // so we use output_names_ptr_.data() + 1 here to extract only the + // embeddings + auto outputs = sess_->Run({}, input_names_ptr_.data(), inputs.data(), + inputs.size(), output_names_ptr_.data() + 1, 1); + return std::move(outputs[0]); + } + + OrtAllocator *Allocator() const { return allocator_; } + + const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const { + return meta_data_; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); + SHERPA_ONNX_READ_META_DATA(meta_data_.feat_dim, "feat_dim"); + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size_ms, "window_size_ms"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_stride_ms, "window_stride_ms"); + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); + + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT( + meta_data_.feature_normalize_type, "feature_normalize_type", ""); + + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.window_type, + "window_type", "povey"); + + std::string framework; + SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); + if (framework != "nemo") { + SHERPA_ONNX_LOGE("Expect a NeMo model, given: %s", framework.c_str()); + exit(-1); + } + } + + private: + SpeakerEmbeddingExtractorConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + SpeakerEmbeddingExtractorNeMoModelMetaData meta_data_; +}; + +SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel( + const SpeakerEmbeddingExtractorConfig &config) + : impl_(std::make_unique(config)) {} + +SpeakerEmbeddingExtractorNeMoModel::~SpeakerEmbeddingExtractorNeMoModel() = + default; + +const SpeakerEmbeddingExtractorNeMoModelMetaData & +SpeakerEmbeddingExtractorNeMoModel::GetMetaData() const { + return impl_->GetMetaData(); +} + +Ort::Value SpeakerEmbeddingExtractorNeMoModel::Compute( + Ort::Value x, Ort::Value x_lens) const { + return impl_->Compute(std::move(x), std::move(x_lens)); +} + +OrtAllocator *SpeakerEmbeddingExtractorNeMoModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h new file mode 100644 index 000000000..9678139e9 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" + +namespace sherpa_onnx { + +class SpeakerEmbeddingExtractorNeMoModel { + public: + explicit SpeakerEmbeddingExtractorNeMoModel( + const SpeakerEmbeddingExtractorConfig &config); + + ~SpeakerEmbeddingExtractorNeMoModel(); + + const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const; + + /** + * @param x A float32 tensor of shape (N, C, T) + * @param x_len A int64 tensor of shape (N,) + * @return A float32 tensor of shape (N, C) + */ + Ort::Value Compute(Ort::Value x, Ort::Value x_len) const; + + OrtAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor.cc b/sherpa-onnx/csrc/speaker-embedding-extractor.cc index 7826e4fb6..f7d6c9b12 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor.cc @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/speaker-embedding-extractor.cc // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor.h b/sherpa-onnx/csrc/speaker-embedding-extractor.h index cb23d40c0..2d536aa54 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor.h +++ b/sherpa-onnx/csrc/speaker-embedding-extractor.h @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/speaker-embedding-extractor.h // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-manager-test.cc b/sherpa-onnx/csrc/speaker-embedding-manager-test.cc index 0e1603c2b..6f115ca55 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager-test.cc +++ b/sherpa-onnx/csrc/speaker-embedding-manager-test.cc @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/speaker-embedding-manager-test.cc // -// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) +// Copyright (c) 2024 Jingzhao Ou (jingzhao.ou@gmail.com) #include "sherpa-onnx/csrc/speaker-embedding-manager.h" diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.cc b/sherpa-onnx/csrc/speaker-embedding-manager.cc index 02894436d..dead72289 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.cc +++ b/sherpa-onnx/csrc/speaker-embedding-manager.cc @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/speaker-embedding-manager.cc // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #include "sherpa-onnx/csrc/speaker-embedding-manager.h" diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.h b/sherpa-onnx/csrc/speaker-embedding-manager.h index 25f85a930..66df665df 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.h +++ b/sherpa-onnx/csrc/speaker-embedding-manager.h @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/speaker-embedding-manager.h // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ diff --git a/sherpa-onnx/python/tests/test_speaker_recognition.py b/sherpa-onnx/python/tests/test_speaker_recognition.py index e05ae2a01..bd7c8edb6 100755 --- a/sherpa-onnx/python/tests/test_speaker_recognition.py +++ b/sherpa-onnx/python/tests/test_speaker_recognition.py @@ -56,7 +56,7 @@ def load_speaker_embedding_model(model_filename): return extractor -def test_wespeaker_model(model_filename: str): +def test_zh_models(model_filename: str): model_filename = str(model_filename) if "en" in model_filename: print(f"skip {model_filename}") @@ -114,8 +114,9 @@ def test_wespeaker_model(model_filename: str): assert ans == name, (name, ans) -def test_3dspeaker_model(model_filename: str): - extractor = load_speaker_embedding_model(str(model_filename)) +def test_en_and_zh_models(model_filename: str): + model_filename = str(model_filename) + extractor = load_speaker_embedding_model(model_filename) manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) filenames = [ @@ -124,7 +125,14 @@ def test_3dspeaker_model(model_filename: str): "speaker1_a_en_16k", "speaker2_a_en_16k", ] + is_en = "en" in model_filename for filename in filenames: + if is_en and "cn" in filename: + continue + + if not is_en and "en" in filename: + continue + name = filename.rsplit("_", maxsplit=1)[0] data, sample_rate = read_wave( f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" @@ -145,6 +153,11 @@ def test_3dspeaker_model(model_filename: str): "speaker1_b_en_16k", ] for filename in filenames: + if is_en and "cn" in filename: + continue + + if not is_en and "en" in filename: + continue print(filename) name = filename.rsplit("_", maxsplit=1)[0] name = name.replace("b_cn", "a_cn") @@ -178,7 +191,8 @@ def test_wespeaker_models(self): return for filename in model_dir.glob("*.onnx"): print(filename) - test_wespeaker_model(filename) + test_zh_models(filename) + test_en_and_zh_models(filename) def test_3dpeaker_models(self): model_dir = Path(d) / "3dspeaker" @@ -187,7 +201,16 @@ def test_3dpeaker_models(self): return for filename in model_dir.glob("*.onnx"): print(filename) - test_3dspeaker_model(filename) + test_en_and_zh_models(filename) + + def test_nemo_models(self): + model_dir = Path(d) / "nemo" + if not model_dir.is_dir(): + print(f"{model_dir} does not exist - skip it") + return + for filename in model_dir.glob("*.onnx"): + print(filename) + test_en_and_zh_models(filename) if __name__ == "__main__":