-
Notifications
You must be signed in to change notification settings - Fork 477
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add C++ runtime for speaker verification models from NeMo (#527)
- Loading branch information
1 parent
68a525a
commit 2024e96
Showing
20 changed files
with
405 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ | ||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ | ||
#include <algorithm> | ||
#include <memory> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "Eigen/Dense" | ||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" | ||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h" | ||
#include "sherpa-onnx/csrc/transpose.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { | ||
public: | ||
explicit SpeakerEmbeddingExtractorNeMoImpl( | ||
const SpeakerEmbeddingExtractorConfig &config) | ||
: model_(config) {} | ||
|
||
int32_t Dim() const override { return model_.GetMetaData().output_dim; } | ||
|
||
std::unique_ptr<OnlineStream> CreateStream() const override { | ||
FeatureExtractorConfig feat_config; | ||
const auto &meta_data = model_.GetMetaData(); | ||
feat_config.sampling_rate = meta_data.sample_rate; | ||
feat_config.feature_dim = meta_data.feat_dim; | ||
feat_config.normalize_samples = true; | ||
feat_config.snip_edges = true; | ||
feat_config.frame_shift_ms = meta_data.window_stride_ms; | ||
feat_config.frame_length_ms = meta_data.window_size_ms; | ||
feat_config.low_freq = 0; | ||
feat_config.is_librosa = true; | ||
feat_config.remove_dc_offset = false; | ||
feat_config.window_type = meta_data.window_type; | ||
|
||
return std::make_unique<OnlineStream>(feat_config); | ||
} | ||
|
||
bool IsReady(OnlineStream *s) const override { | ||
return s->GetNumProcessedFrames() < s->NumFramesReady(); | ||
} | ||
|
||
std::vector<float> Compute(OnlineStream *s) const override { | ||
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames(); | ||
if (num_frames <= 0) { | ||
SHERPA_ONNX_LOGE( | ||
"Please make sure IsReady(s) returns true. num_frames: %d", | ||
num_frames); | ||
return {}; | ||
} | ||
|
||
std::vector<float> features = | ||
s->GetFrames(s->GetNumProcessedFrames(), num_frames); | ||
|
||
s->GetNumProcessedFrames() += num_frames; | ||
|
||
int32_t feat_dim = features.size() / num_frames; | ||
|
||
const auto &meta_data = model_.GetMetaData(); | ||
if (!meta_data.feature_normalize_type.empty()) { | ||
if (meta_data.feature_normalize_type == "per_feature") { | ||
NormalizePerFeature(features.data(), num_frames, feat_dim); | ||
} else { | ||
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s", | ||
meta_data.feature_normalize_type.c_str()); | ||
exit(-1); | ||
} | ||
} | ||
|
||
if (num_frames % 16 != 0) { | ||
int32_t pad = 16 - num_frames % 16; | ||
features.resize((num_frames + pad) * feat_dim); | ||
} | ||
|
||
auto memory_info = | ||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
|
||
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim}; | ||
Ort::Value x = | ||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(), | ||
x_shape.data(), x_shape.size()); | ||
|
||
x = Transpose12(model_.Allocator(), &x); | ||
|
||
int64_t x_lens = num_frames; | ||
std::array<int64_t, 1> x_lens_shape{1}; | ||
Ort::Value x_lens_tensor = Ort::Value::CreateTensor( | ||
memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size()); | ||
|
||
Ort::Value embedding = | ||
model_.Compute(std::move(x), std::move(x_lens_tensor)); | ||
std::vector<int64_t> embedding_shape = | ||
embedding.GetTensorTypeAndShapeInfo().GetShape(); | ||
|
||
std::vector<float> ans(embedding_shape[1]); | ||
std::copy(embedding.GetTensorData<float>(), | ||
embedding.GetTensorData<float>() + ans.size(), ans.begin()); | ||
|
||
return ans; | ||
} | ||
|
||
private: | ||
void NormalizePerFeature(float *p, int32_t num_frames, | ||
int32_t feat_dim) const { | ||
auto m = Eigen::Map< | ||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>( | ||
p, num_frames, feat_dim); | ||
|
||
auto EX = m.colwise().mean(); | ||
auto EX2 = m.array().pow(2).colwise().sum() / num_frames; | ||
auto variance = EX2 - EX.array().pow(2); | ||
auto stddev = variance.array().sqrt(); | ||
|
||
m = (m.rowwise() - EX).array().rowwise() / stddev.array(); | ||
} | ||
|
||
private: | ||
SpeakerEmbeddingExtractorNeMoModel model_; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ |
28 changes: 28 additions & 0 deletions
28
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ | ||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ | ||
|
||
#include <cstdint> | ||
#include <string> | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct SpeakerEmbeddingExtractorNeMoModelMetaData { | ||
int32_t output_dim = 0; | ||
int32_t feat_dim = 80; | ||
int32_t sample_rate = 0; | ||
int32_t window_size_ms = 25; | ||
int32_t window_stride_ms = 25; | ||
|
||
// Chinese, English, etc. | ||
std::string language; | ||
|
||
// for 3d-speaker, it is global-mean | ||
std::string feature_normalize_type; | ||
std::string window_type = "hann"; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ |
Oops, something went wrong.