Skip to content

Commit

Permalink
Use espeak-ng for coqui-ai/TTS VITS English models. (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 6, 2023
1 parent 3b90e85 commit 23cf92d
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ data class OfflineTtsModelConfig(
data class OfflineTtsConfig(
var model: OfflineTtsModelConfig,
var ruleFsts: String = "",
var maxNumSentences: Int = 2,
var maxNumSentences: Int = 1,
)

class GeneratedAudio(
Expand Down
3 changes: 3 additions & 0 deletions python-api-examples/offline-tts-play.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def main():

if len(audio.samples) == 0:
print("Error in generating audios. Please read previous error messages.")
global killed
killed = True
play_back_thread.join()
return

elapsed_seconds = end - start
Expand Down
31 changes: 23 additions & 8 deletions scripts/apk/generate-tts-apk-script.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ class TtsModel:
data_dir: Optional[str] = None


def get_coqui_models() -> List[TtsModel]:
# English (coqui-ai/TTS)
models = [
TtsModel(model_dir="vits-coqui-en-ljspeech"),
TtsModel(model_dir="vits-coqui-en-ljspeech-neon"),
TtsModel(model_dir="vits-coqui-en-vctk"),
# TtsModel(model_dir="vits-coqui-en-jenny"),
]

for m in models:
m.data_dir = m.model_dir + "/" + "espeak-ng-data"
m.model_name = "model.onnx"
m.lang = "en"

return models


def get_piper_models() -> List[TtsModel]:
models = [
TtsModel(model_dir="vits-piper-ar_JO-kareem-low"),
Expand Down Expand Up @@ -137,6 +154,7 @@ def get_piper_models() -> List[TtsModel]:
TtsModel(model_dir="vits-piper-vi_VN-vivos-x_low"),
TtsModel(model_dir="vits-piper-zh_CN-huayan-medium"),
]

for m in models:
m.data_dir = m.model_dir + "/" + "espeak-ng-data"
m.model_name = m.model_dir[len("vits-piper-") :] + ".onnx"
Expand All @@ -145,7 +163,7 @@ def get_piper_models() -> List[TtsModel]:
return models


def get_all_models() -> List[TtsModel]:
def get_vits_models() -> List[TtsModel]:
return [
# Chinese
TtsModel(
Expand Down Expand Up @@ -202,12 +220,6 @@ def get_all_models() -> List[TtsModel]:
lang="zh",
rule_fsts="vits-zh-hf-theresa/rule.fst",
),
# English (coqui-ai/TTS)
# fmt: off
TtsModel(model_dir="vits-coqui-en-ljspeech", model_name="model.onnx", lang="en"),
TtsModel(model_dir="vits-coqui-en-ljspeech-neon", model_name="model.onnx", lang="en"),
TtsModel(model_dir="vits-coqui-en-vctk", model_name="model.onnx", lang="en"),
# TtsModel(model_dir="vits-coqui-en-jenny", model_name="model.onnx", lang="en"),
# English (US)
TtsModel(model_dir="vits-vctk", model_name="vits-vctk.onnx", lang="en"),
TtsModel(model_dir="vits-ljs", model_name="vits-ljs.onnx", lang="en"),
Expand All @@ -225,8 +237,11 @@ def main():
s = f.read()
template = environment.from_string(s)
d = dict()
# all_model_list = get_all_models()

# all_model_list = get_vits_models()
all_model_list = get_piper_models()
all_model_list += get_coqui_models()

num_models = len(all_model_list)

num_per_runner = num_models // total
Expand Down
34 changes: 23 additions & 11 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,16 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
#endif

int32_t SampleRate() const override { return model_->SampleRate(); }
int32_t SampleRate() const override {
return model_->GetMetaData().sample_rate;
}

GeneratedAudio Generate(
const std::string &_text, int64_t sid = 0, float speed = 1.0,
GeneratedAudioCallback callback = nullptr) const override {
int32_t num_speakers = model_->NumSpeakers();
const auto &meta_data = model_->GetMetaData();
int32_t num_speakers = meta_data.num_speakers;

if (num_speakers == 0 && sid != 0) {
SHERPA_ONNX_LOGE(
"This is a single-speaker model and supports only sid 0. Given sid: "
Expand Down Expand Up @@ -105,14 +109,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}

std::vector<std::vector<int64_t>> x =
frontend_->ConvertTextToTokenIds(text, model_->Voice());
frontend_->ConvertTextToTokenIds(text, meta_data.voice);

if (x.empty() || (x.size() == 1 && x[0].empty())) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
return {};
}

if (model_->AddBlank() && config_.model.vits.data_dir.empty()) {
if (meta_data.add_blank && config_.model.vits.data_dir.empty()) {
for (auto &k : x) {
k = AddBlank(k);
}
Expand Down Expand Up @@ -189,25 +193,33 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
private:
#if __ANDROID_API__ >= 9
void InitFrontend(AAssetManager *mgr) {
if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) {
const auto &meta_data = model_->GetMetaData();

if ((meta_data.is_piper || meta_data.is_coqui) &&
!config_.model.vits.data_dir.empty()) {
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
mgr, config_.model.vits.tokens, config_.model.vits.data_dir);
mgr, config_.model.vits.tokens, config_.model.vits.data_dir,
meta_data);
} else {
frontend_ = std::make_unique<Lexicon>(
mgr, config_.model.vits.lexicon, config_.model.vits.tokens,
model_->Punctuations(), model_->Language(), config_.model.debug);
meta_data.punctuations, meta_data.language, config_.model.debug);
}
}
#endif

void InitFrontend() {
if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) {
const auto &meta_data = model_->GetMetaData();

if ((meta_data.is_piper || meta_data.is_coqui) &&
!config_.model.vits.data_dir.empty()) {
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
config_.model.vits.tokens, config_.model.vits.data_dir);
config_.model.vits.tokens, config_.model.vits.data_dir,
model_->GetMetaData());
} else {
frontend_ = std::make_unique<Lexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
model_->Punctuations(), model_->Language(), config_.model.debug);
meta_data.punctuations, meta_data.language, config_.model.debug);
}
}

