Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C++ runtime for models from 3d-speaker #523

Merged
merged 7 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions .github/scripts/test-speaker-recognition-python.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env bash

set -e

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

d=/tmp/sr-models
mkdir -p $d

pushd $d
log "Download test waves"
git clone https://github.com/csukuangfj/sr-data
popd

log "Download wespeaker models"
model_dir=$d/wespeaker
mkdir -p $model_dir
pushd $model_dir
models=(
en_voxceleb_CAM++.onnx
en_voxceleb_CAM++_LM.onnx
en_voxceleb_resnet152_LM.onnx
en_voxceleb_resnet221_LM.onnx
en_voxceleb_resnet293_LM.onnx
en_voxceleb_resnet34.onnx
en_voxceleb_resnet34_LM.onnx
zh_cnceleb_resnet34.onnx
zh_cnceleb_resnet34_LM.onnx
)
for m in ${models[@]}; do
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
done
ls -lh
popd

log "Download 3d-speaker models"
model_dir=$d/3dspeaker
mkdir -p $model_dir
pushd $model_dir
models=(
speech_campplus_sv_en_voxceleb_16k.onnx
speech_campplus_sv_zh-cn_16k-common.onnx
speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx
speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
speech_eres2net_sv_en_voxceleb_16k.onnx
speech_eres2net_sv_zh-cn_16k-common.onnx
)
for m in ${models[@]}; do
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
done
ls -lh
popd


python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose
1 change: 1 addition & 0 deletions .github/workflows/run-python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ jobs:
- name: Test sherpa-onnx
shell: bash
run: |
.github/scripts/test-speaker-recognition-python.sh
.github/scripts/test-python.sh

- uses: actions/upload-artifact@v3
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ set(sources
# speaker embedding extractor
list(APPEND sources
speaker-embedding-extractor-impl.cc
speaker-embedding-extractor-wespeaker-model.cc
speaker-embedding-extractor-model.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h
// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>

#include "Eigen/Dense"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"

namespace sherpa_onnx {

class SpeakerEmbeddingExtractorWeSpeakerImpl
class SpeakerEmbeddingExtractorGeneralImpl
: public SpeakerEmbeddingExtractorImpl {
public:
explicit SpeakerEmbeddingExtractorWeSpeakerImpl(
explicit SpeakerEmbeddingExtractorGeneralImpl(
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}

int32_t Dim() const override { return model_.GetMetaData().output_dim; }

std::unique_ptr<OnlineStream> CreateStream() const override {
FeatureExtractorConfig feat_config;
auto meta_data = model_.GetMetaData();
const auto &meta_data = model_.GetMetaData();
feat_config.sampling_rate = meta_data.sample_rate;
feat_config.normalize_samples = meta_data.normalize_samples;

Expand All @@ -52,6 +53,17 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl

int32_t feat_dim = features.size() / num_frames;

const auto &meta_data = model_.GetMetaData();
if (!meta_data.feature_normalize_type.empty()) {
if (meta_data.feature_normalize_type == "global-mean") {
SubtractGlobalMean(features.data(), num_frames, feat_dim);
} else {
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
meta_data.feature_normalize_type.c_str());
exit(-1);
}
}

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

Expand All @@ -71,9 +83,19 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl
}

private:
SpeakerEmbeddingExtractorWeSpeakerModel model_;
void SubtractGlobalMean(float *p, int32_t num_frames,
int32_t feat_dim) const {
auto m = Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
p, num_frames, feat_dim);

m = m.rowwise() - m.colwise().mean();
}

private:
SpeakerEmbeddingExtractorModel model_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
9 changes: 7 additions & 2 deletions sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"

namespace sherpa_onnx {

namespace {

enum class ModelType {
kWeSpeaker,
k3dSpeaker,
kUnkown,
};

Expand Down Expand Up @@ -49,6 +50,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,

if (model_type.get() == std::string("wespeaker")) {
return ModelType::kWeSpeaker;
} else if (model_type.get() == std::string("3d-speaker")) {
return ModelType::k3dSpeaker;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
Expand All @@ -68,7 +71,9 @@ SpeakerEmbeddingExtractorImpl::Create(

switch (model_type) {
case ModelType::kWeSpeaker:
return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config);
// fall through
case ModelType::k3dSpeaker:
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
Expand Down
28 changes: 28 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_

#include <cstdint>
#include <string>

namespace sherpa_onnx {

struct SpeakerEmbeddingExtractorModelMetaData {
int32_t output_dim = 0;
int32_t sample_rate = 0;

// for wespeaker models, it is 0;
// for 3d-speaker models, it is 1
int32_t normalize_samples = 1;

// Chinese, English, etc.
std::string language;

// for 3d-speaker, it is global-mean
std::string feature_normalize_type;
};

} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
// sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023-2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"

#include <string>
#include <utility>
Expand All @@ -11,11 +11,11 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h"

namespace sherpa_onnx {

class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
class SpeakerEmbeddingExtractorModel::Impl {
public:
explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
: config_(config),
Expand All @@ -37,7 +37,7 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
return std::move(outputs[0]);
}

const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const {
const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const {
return meta_data_;
}

Expand Down Expand Up @@ -65,10 +65,13 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
"normalize_samples");
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");

SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(
meta_data_.feature_normalize_type, "feature_normalize_type", "");

std::string framework;
SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
if (framework != "wespeaker") {
SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s",
if (framework != "wespeaker" && framework != "3d-speaker") {
SHERPA_ONNX_LOGE("Expect a wespeaker or a 3d-speaker model, given: %s",
framework.c_str());
exit(-1);
}
Expand All @@ -88,24 +91,21 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;

SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_;
SpeakerEmbeddingExtractorModelMetaData meta_data_;
};

SpeakerEmbeddingExtractorWeSpeakerModel::
SpeakerEmbeddingExtractorWeSpeakerModel(
const SpeakerEmbeddingExtractorConfig &config)
SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel(
const SpeakerEmbeddingExtractorConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

SpeakerEmbeddingExtractorWeSpeakerModel::
~SpeakerEmbeddingExtractorWeSpeakerModel() = default;
SpeakerEmbeddingExtractorModel::~SpeakerEmbeddingExtractorModel() = default;

const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &
SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const {
const SpeakerEmbeddingExtractorModelMetaData &
SpeakerEmbeddingExtractorModel::GetMetaData() const {
return impl_->GetMetaData();
}

Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute(
Ort::Value x) const {
Ort::Value SpeakerEmbeddingExtractorModel::Compute(Ort::Value x) const {
return impl_->Compute(std::move(x));
}

Expand Down
37 changes: 37 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-extractor-model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-model.h
//
// Copyright (c) 2023-2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_

#include <memory>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"

namespace sherpa_onnx {

class SpeakerEmbeddingExtractorModel {
public:
explicit SpeakerEmbeddingExtractorModel(
const SpeakerEmbeddingExtractorConfig &config);

~SpeakerEmbeddingExtractorModel();

const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const;

/**
* @param x A float32 tensor of shape (N, T, C)
* @return A float32 tensor of shape (N, C)
*/
Ort::Value Compute(Ort::Value x) const;

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_

This file was deleted.

Loading
Loading