Skip to content

Commit

Permalink
Add C++ runtime support for speaker verification models from NeMo.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jan 13, 2024
1 parent 68a525a commit ed33030
Show file tree
Hide file tree
Showing 8 changed files with 345 additions and 1 deletion.
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ set(sources
list(APPEND sources
speaker-embedding-extractor-impl.cc
speaker-embedding-extractor-model.cc
speaker-embedding-extractor-nemo-model.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
)
Expand Down
9 changes: 8 additions & 1 deletion sherpa-onnx/csrc/features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
opts_.frame_opts.window_type = config.window_type;

opts_.mel_opts.num_bins = config.feature_dim;

Expand All @@ -52,6 +56,9 @@ class FeatureExtractor::Impl {
// https://github.com/k2-fsa/sherpa-onnx/issues/514
opts_.mel_opts.high_freq = -400;

opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;

fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}

Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ struct FeatureExtractorConfig {
// If false, we will multiply the inputs by 32768
bool normalize_samples = true;

bool snip_edges = false;
float frame_shift_ms = 10.0f; // in milliseconds.
float frame_length_ms = 25.0f; // in milliseconds.
int32_t low_freq = 20;
bool is_librosa = false;
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
std::string window_type = "povey"; // e.g. Hamming window

std::string ToString() const;

void Register(ParseOptions *po);
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h"

namespace sherpa_onnx {

Expand All @@ -14,6 +15,7 @@ namespace {
enum class ModelType {
kWeSpeaker,
k3dSpeaker,
kNeMo,
kUnkown,
};

Expand Down Expand Up @@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kWeSpeaker;
} else if (model_type.get() == std::string("3d-speaker")) {
return ModelType::k3dSpeaker;
} else if (model_type.get() == std::string("nemo")) {
return ModelType::kNeMo;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
Expand All @@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create(
// fall through
case ModelType::k3dSpeaker:
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
case ModelType::kNeMo:
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
Expand Down
128 changes: 128 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>

#include "Eigen/Dense"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
#include "sherpa-onnx/csrc/transpose.h"

namespace sherpa_onnx {

class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
public:
explicit SpeakerEmbeddingExtractorNeMoImpl(
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}

int32_t Dim() const override { return model_.GetMetaData().output_dim; }

std::unique_ptr<OnlineStream> CreateStream() const override {
FeatureExtractorConfig feat_config;
const auto &meta_data = model_.GetMetaData();
feat_config.sampling_rate = meta_data.sample_rate;
feat_config.feature_dim = meta_data.feat_dim;
feat_config.normalize_samples = true;
feat_config.snip_edges = true;
feat_config.frame_shift_ms = meta_data.window_stride_ms;
feat_config.frame_length_ms = meta_data.window_size_ms;
feat_config.low_freq = 0;
feat_config.is_librosa = true;
feat_config.remove_dc_offset = false;
feat_config.window_type = meta_data.window_type;

return std::make_unique<OnlineStream>(feat_config);
}

bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() < s->NumFramesReady();
}

std::vector<float> Compute(OnlineStream *s) const override {
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
if (num_frames <= 0) {
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %d",
num_frames);
return {};
}

std::vector<float> features =
s->GetFrames(s->GetNumProcessedFrames(), num_frames);

s->GetNumProcessedFrames() += num_frames;

int32_t feat_dim = features.size() / num_frames;

const auto &meta_data = model_.GetMetaData();
if (!meta_data.feature_normalize_type.empty()) {
if (meta_data.feature_normalize_type == "per_feature") {
NormalizePerFeature(features.data(), num_frames, feat_dim);
} else {
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
meta_data.feature_normalize_type.c_str());
exit(-1);
}
}

if (num_frames % 16 != 0) {
int32_t pad = 16 - num_frames % 16;
features.resize((num_frames + pad) * feat_dim);
}

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());

x = Transpose12(model_.Allocator(), &x);

int64_t x_lens = num_frames;
std::array<int64_t, 1> x_lens_shape{1};
Ort::Value x_lens_tensor = Ort::Value::CreateTensor(
memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size());

Ort::Value embedding =
model_.Compute(std::move(x), std::move(x_lens_tensor));
std::vector<int64_t> embedding_shape =
embedding.GetTensorTypeAndShapeInfo().GetShape();

std::vector<float> ans(embedding_shape[1]);
std::copy(embedding.GetTensorData<float>(),
embedding.GetTensorData<float>() + ans.size(), ans.begin());

return ans;
}

private:
void NormalizePerFeature(float *p, int32_t num_frames,
int32_t feat_dim) const {
auto m = Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
p, num_frames, feat_dim);

auto EX = m.colwise().mean();
auto EX2 = m.array().pow(2).colwise().sum() / num_frames;
auto variance = EX2 - EX.array().pow(2);
auto stddev = variance.array().sqrt();

m = (m.rowwise() - EX).array().rowwise() / stddev.array();
}

private:
SpeakerEmbeddingExtractorNeMoModel model_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_

#include <cstdint>
#include <string>

namespace sherpa_onnx {

struct SpeakerEmbeddingExtractorNeMoModelMetaData {
int32_t output_dim = 0;
int32_t feat_dim = 80;
int32_t sample_rate = 0;
int32_t window_size_ms = 25;
int32_t window_stride_ms = 25;

// Chinese, English, etc.
std::string language;

// for 3d-speaker, it is global-mean
std::string feature_normalize_type;
std::string window_type = "hann";
};

} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
126 changes: 126 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"

#include <string>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h"

namespace sherpa_onnx {

class SpeakerEmbeddingExtractorNeMoModel::Impl {
public:
explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.model);
Init(buf.data(), buf.size());
}
}

