From 6c9f0ac7c6bf7d8cc4defa9b840f3acc300178e3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 5 Dec 2023 19:06:14 +0800 Subject: [PATCH] Use espeak-ng for coqui-ai/TTS VITS English models. --- python-api-examples/offline-tts-play.py | 3 + sherpa-onnx/csrc/offline-tts-vits-impl.h | 34 ++++-- .../csrc/offline-tts-vits-model-config.cc | 3 +- .../csrc/offline-tts-vits-model-metadata.h | 34 ++++++ sherpa-onnx/csrc/offline-tts-vits-model.cc | 87 ++++++++------- sherpa-onnx/csrc/offline-tts-vits-model.h | 13 +-- sherpa-onnx/csrc/piper-phonemize-lexicon.cc | 102 +++++++++++++++--- sherpa-onnx/csrc/piper-phonemize-lexicon.h | 9 +- 8 files changed, 201 insertions(+), 84 deletions(-) create mode 100644 sherpa-onnx/csrc/offline-tts-vits-model-metadata.h diff --git a/python-api-examples/offline-tts-play.py b/python-api-examples/offline-tts-play.py index d205e7422..c01db51e5 100755 --- a/python-api-examples/offline-tts-play.py +++ b/python-api-examples/offline-tts-play.py @@ -311,6 +311,9 @@ def main(): if len(audio.samples) == 0: print("Error in generating audios. Please read previous error messages.") + global killed + killed = True + play_back_thread.join() return elapsed_seconds = end - start diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index bb6555700..f1c043204 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -69,12 +69,16 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { } #endif - int32_t SampleRate() const override { return model_->SampleRate(); } + int32_t SampleRate() const override { + return model_->GetMetaData().sample_rate; + } GeneratedAudio Generate( const std::string &_text, int64_t sid = 0, float speed = 1.0, GeneratedAudioCallback callback = nullptr) const override { - int32_t num_speakers = model_->NumSpeakers(); + const auto &meta_data = model_->GetMetaData(); + int32_t num_speakers = meta_data.num_speakers; + if (num_speakers == 0 && sid != 0) { SHERPA_ONNX_LOGE( "This is a single-speaker model and supports only sid 0. Given sid: " @@ -105,14 +109,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { } std::vector> x = - frontend_->ConvertTextToTokenIds(text, model_->Voice()); + frontend_->ConvertTextToTokenIds(text, meta_data.voice); if (x.empty() || (x.size() == 1 && x[0].empty())) { SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); return {}; } - if (model_->AddBlank() && config_.model.vits.data_dir.empty()) { + if (meta_data.add_blank && config_.model.vits.data_dir.empty()) { for (auto &k : x) { k = AddBlank(k); } @@ -189,25 +193,33 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { private: #if __ANDROID_API__ >= 9 void InitFrontend(AAssetManager *mgr) { - if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) { + const auto &meta_data = model_->GetMetaData(); + + if ((meta_data.is_piper || meta_data.is_coqui) && + !config_.model.vits.data_dir.empty()) { frontend_ = std::make_unique( - mgr, config_.model.vits.tokens, config_.model.vits.data_dir); + mgr, config_.model.vits.tokens, config_.model.vits.data_dir, + meta_data); } else { frontend_ = std::make_unique( mgr, config_.model.vits.lexicon, config_.model.vits.tokens, - model_->Punctuations(), model_->Language(), config_.model.debug); + meta_data.punctuations, meta_data.language, config_.model.debug); } } #endif void InitFrontend() { - if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) { + const auto &meta_data = model_->GetMetaData(); + + if ((meta_data.is_piper || meta_data.is_coqui) && + !config_.model.vits.data_dir.empty()) { frontend_ = std::make_unique( - config_.model.vits.tokens, config_.model.vits.data_dir); + config_.model.vits.tokens, config_.model.vits.data_dir, + model_->GetMetaData()); } else { frontend_ = std::make_unique( config_.model.vits.lexicon, config_.model.vits.tokens, - model_->Punctuations(), model_->Language(), config_.model.debug); + meta_data.punctuations, meta_data.language, config_.model.debug); } } @@ -256,7 +268,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { const float *p = audio.GetTensorData(); GeneratedAudio ans; - ans.sample_rate = model_->SampleRate(); + ans.sample_rate = model_->GetMetaData().sample_rate; ans.samples = std::vector(p, p + total); return ans; } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc index b9fce0f6b..22ccec354 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc @@ -46,7 +46,8 @@ bool OfflineTtsVitsModelConfig::Validate() const { if (data_dir.empty()) { if (lexicon.empty()) { - SHERPA_ONNX_LOGE("Please provide --vits-lexicon"); + SHERPA_ONNX_LOGE( + "Please provide --vits-lexicon if you leave --vits-data-dir empty"); return false; } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h b/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h new file mode 100644 index 000000000..9356519aa --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/offline-tts-vits-model-metadata.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_METADATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_METADATA_H_ + +#include +#include + +namespace sherpa_onnx { + +struct OfflineTtsVitsModelMetaData { + int32_t sample_rate; + int32_t add_blank = 0; + int32_t num_speakers = 0; + + std::string punctuations; + std::string language; + std::string voice; + + bool is_piper = false; + bool is_coqui = false; + + // the following options are for models from coqui-ai/TTS + int32_t blank_id = 0; + int32_t bos_id = 0; + int32_t eos_id = 0; + int32_t use_eos_bos = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_METADATA_H_ diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index 31e3a7c31..b0604a6b5 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -38,22 +38,14 @@ class OfflineTtsVitsModel::Impl { #endif Ort::Value Run(Ort::Value x, int64_t sid, float speed) { - if (is_piper_) { - return RunVitsPiper(std::move(x), sid, speed); + if (meta_data_.is_piper || meta_data_.is_coqui) { + return RunVitsPiperOrCoqui(std::move(x), sid, speed); } return RunVits(std::move(x), sid, speed); } - int32_t SampleRate() const { return sample_rate_; } - - bool AddBlank() const { return add_blank_; } - - std::string Punctuations() const { return punctuations_; } - std::string Language() const { return language_; } - std::string Voice() const { return voice_; } - bool IsPiper() const { return is_piper_; } - int32_t NumSpeakers() const { return num_speakers_; } + const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; } private: void Init(void *model_data, size_t model_data_length) { @@ -70,27 +62,52 @@ class OfflineTtsVitsModel::Impl { std::ostringstream os; os << "---vits model---\n"; PrintModelMetadata(os, meta_data); + + os << "----------input names----------\n"; + int32_t i = 0; + for (const auto &s : input_names_) { + os << i << " " << s << "\n"; + ++i; + } + os << "----------output names----------\n"; + i = 0; + for (const auto &s : output_names_) { + os << i << " " << s << "\n"; + ++i; + } + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); } Ort::AllocatorWithDefaultOptions allocator; // used in the macro below - SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate"); - SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(add_blank_, "add_blank", 0); - SHERPA_ONNX_READ_META_DATA(num_speakers_, "n_speakers"); - SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(punctuations_, "punctuation", - ""); - SHERPA_ONNX_READ_META_DATA_STR(language_, "language"); - SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(voice_, "voice", ""); + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank", + 0); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers"); + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations, + "punctuation", ""); + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.voice, "voice", ""); + + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.blank_id, "blank_id", 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.bos_id, "bos_id", 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.eos_id, "eos_id", 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_eos_bos, + "use_eos_bos", 0); std::string comment; SHERPA_ONNX_READ_META_DATA_STR(comment, "comment"); - if (comment.find("piper") != std::string::npos || - comment.find("coqui") != std::string::npos) { - is_piper_ = true; + + if (comment.find("piper") != std::string::npos) { + meta_data_.is_piper = true; + } + + if (comment.find("coqui") != std::string::npos) { + meta_data_.is_coqui = true; } } - Ort::Value RunVitsPiper(Ort::Value x, int64_t sid, float speed) { + Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -213,14 +230,7 @@ class OfflineTtsVitsModel::Impl { std::vector output_names_; std::vector output_names_ptr_; - int32_t sample_rate_; - int32_t add_blank_; - int32_t num_speakers_; - std::string punctuations_; - std::string language_; - std::string voice_; - - bool is_piper_ = false; + OfflineTtsVitsModelMetaData meta_data_; }; OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) @@ -239,21 +249,8 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, return impl_->Run(std::move(x), sid, speed); } -int32_t OfflineTtsVitsModel::SampleRate() const { return impl_->SampleRate(); } - -bool OfflineTtsVitsModel::AddBlank() const { return impl_->AddBlank(); } - -std::string OfflineTtsVitsModel::Punctuations() const { - return impl_->Punctuations(); -} - -std::string OfflineTtsVitsModel::Language() const { return impl_->Language(); } -std::string OfflineTtsVitsModel::Voice() const { return impl_->Voice(); } - -bool OfflineTtsVitsModel::IsPiper() const { return impl_->IsPiper(); } - -int32_t OfflineTtsVitsModel::NumSpeakers() const { - return impl_->NumSpeakers(); +const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const { + return impl_->GetMetaData(); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.h b/sherpa-onnx/csrc/offline-tts-vits-model.h index 7708144c6..7d51efa2c 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model.h @@ -15,6 +15,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/offline-tts-model-config.h" +#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" namespace sherpa_onnx { @@ -39,17 +40,7 @@ class OfflineTtsVitsModel { */ Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0); - // Sample rate of the generated audio - int32_t SampleRate() const; - - // true to insert a blank between each token - bool AddBlank() const; - - std::string Punctuations() const; - std::string Language() const; // e.g., Chinese, English, German, etc. - std::string Voice() const; // e.g., en-us, for espeak-ng - bool IsPiper() const; - int32_t NumSpeakers() const; + const OfflineTtsVitsModelMetaData &GetMetaData() const; private: class Impl; diff --git a/sherpa-onnx/csrc/piper-phonemize-lexicon.cc b/sherpa-onnx/csrc/piper-phonemize-lexicon.cc index 857f5d192..c27d92a11 100644 --- a/sherpa-onnx/csrc/piper-phonemize-lexicon.cc +++ b/sherpa-onnx/csrc/piper-phonemize-lexicon.cc @@ -57,10 +57,17 @@ static std::unordered_map ReadTokens(std::istream &is) { s = conv.from_bytes(sym); if (s.size() != 1) { + // for tokens.txt from coqui-ai/TTS, the last token is + if (s.size() == 6 && s[0] == '<' && s[1] == 'B' && s[2] == 'L' && + s[3] == 'N' && s[4] == 'K' && s[5] == '>') { + continue; + } + SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d", line.c_str(), static_cast(s.size())); exit(-1); } + char32_t c = s[0]; if (token2id.count(c)) { @@ -77,7 +84,7 @@ static std::unordered_map ReadTokens(std::istream &is) { // see the function "phonemes_to_ids" from // https://github.com/rhasspy/piper/blob/master/notebooks/piper_inference_(ONNX).ipynb -static std::vector PhonemesToIds( +static std::vector PiperPhonemesToIds( const std::unordered_map &token2id, const std::vector &phonemes) { // see @@ -104,6 +111,60 @@ static std::vector PhonemesToIds( return ans; } +static std::vector CoquiPhonemesToIds( + const std::unordered_map &token2id, + const std::vector &phonemes, + const OfflineTtsVitsModelMetaData &meta_data) { + // see + // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87 + int32_t use_eos_bos = meta_data.use_eos_bos; + int32_t bos_id = meta_data.bos_id; + int32_t eos_id = meta_data.eos_id; + int32_t blank_id = meta_data.blank_id; + int32_t add_blank = meta_data.add_blank; + + std::vector ans; + if (add_blank) { + ans.reserve(phonemes.size() * 2 + 3); + } else { + ans.reserve(phonemes.size() + 2); + } + + if (use_eos_bos) { + ans.push_back(bos_id); + } + + if (add_blank) { + ans.push_back(blank_id); + + for (auto p : phonemes) { + if (token2id.count(p)) { + ans.push_back(token2id.at(p)); + ans.push_back(blank_id); + } else { + SHERPA_ONNX_LOGE("Skip unknown phonemes. Unicode codepoint: \\U+%04x.", + static_cast(p)); + } + } + } else { + // not adding blank + for (auto p : phonemes) { + if (token2id.count(p)) { + ans.push_back(token2id.at(p)); + } else { + SHERPA_ONNX_LOGE("Skip unknown phonemes. Unicode codepoint: \\U+%04x.", + static_cast(p)); + } + } + } + + if (use_eos_bos) { + ans.push_back(eos_id); + } + + return ans; +} + void InitEspeak(const std::string &data_dir) { static std::once_flag init_flag; std::call_once(init_flag, [data_dir]() { @@ -119,21 +180,23 @@ void InitEspeak(const std::string &data_dir) { }); } -PiperPhonemizeLexicon::PiperPhonemizeLexicon(const std::string &tokens, - const std::string &data_dir) - : data_dir_(data_dir) { +PiperPhonemizeLexicon::PiperPhonemizeLexicon( + const std::string &tokens, const std::string &data_dir, + const OfflineTtsVitsModelMetaData &meta_data) + : meta_data_(meta_data) { { std::ifstream is(tokens); token2id_ = ReadTokens(is); } - InitEspeak(data_dir_); + InitEspeak(data_dir); } #if __ANDROID_API__ >= 9 -PiperPhonemizeLexicon::PiperPhonemizeLexicon(AAssetManager *mgr, - const std::string &tokens, - const std::string &data_dir) { +PiperPhonemizeLexicon::PiperPhonemizeLexicon( + AAssetManager *mgr, const std::string &tokens, const std::string &data_dir, + const OfflineTtsVitsModelMetaData &meta_data) + : meta_data_(meta_data) { { auto buf = ReadFile(mgr, tokens); std::istrstream is(buf.data(), buf.size()); @@ -141,8 +204,9 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon(AAssetManager *mgr, } // We should copy the directory of espeak-ng-data from the asset to - // some internal or external storage and then pass the directory to data_dir. - InitEspeak(data_dir_); + // some internal or external storage and then pass the directory to + // data_dir. + InitEspeak(data_dir); } #endif @@ -160,9 +224,21 @@ std::vector> PiperPhonemizeLexicon::ConvertTextToTokenIds( std::vector> ans; std::vector phoneme_ids; - for (const auto &p : phonemes) { - phoneme_ids = PhonemesToIds(token2id_, p); - ans.push_back(std::move(phoneme_ids)); + + if (meta_data_.is_piper) { + for (const auto &p : phonemes) { + phoneme_ids = PiperPhonemesToIds(token2id_, p); + ans.push_back(std::move(phoneme_ids)); + } + } else if (meta_data_.is_coqui) { + for (const auto &p : phonemes) { + phoneme_ids = CoquiPhonemesToIds(token2id_, p, meta_data_); + ans.push_back(std::move(phoneme_ids)); + } + + } else { + SHERPA_ONNX_LOGE("Unsupported model"); + exit(-1); } return ans; diff --git a/sherpa-onnx/csrc/piper-phonemize-lexicon.h b/sherpa-onnx/csrc/piper-phonemize-lexicon.h index d2cdad2a8..842d80e0c 100644 --- a/sherpa-onnx/csrc/piper-phonemize-lexicon.h +++ b/sherpa-onnx/csrc/piper-phonemize-lexicon.h @@ -15,25 +15,28 @@ #endif #include "sherpa-onnx/csrc/offline-tts-frontend.h" +#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" namespace sherpa_onnx { class PiperPhonemizeLexicon : public OfflineTtsFrontend { public: - PiperPhonemizeLexicon(const std::string &tokens, const std::string &data_dir); + PiperPhonemizeLexicon(const std::string &tokens, const std::string &data_dir, + const OfflineTtsVitsModelMetaData &meta_data); #if __ANDROID_API__ >= 9 PiperPhonemizeLexicon(AAssetManager *mgr, const std::string &tokens, - const std::string &data_dir); + const std::string &data_dir, + const OfflineTtsVitsModelMetaData &meta_data); #endif std::vector> ConvertTextToTokenIds( const std::string &text, const std::string &voice = "") const override; private: - std::string data_dir_; // map unicode codepoint to an integer ID std::unordered_map token2id_; + OfflineTtsVitsModelMetaData meta_data_; }; } // namespace sherpa_onnx