-
Notifications
You must be signed in to change notification settings - Fork 480
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add C++ runtime support for speaker verification models from NeMo.
- Loading branch information
1 parent
68a525a
commit ed33030
Showing
8 changed files
with
345 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ | ||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ | ||
#include <algorithm> | ||
#include <memory> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "Eigen/Dense" | ||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" | ||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h" | ||
#include "sherpa-onnx/csrc/transpose.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { | ||
public: | ||
explicit SpeakerEmbeddingExtractorNeMoImpl( | ||
const SpeakerEmbeddingExtractorConfig &config) | ||
: model_(config) {} | ||
|
||
int32_t Dim() const override { return model_.GetMetaData().output_dim; } | ||
|
||
std::unique_ptr<OnlineStream> CreateStream() const override { | ||
FeatureExtractorConfig feat_config; | ||
const auto &meta_data = model_.GetMetaData(); | ||
feat_config.sampling_rate = meta_data.sample_rate; | ||
feat_config.feature_dim = meta_data.feat_dim; | ||
feat_config.normalize_samples = true; | ||
feat_config.snip_edges = true; | ||
feat_config.frame_shift_ms = meta_data.window_stride_ms; | ||
feat_config.frame_length_ms = meta_data.window_size_ms; | ||
feat_config.low_freq = 0; | ||
feat_config.is_librosa = true; | ||
feat_config.remove_dc_offset = false; | ||
feat_config.window_type = meta_data.window_type; | ||
|
||
return std::make_unique<OnlineStream>(feat_config); | ||
} | ||
|
||
bool IsReady(OnlineStream *s) const override { | ||
return s->GetNumProcessedFrames() < s->NumFramesReady(); | ||
} | ||
|
||
std::vector<float> Compute(OnlineStream *s) const override { | ||
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames(); | ||
if (num_frames <= 0) { | ||
SHERPA_ONNX_LOGE( | ||
"Please make sure IsReady(s) returns true. num_frames: %d", | ||
num_frames); | ||
return {}; | ||
} | ||
|
||
std::vector<float> features = | ||
s->GetFrames(s->GetNumProcessedFrames(), num_frames); | ||
|
||
s->GetNumProcessedFrames() += num_frames; | ||
|
||
int32_t feat_dim = features.size() / num_frames; | ||
|
||
const auto &meta_data = model_.GetMetaData(); | ||
if (!meta_data.feature_normalize_type.empty()) { | ||
if (meta_data.feature_normalize_type == "per_feature") { | ||
NormalizePerFeature(features.data(), num_frames, feat_dim); | ||
} else { | ||
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s", | ||
meta_data.feature_normalize_type.c_str()); | ||
exit(-1); | ||
} | ||
} | ||
|
||
if (num_frames % 16 != 0) { | ||
int32_t pad = 16 - num_frames % 16; | ||
features.resize((num_frames + pad) * feat_dim); | ||
} | ||
|
||
auto memory_info = | ||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
|
||
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim}; | ||
Ort::Value x = | ||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(), | ||
x_shape.data(), x_shape.size()); | ||
|
||
x = Transpose12(model_.Allocator(), &x); | ||
|
||
int64_t x_lens = num_frames; | ||
std::array<int64_t, 1> x_lens_shape{1}; | ||
Ort::Value x_lens_tensor = Ort::Value::CreateTensor( | ||
memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size()); | ||
|
||
Ort::Value embedding = | ||
model_.Compute(std::move(x), std::move(x_lens_tensor)); | ||
std::vector<int64_t> embedding_shape = | ||
embedding.GetTensorTypeAndShapeInfo().GetShape(); | ||
|
||
std::vector<float> ans(embedding_shape[1]); | ||
std::copy(embedding.GetTensorData<float>(), | ||
embedding.GetTensorData<float>() + ans.size(), ans.begin()); | ||
|
||
return ans; | ||
} | ||
|
||
private: | ||
void NormalizePerFeature(float *p, int32_t num_frames, | ||
int32_t feat_dim) const { | ||
auto m = Eigen::Map< | ||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>( | ||
p, num_frames, feat_dim); | ||
|
||
auto EX = m.colwise().mean(); | ||
auto EX2 = m.array().pow(2).colwise().sum() / num_frames; | ||
auto variance = EX2 - EX.array().pow(2); | ||
auto stddev = variance.array().sqrt(); | ||
|
||
m = (m.rowwise() - EX).array().rowwise() / stddev.array(); | ||
} | ||
|
||
private: | ||
SpeakerEmbeddingExtractorNeMoModel model_; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ |
28 changes: 28 additions & 0 deletions
28
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ | ||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ | ||
|
||
#include <cstdint> | ||
#include <string> | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct SpeakerEmbeddingExtractorNeMoModelMetaData { | ||
int32_t output_dim = 0; | ||
int32_t feat_dim = 80; | ||
int32_t sample_rate = 0; | ||
int32_t window_size_ms = 25; | ||
int32_t window_stride_ms = 25; | ||
|
||
// Chinese, English, etc. | ||
std::string language; | ||
|
||
// for 3d-speaker, it is global-mean | ||
std::string feature_normalize_type; | ||
std::string window_type = "hann"; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ |
126 changes: 126 additions & 0 deletions
126
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h" | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
#include "sherpa-onnx/csrc/onnx-utils.h" | ||
#include "sherpa-onnx/csrc/session.h" | ||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class SpeakerEmbeddingExtractorNeMoModel::Impl { | ||
public: | ||
explicit Impl(const SpeakerEmbeddingExtractorConfig &config) | ||
: config_(config), | ||
env_(ORT_LOGGING_LEVEL_ERROR), | ||
sess_opts_(GetSessionOptions(config)), | ||
allocator_{} { | ||
{ | ||
auto buf = ReadFile(config.model); | ||
Init(buf.data(), buf.size()); | ||
} | ||
} | ||
|
||
Ort::Value Compute(Ort::Value x, Ort::Value x_lens) const { | ||
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)}; | ||
|
||
// output_names_ptr_[0] is logits | ||
// output_names_ptr_[1] is embeddings | ||
// so we use output_names_ptr_.data() + 1 here to extract only the | ||
// embeddings | ||
auto outputs = sess_->Run({}, input_names_ptr_.data(), inputs.data(), | ||
inputs.size(), output_names_ptr_.data() + 1, 1); | ||
return std::move(outputs[0]); | ||
} | ||
|
||
OrtAllocator *Allocator() const { return allocator_; } | ||
|
||
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const { | ||
return meta_data_; | ||
} | ||
|
||
private: | ||
void Init(void *model_data, size_t model_data_length) { | ||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
sess_opts_); | ||
|
||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
|
||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
|
||
// get meta data | ||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
if (config_.debug) { | ||
std::ostringstream os; | ||
PrintModelMetadata(os, meta_data); | ||
SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
} | ||
|
||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.feat_dim, "feat_dim"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.window_size_ms, "window_size_ms"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.window_stride_ms, "window_stride_ms"); | ||
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); | ||
|
||
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT( | ||
meta_data_.feature_normalize_type, "feature_normalize_type", ""); | ||
|
||
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.window_type, | ||
"window_type", "povey"); | ||
|
||
std::string framework; | ||
SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); | ||
if (framework != "nemo") { | ||
SHERPA_ONNX_LOGE("Expect a NeMo model, given: %s", framework.c_str()); | ||
exit(-1); | ||
} | ||
} | ||
|
||
private: | ||
SpeakerEmbeddingExtractorConfig config_; | ||
Ort::Env env_; | ||
Ort::SessionOptions sess_opts_; | ||
Ort::AllocatorWithDefaultOptions allocator_; | ||
|
||
std::unique_ptr<Ort::Session> sess_; | ||
|
||
std::vector<std::string> input_names_; | ||
std::vector<const char *> input_names_ptr_; | ||
|
||
std::vector<std::string> output_names_; | ||
std::vector<const char *> output_names_ptr_; | ||
|
||
SpeakerEmbeddingExtractorNeMoModelMetaData meta_data_; | ||
}; | ||
|
||
SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel( | ||
const SpeakerEmbeddingExtractorConfig &config) | ||
: impl_(std::make_unique<Impl>(config)) {} | ||
|
||
SpeakerEmbeddingExtractorNeMoModel::~SpeakerEmbeddingExtractorNeMoModel() = | ||
default; | ||
|
||
const SpeakerEmbeddingExtractorNeMoModelMetaData & | ||
SpeakerEmbeddingExtractorNeMoModel::GetMetaData() const { | ||
return impl_->GetMetaData(); | ||
} | ||
|
||
Ort::Value SpeakerEmbeddingExtractorNeMoModel::Compute( | ||
Ort::Value x, Ort::Value x_lens) const { | ||
return impl_->Compute(std::move(x), std::move(x_lens)); | ||
} | ||
|
||
OrtAllocator *SpeakerEmbeddingExtractorNeMoModel::Allocator() const { | ||
return impl_->Allocator(); | ||
} | ||
|
||
} // namespace sherpa_onnx |
Oops, something went wrong.