Ort::Value Compute(Ort::Value x, Ort::Value x_lens) const {
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};

// output_names_ptr_[0] is logits
// output_names_ptr_[1] is embeddings
// so we use output_names_ptr_.data() + 1 here to extract only the
// embeddings
auto outputs = sess_->Run({}, input_names_ptr_.data(), inputs.data(),
inputs.size(), output_names_ptr_.data() + 1, 1);
return std::move(outputs[0]);
}

OrtAllocator *Allocator() const { return allocator_; }

const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
return meta_data_;
}

private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);

GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);

// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}

Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
SHERPA_ONNX_READ_META_DATA(meta_data_.feat_dim, "feat_dim");
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA(meta_data_.window_size_ms, "window_size_ms");
SHERPA_ONNX_READ_META_DATA(meta_data_.window_stride_ms, "window_stride_ms");
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");

SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(
meta_data_.feature_normalize_type, "feature_normalize_type", "");

SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.window_type,
"window_type", "povey");

std::string framework;
SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
if (framework != "nemo") {
SHERPA_ONNX_LOGE("Expect a NeMo model, given: %s", framework.c_str());
exit(-1);
}
}

private:
SpeakerEmbeddingExtractorConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;

std::unique_ptr<Ort::Session> sess_;

std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;

std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;

SpeakerEmbeddingExtractorNeMoModelMetaData meta_data_;
};

SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel(
const SpeakerEmbeddingExtractorConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

SpeakerEmbeddingExtractorNeMoModel::~SpeakerEmbeddingExtractorNeMoModel() =
default;

const SpeakerEmbeddingExtractorNeMoModelMetaData &
SpeakerEmbeddingExtractorNeMoModel::GetMetaData() const {
return impl_->GetMetaData();
}

Ort::Value SpeakerEmbeddingExtractorNeMoModel::Compute(
Ort::Value x, Ort::Value x_lens) const {
return impl_->Compute(std::move(x), std::move(x_lens));
}

OrtAllocator *SpeakerEmbeddingExtractorNeMoModel::Allocator() const {
return impl_->Allocator();
}

} // namespace sherpa_onnx
Loading

0 comments on commit ed33030

Please sign in to comment.