Expand Down Expand Up @@ -256,7 +268,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
const float *p = audio.GetTensorData<float>();

GeneratedAudio ans;
ans.sample_rate = model_->SampleRate();
ans.sample_rate = model_->GetMetaData().sample_rate;
ans.samples = std::vector<float>(p, p + total);
return ans;
}
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-tts-vits-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ bool OfflineTtsVitsModelConfig::Validate() const {

if (data_dir.empty()) {
if (lexicon.empty()) {
SHERPA_ONNX_LOGE("Please provide --vits-lexicon");
SHERPA_ONNX_LOGE(
"Please provide --vits-lexicon if you leave --vits-data-dir empty");
return false;
}

Expand Down
34 changes: 34 additions & 0 deletions sherpa-onnx/csrc/offline-tts-vits-model-metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// sherpa-onnx/csrc/offline-tts-vits-model-metadata.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_METADATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_METADATA_H_

#include <cstdint>
#include <string>

namespace sherpa_onnx {

struct OfflineTtsVitsModelMetaData {
int32_t sample_rate;
int32_t add_blank = 0;
int32_t num_speakers = 0;

std::string punctuations;
std::string language;
std::string voice;

bool is_piper = false;
bool is_coqui = false;

// the following options are for models from coqui-ai/TTS
int32_t blank_id = 0;
int32_t bos_id = 0;
int32_t eos_id = 0;
int32_t use_eos_bos = 0;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_METADATA_H_
87 changes: 42 additions & 45 deletions sherpa-onnx/csrc/offline-tts-vits-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,14 @@ class OfflineTtsVitsModel::Impl {
#endif

Ort::Value Run(Ort::Value x, int64_t sid, float speed) {
if (is_piper_) {
return RunVitsPiper(std::move(x), sid, speed);
if (meta_data_.is_piper || meta_data_.is_coqui) {
return RunVitsPiperOrCoqui(std::move(x), sid, speed);
}

return RunVits(std::move(x), sid, speed);
}

int32_t SampleRate() const { return sample_rate_; }

bool AddBlank() const { return add_blank_; }

std::string Punctuations() const { return punctuations_; }
std::string Language() const { return language_; }
std::string Voice() const { return voice_; }
bool IsPiper() const { return is_piper_; }
int32_t NumSpeakers() const { return num_speakers_; }
const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; }

private:
void Init(void *model_data, size_t model_data_length) {
Expand All @@ -70,27 +62,52 @@ class OfflineTtsVitsModel::Impl {
std::ostringstream os;
os << "---vits model---\n";
PrintModelMetadata(os, meta_data);

os << "----------input names----------\n";
int32_t i = 0;
for (const auto &s : input_names_) {
os << i << " " << s << "\n";
++i;
}
os << "----------output names----------\n";
i = 0;
for (const auto &s : output_names_) {
os << i << " " << s << "\n";
++i;
}

SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}

Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(add_blank_, "add_blank", 0);
SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(punctuations_, "punctuation",
"");
SHERPA_ONNX_READ_META_DATA_STR(language_, "language");
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(voice_, "voice", "");
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank",
0);
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers");
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations,
"punctuation", "");
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.voice, "voice", "");

SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.blank_id, "blank_id", 0);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.bos_id, "bos_id", 0);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.eos_id, "eos_id", 0);
SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_eos_bos,
"use_eos_bos", 0);

