Skip to content

Commit

Permalink
Add C++ runtime for speaker verification models from NeMo (#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jan 13, 2024
1 parent 68a525a commit 2024e96
Show file tree
Hide file tree
Showing 20 changed files with 405 additions and 24 deletions.
14 changes: 14 additions & 0 deletions .github/scripts/test-speaker-recognition-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,19 @@ done
ls -lh
popd

log "Download NeMo models"
model_dir=$d/nemo
mkdir -p $model_dir
pushd $model_dir
models=(
nemo_en_titanet_large.onnx
nemo_en_titanet_small.onnx
nemo_en_speakerverification_speakernet.onnx
)
for m in ${models[@]}; do
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
done
ls -lh
popd

python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose
16 changes: 8 additions & 8 deletions cmake/kaldi-native-fbank.cmake
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function(download_kaldi_native_fbank)
include(FetchContent)

set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.6.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.6.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=6202a00cd06ba8ff89beb7b6f85cda34e073e94f25fc29e37c519bff0706bf19")

set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
Expand All @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.5.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.5.tar.gz
/tmp/kaldi-native-fbank-1.18.5.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.6.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.6.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.6.tar.gz
/tmp/kaldi-native-fbank-1.18.6.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.6.tar.gz
)

foreach(f IN LISTS possible_file_locations)
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ set(sources
list(APPEND sources
speaker-embedding-extractor-impl.cc
speaker-embedding-extractor-model.cc
speaker-embedding-extractor-nemo-model.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
)
Expand Down
9 changes: 8 additions & 1 deletion sherpa-onnx/csrc/features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;

opts_.mel_opts.num_bins = config.feature_dim;

Expand All @@ -52,6 +56,9 @@ class FeatureExtractor::Impl {
// https://github.com/k2-fsa/sherpa-onnx/issues/514
opts_.mel_opts.high_freq = -400;

opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;

fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}

Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ struct FeatureExtractorConfig {
// If false, we will multiply the inputs by 32768
bool normalize_samples = true;

bool snip_edges = false;
float frame_shift_ms = 10.0f; // in milliseconds.
float frame_length_ms = 25.0f; // in milliseconds.
int32_t low_freq = 20;
bool is_librosa = false;
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
std::string window_type = "povey"; // e.g. Hamming window

std::string ToString() const;

void Register(ParseOptions *po);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
Expand Down
8 changes: 7 additions & 1 deletion sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h"

namespace sherpa_onnx {

Expand All @@ -14,6 +15,7 @@ namespace {
enum class ModelType {
kWeSpeaker,
k3dSpeaker,
kNeMo,
kUnkown,
};

Expand Down Expand Up @@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kWeSpeaker;
} else if (model_type.get() == std::string("3d-speaker")) {
return ModelType::k3dSpeaker;
} else if (model_type.get() == std::string("nemo")) {
return ModelType::kNeMo;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
Expand All @@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create(
// fall through
case ModelType::k3dSpeaker:
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
case ModelType::kNeMo:
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
//
// Copyright (c) 2023-2024 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/speaker-embedding-extractor-model.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-model.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_

Expand Down
128 changes: 128 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>

#include "Eigen/Dense"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
#include "sherpa-onnx/csrc/transpose.h"

namespace sherpa_onnx {

class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
public:
explicit SpeakerEmbeddingExtractorNeMoImpl(
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}

int32_t Dim() const override { return model_.GetMetaData().output_dim; }

std::unique_ptr<OnlineStream> CreateStream() const override {
FeatureExtractorConfig feat_config;
const auto &meta_data = model_.GetMetaData();
feat_config.sampling_rate = meta_data.sample_rate;
feat_config.feature_dim = meta_data.feat_dim;
feat_config.normalize_samples = true;
feat_config.snip_edges = true;
feat_config.frame_shift_ms = meta_data.window_stride_ms;
feat_config.frame_length_ms = meta_data.window_size_ms;
feat_config.low_freq = 0;
feat_config.is_librosa = true;
feat_config.remove_dc_offset = false;
feat_config.window_type = meta_data.window_type;

return std::make_unique<OnlineStream>(feat_config);
}

bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() < s->NumFramesReady();
}

std::vector<float> Compute(OnlineStream *s) const override {
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
if (num_frames <= 0) {
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %d",
num_frames);
return {};
}

std::vector<float> features =
s->GetFrames(s->GetNumProcessedFrames(), num_frames);

s->GetNumProcessedFrames() += num_frames;

int32_t feat_dim = features.size() / num_frames;

const auto &meta_data = model_.GetMetaData();
if (!meta_data.feature_normalize_type.empty()) {
if (meta_data.feature_normalize_type == "per_feature") {
NormalizePerFeature(features.data(), num_frames, feat_dim);
} else {
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
meta_data.feature_normalize_type.c_str());
exit(-1);
}
}

if (num_frames % 16 != 0) {
int32_t pad = 16 - num_frames % 16;
features.resize((num_frames + pad) * feat_dim);
}

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());

x = Transpose12(model_.Allocator(), &x);

int64_t x_lens = num_frames;
std::array<int64_t, 1> x_lens_shape{1};
Ort::Value x_lens_tensor = Ort::Value::CreateTensor(
memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size());

Ort::Value embedding =
model_.Compute(std::move(x), std::move(x_lens_tensor));
std::vector<int64_t> embedding_shape =
embedding.GetTensorTypeAndShapeInfo().GetShape();

std::vector<float> ans(embedding_shape[1]);
std::copy(embedding.GetTensorData<float>(),
embedding.GetTensorData<float>() + ans.size(), ans.begin());

return ans;
}

private:
void NormalizePerFeature(float *p, int32_t num_frames,
int32_t feat_dim) const {
auto m = Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
p, num_frames, feat_dim);

auto EX = m.colwise().mean();
auto EX2 = m.array().pow(2).colwise().sum() / num_frames;
auto variance = EX2 - EX.array().pow(2);
auto stddev = variance.array().sqrt();

m = (m.rowwise() - EX).array().rowwise() / stddev.array();
}

private:
SpeakerEmbeddingExtractorNeMoModel model_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_

#include <cstdint>
#include <string>

namespace sherpa_onnx {

struct SpeakerEmbeddingExtractorNeMoModelMetaData {
int32_t output_dim = 0;
int32_t feat_dim = 80;
int32_t sample_rate = 0;
int32_t window_size_ms = 25;
int32_t window_stride_ms = 25;

// Chinese, English, etc.
std::string language;

// for 3d-speaker, it is global-mean
std::string feature_normalize_type;
std::string window_type = "hann";
};

} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
Loading

0 comments on commit 2024e96

Please sign in to comment.