Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C++ runtime for speaker verification models from NeMo #527

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/scripts/test-speaker-recognition-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,19 @@ done
ls -lh
popd

log "Download NeMo models"
model_dir=$d/nemo
mkdir -p $model_dir
pushd $model_dir
models=(
nemo_en_titanet_large.onnx
nemo_en_titanet_small.onnx
nemo_en_speakerverification_speakernet.onnx
)
for m in ${models[@]}; do
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
done
ls -lh
popd

python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose
16 changes: 8 additions & 8 deletions cmake/kaldi-native-fbank.cmake
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function(download_kaldi_native_fbank)
include(FetchContent)

set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.6.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.6.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=6202a00cd06ba8ff89beb7b6f85cda34e073e94f25fc29e37c519bff0706bf19")

set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
Expand All @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.5.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.5.tar.gz
/tmp/kaldi-native-fbank-1.18.5.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.6.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.6.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.6.tar.gz
/tmp/kaldi-native-fbank-1.18.6.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.6.tar.gz
)

foreach(f IN LISTS possible_file_locations)
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ set(sources
list(APPEND sources
speaker-embedding-extractor-impl.cc
speaker-embedding-extractor-model.cc
speaker-embedding-extractor-nemo-model.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
)
Expand Down
9 changes: 8 additions & 1 deletion sherpa-onnx/csrc/features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;

opts_.mel_opts.num_bins = config.feature_dim;

Expand All @@ -52,6 +56,9 @@ class FeatureExtractor::Impl {
// https://github.com/k2-fsa/sherpa-onnx/issues/514
opts_.mel_opts.high_freq = -400;

opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;

fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}

Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ struct FeatureExtractorConfig {
// If false, we will multiply the inputs by 32768
bool normalize_samples = true;

bool snip_edges = false;
float frame_shift_ms = 10.0f; // in milliseconds.
float frame_length_ms = 25.0f; // in milliseconds.
int32_t low_freq = 20;
bool is_librosa = false;
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
std::string window_type = "povey"; // e.g. Hamming window

std::string ToString() const;

void Register(ParseOptions *po);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
Expand Down
8 changes: 7 additions & 1 deletion sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h"

namespace sherpa_onnx {

Expand All @@ -14,6 +15,7 @@ namespace {
enum class ModelType {
kWeSpeaker,
k3dSpeaker,
kNeMo,
kUnkown,
};

Expand Down Expand Up @@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kWeSpeaker;
} else if (model_type.get() == std::string("3d-speaker")) {
return ModelType::k3dSpeaker;
} else if (model_type.get() == std::string("nemo")) {
return ModelType::kNeMo;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
Expand All @@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create(
// fall through
case ModelType::k3dSpeaker:
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
case ModelType::kNeMo:
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/speaker-embedding-extractor-model.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-model.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_

Expand Down
128 changes: 128 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>

#include "Eigen/Dense"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
#include "sherpa-onnx/csrc/transpose.h"

namespace sherpa_onnx {

class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
public:
explicit SpeakerEmbeddingExtractorNeMoImpl(
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}

int32_t Dim() const override { return model_.GetMetaData().output_dim; }

std::unique_ptr<OnlineStream> CreateStream() const override {
FeatureExtractorConfig feat_config;
const auto &meta_data = model_.GetMetaData();
feat_config.sampling_rate = meta_data.sample_rate;
feat_config.feature_dim = meta_data.feat_dim;
feat_config.normalize_samples = true;
feat_config.snip_edges = true;
feat_config.frame_shift_ms = meta_data.window_stride_ms;
feat_config.frame_length_ms = meta_data.window_size_ms;
feat_config.low_freq = 0;
feat_config.is_librosa = true;
feat_config.remove_dc_offset = false;
feat_config.window_type = meta_data.window_type;

return std::make_unique<OnlineStream>(feat_config);
}

bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() < s->NumFramesReady();
}

std::vector<float> Compute(OnlineStream *s) const override {
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
if (num_frames <= 0) {
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %d",
num_frames);
return {};
}

std::vector<float> features =
s->GetFrames(s->GetNumProcessedFrames(), num_frames);

s->GetNumProcessedFrames() += num_frames;

int32_t feat_dim = features.size() / num_frames;

const auto &meta_data = model_.GetMetaData();
if (!meta_data.feature_normalize_type.empty()) {
if (meta_data.feature_normalize_type == "per_feature") {
NormalizePerFeature(features.data(), num_frames, feat_dim);
} else {
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
meta_data.feature_normalize_type.c_str());
exit(-1);
}
}

if (num_frames % 16 != 0) {
int32_t pad = 16 - num_frames % 16;
features.resize((num_frames + pad) * feat_dim);
}

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());

x = Transpose12(model_.Allocator(), &x);

int64_t x_lens = num_frames;
std::array<int64_t, 1> x_lens_shape{1};
Ort::Value x_lens_tensor = Ort::Value::CreateTensor(
memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size());

Ort::Value embedding =
model_.Compute(std::move(x), std::move(x_lens_tensor));
std::vector<int64_t> embedding_shape =
embedding.GetTensorTypeAndShapeInfo().GetShape();

std::vector<float> ans(embedding_shape[1]);
std::copy(embedding.GetTensorData<float>(),
embedding.GetTensorData<float>() + ans.size(), ans.begin());

return ans;
}

private:
void NormalizePerFeature(float *p, int32_t num_frames,
int32_t feat_dim) const {
auto m = Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
p, num_frames, feat_dim);

auto EX = m.colwise().mean();
auto EX2 = m.array().pow(2).colwise().sum() / num_frames;
auto variance = EX2 - EX.array().pow(2);
auto stddev = variance.array().sqrt();

m = (m.rowwise() - EX).array().rowwise() / stddev.array();
}

private:
SpeakerEmbeddingExtractorNeMoModel model_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_

#include <cstdint>
#include <string>

namespace sherpa_onnx {

struct SpeakerEmbeddingExtractorNeMoModelMetaData {
int32_t output_dim = 0;
int32_t feat_dim = 80;
int32_t sample_rate = 0;
int32_t window_size_ms = 25;
int32_t window_stride_ms = 25;

// Chinese, English, etc.
std::string language;

// for 3d-speaker, it is global-mean
std::string feature_normalize_type;
std::string window_type = "hann";
};

} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
Loading
Loading