std::string comment;
SHERPA_ONNX_READ_META_DATA_STR(comment, "comment");
if (comment.find("piper") != std::string::npos ||
comment.find("coqui") != std::string::npos) {
is_piper_ = true;

if (comment.find("piper") != std::string::npos) {
meta_data_.is_piper = true;
}

if (comment.find("coqui") != std::string::npos) {
meta_data_.is_coqui = true;
}
}

Ort::Value RunVitsPiper(Ort::Value x, int64_t sid, float speed) {
Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

Expand Down Expand Up @@ -213,14 +230,7 @@ class OfflineTtsVitsModel::Impl {
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;

int32_t sample_rate_;
int32_t add_blank_;
int32_t num_speakers_;
std::string punctuations_;
std::string language_;
std::string voice_;

bool is_piper_ = false;
OfflineTtsVitsModelMetaData meta_data_;
};

OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config)
Expand All @@ -239,21 +249,8 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,
return impl_->Run(std::move(x), sid, speed);
}

int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); }

bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); }

std::string OfflineTtsVitsModel::Punctuations() const {
return impl_->Punctuations();
}

std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); }
std::string OfflineTtsVitsModel::Voice() const { return impl_->Voice(); }

bool OfflineTtsVitsModel::IsPiper() const { return impl_->IsPiper(); }

int32_t OfflineTtsVitsModel::NumSpeakers() const {
return impl_->NumSpeakers();
const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const {
return impl_->GetMetaData();
}

} // namespace sherpa_onnx
13 changes: 2 additions & 11 deletions sherpa-onnx/csrc/offline-tts-vits-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"

namespace sherpa_onnx {

Expand All @@ -39,17 +40,7 @@ class OfflineTtsVitsModel {
*/
Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0);

// Sample rate of the generated audio
int32_t SampleRate() const;

// true to insert a blank between each token
bool AddBlank() const;

std::string Punctuations() const;
std::string Language() const; // e.g., Chinese, English, German, etc.
std::string Voice() const; // e.g., en-us, for espeak-ng
bool IsPiper() const;
int32_t NumSpeakers() const;
const OfflineTtsVitsModelMetaData &GetMetaData() const;

private:
class Impl;
Expand Down
Loading

0 comments on commit 23cf92d

Please sign in to comment.