diff --git a/.github/scripts/baker_zh/TTS/run-matcha.sh b/.github/scripts/baker_zh/TTS/run-matcha.sh new file mode 100755 index 0000000000..150f023aef --- /dev/null +++ b/.github/scripts/baker_zh/TTS/run-matcha.sh @@ -0,0 +1,167 @@ +#!/usr/bin/env bash + +set -ex + +apt-get update +apt-get install -y sox + +python3 -m pip install numba conformer==0.3.2 diffusers librosa +python3 -m pip install jieba + + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/baker_zh/TTS + +sed -i.bak s/600/8/g ./prepare.sh +sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh +sed -i.bak s/500/5/g ./prepare.sh +git diff + +function prepare_data() { + # We have created a subset of the data for testing + # + mkdir -p download + pushd download + wget -q https://huggingface.co/csukuangfj/tmp-files/resolve/main/BZNSYP-samples.tar.bz2 + tar xvf BZNSYP-samples.tar.bz2 + mv BZNSYP-samples BZNSYP + rm BZNSYP-samples.tar.bz2 + popd + + ./prepare.sh + tree . +} + +function train() { + pushd ./matcha + sed -i.bak s/1500/3/g ./train.py + git diff . + popd + + ./matcha/train.py \ + --exp-dir matcha/exp \ + --num-epochs 1 \ + --save-every-n 1 \ + --num-buckets 2 \ + --tokens data/tokens.txt \ + --max-duration 20 + + ls -lh matcha/exp +} + +function infer() { + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + + ./matcha/infer.py \ + --num-buckets 2 \ + --epoch 1 \ + --exp-dir ./matcha/exp \ + --tokens data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json \ + --vocoder ./generator_v2 \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./generated.wav + + ls -lh *.wav + soxi ./generated.wav + rm -v ./generated.wav + rm -v generator_v2 +} + +function export_onnx() { + pushd matcha/exp + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/epoch-2000.pt + popd + + pushd data/fbank + rm -v *.json + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/cmvn.json + popd + + ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp \ + --epoch 2000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json + + ls -lh *.onnx + + if false; then + # The CI machine does not have enough memory to run it + # + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + python3 ./matcha/export_onnx_hifigan.py + else + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx + fi + + ls -lh *.onnx + + python3 ./matcha/generate_lexicon.py + + for v in v1 v2 v3; do + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_$v.onnx \ + --tokens ./data/tokens.txt \ + --lexicon ./lexicon.txt \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav /icefall/generated-matcha-tts-steps-6-$v.wav + done + + ls -lh /icefall/*.wav + soxi /icefall/generated-matcha-tts-steps-6-*.wav + cp ./model-steps-*.onnx /icefall + + d=matcha-icefall-zh-baker + mkdir $d + cp -v data/tokens.txt $d + cp -v lexicon.txt $d + cp model-steps-3.onnx $d + pushd $d + curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2 + tar xvf dict.tar.bz2 + rm dict.tar.bz2 + + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst + +cat >README.md < + +The training command is given below: +```bash +python3 ./matcha/train.py \ + --exp-dir ./matcha/exp-1/ \ + --num-workers 4 \ + --world-size 1 \ + --num-epochs 2000 \ + --max-duration 1200 \ + --bucketing-sampler 1 \ + --start-epoch 1 +``` + +To inference, use: + +```bash +# Download Hifigan vocoder. We use Hifigan v2 below. You can select from v1, v2, or v3 + +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + +python3 ./matcha/infer.py \ + --epoch 2000 \ + --exp-dir ./matcha/exp-1 \ + --vocoder ./generator_v2 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./generated.wav +``` + +```bash +soxi ./generated.wav +``` + +prints: +``` +Input File : './generated.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:17.31 = 381696 samples ~ 1298.29 CDDA sectors +File Size : 763k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` + +https://github.com/user-attachments/assets/88d4e88f-ebc4-4f32-b216-16d46b966024 + + +To export the checkpoint to onnx: +```bash +python3 ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp-1 \ + --epoch 2000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json +``` + +The above command generates the following files: +``` +-rw-r--r-- 1 kuangfangjun root 72M Dec 27 18:53 model-steps-2.onnx +-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-3.onnx +-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-4.onnx +-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:55 model-steps-5.onnx +-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:57 model-steps-6.onnx +``` + +where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver. + +**HINT**: If you get the following error while running `export_onnx.py`: + +``` +torch.onnx.errors.UnsupportedOperatorError: Exporting the operator +'aten::scaled_dot_product_attention' to ONNX opset version 14 is not supported. +``` + +please use `torch>=2.2.0`. + +To export the Hifigan vocoder to onnx, please use: + +```bash +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + +python3 ./matcha/export_onnx_hifigan.py +``` + +The above command generates 3 files: + + - hifigan_v1.onnx + - hifigan_v2.onnx + - hifigan_v3.onnx + +**HINT**: You can download pre-exported hifigan ONNX models from + + +To use the generated onnx files to generate speech from text, please run: + +```bash + +# First, generate ./lexicon.txt +python3 ./matcha/generate_lexicon.py + +python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-4.onnx \ + --vocoder ./hifigan_v2.onnx \ + --tokens ./data/tokens.txt \ + --lexicon ./lexicon.txt \ + --input-text "在一个阳光明媚的夏天,小马、小羊和小狗它们一块儿在广阔的草地上,嬉戏玩耍,这时小猴来了,还带着它心爱的足球活蹦乱跳地跑前、跑后教小马、小羊、小狗踢足球。" \ + --output-wav ./1.wav +``` + +```bash +soxi ./1.wav + +Input File : './1.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:16.37 = 360960 samples ~ 1227.76 CDDA sectors +File Size : 722k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` + +https://github.com/user-attachments/assets/578d04bb-fee8-47e5-9984-a868dcce610e + diff --git a/egs/baker_zh/TTS/local/audio.py b/egs/baker_zh/TTS/local/audio.py new file mode 120000 index 0000000000..b70d91c920 --- /dev/null +++ b/egs/baker_zh/TTS/local/audio.py @@ -0,0 +1 @@ +../matcha/audio.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py new file mode 100755 index 0000000000..0720158f27 --- /dev/null +++ b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the baker-zh dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, LilcomChunkyWriter, load_manifest +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=4, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + return parser + + +def compute_fbank_baker_zh(num_jobs: int): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + if num_jobs < 1: + num_jobs = os.cpu_count() + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + + prefix = "baker_zh" + suffix = "jsonl.gz" + + extractor = MatchaFbank(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts.{suffix}" + logging.info(f"Processing {cuts_filename}") + cut_set = load_manifest(src_dir / cuts_filename).resample(22050) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank_baker_zh(args.num_jobs) diff --git a/egs/baker_zh/TTS/local/compute_fbank_statistics.py b/egs/baker_zh/TTS/local/compute_fbank_statistics.py new file mode 120000 index 0000000000..fd1d8b52e1 --- /dev/null +++ b/egs/baker_zh/TTS/local/compute_fbank_statistics.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/compute_fbank_statistics.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/convert_text_to_tokens.py b/egs/baker_zh/TTS/local/convert_text_to_tokens.py new file mode 100755 index 0000000000..bf59cb466a --- /dev/null +++ b/egs/baker_zh/TTS/local/convert_text_to_tokens.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 + +import argparse +import re +from typing import List + +import jieba +from lhotse import load_manifest +from pypinyin import Style, lazy_pinyin, load_phrases_dict + +load_phrases_dict( + { + "行长": [["hang2"], ["zhang3"]], + "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], + } +) + +whiter_space_re = re.compile(r"\s+") + +punctuations_re = [ + (re.compile(x[0], re.IGNORECASE), x[1]) + for x in [ + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("“", '"'), + ("”", '"'), + ("‘", "'"), + ("’", "'"), + (":", ":"), + ("、", ","), + ("B", "逼"), + ("P", "批"), + ] +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--in-file", + type=str, + required=True, + help="Input cutset.", + ) + + parser.add_argument( + "--out-file", + type=str, + required=True, + help="Output cutset.", + ) + + return parser + + +def normalize_white_spaces(text): + return whiter_space_re.sub(" ", text) + + +def normalize_punctuations(text): + for regex, replacement in punctuations_re: + text = re.sub(regex, replacement, text) + return text + + +def split_text(text: str) -> List[str]: + """ + Example input: '你好呀,You are 一个好人。 去银行存钱?How about you?' + Example output: ['你好', '呀', ',', 'you are', '一个', '好人', '.', '去', '银行', '存钱', '?', 'how about you', '?'] + """ + text = text.lower() + text = normalize_white_spaces(text) + text = normalize_punctuations(text) + ans = [] + + for seg in jieba.cut(text): + if seg in ",.!?:\"'": + ans.append(seg) + elif seg == " " and len(ans) > 0: + if ord("a") <= ord(ans[-1][-1]) <= ord("z"): + ans[-1] += seg + elif ord("a") <= ord(seg[0]) <= ord("z"): + if len(ans) == 0: + ans.append(seg) + continue + + if ans[-1][-1] == " ": + ans[-1] += seg + continue + + ans.append(seg) + else: + ans.append(seg) + + ans = [s.strip() for s in ans] + return ans + + +def main(): + args = get_parser().parse_args() + cuts = load_manifest(args.in_file) + for c in cuts: + assert len(c.supervisions) == 1, (len(c.supervisions), c.supervisions) + text = c.supervisions[0].normalized_text + + text_list = split_text(text) + tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True) + + c.tokens = tokens + + cuts.to_file(args.out_file) + + print(f"saved to {args.out_file}") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/local/fbank.py b/egs/baker_zh/TTS/local/fbank.py new file mode 120000 index 0000000000..5bcf1fde57 --- /dev/null +++ b/egs/baker_zh/TTS/local/fbank.py @@ -0,0 +1 @@ +../matcha/fbank.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/generate_tokens.py b/egs/baker_zh/TTS/local/generate_tokens.py new file mode 100755 index 0000000000..b2abe1a71a --- /dev/null +++ b/egs/baker_zh/TTS/local/generate_tokens.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +""" +This file generates the file tokens.txt. + +Usage: + +python3 ./local/generate_tokens.py > data/tokens.txt +""" + + +import argparse +from typing import List + +import jieba +from pypinyin import Style, lazy_pinyin, pinyin_dict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to to save tokens.txt.", + ) + + return parser + + +def generate_token_list() -> List[str]: + token_set = set() + + word_dict = pinyin_dict.pinyin_dict + i = 0 + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] + token_set.add(t) + + no_digit = set() + for t in token_set: + if t[-1] not in "1234": + no_digit.add(t) + else: + no_digit.add(t[:-1]) + + no_digit.add("dei") + no_digit.add("tou") + no_digit.add("dia") + + for t in no_digit: + token_set.add(t) + for i in range(1, 5): + token_set.add(f"{t}{i}") + + ans = list(token_set) + ans.sort() + + punctuations = list(",.!?:\"'") + ans = punctuations + ans + + # use ID 0 for blank + # Use ID 1 of _ for padding + ans.insert(0, " ") + ans.insert(1, "_") # + + return ans + + +def main(): + args = get_parser().parse_args() + token_list = generate_token_list() + with open(args.tokens, "w", encoding="utf-8") as f: + for indx, token in enumerate(token_list): + f.write(f"{token} {indx}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/local/validate_manifest.py b/egs/baker_zh/TTS/local/validate_manifest.py new file mode 100755 index 0000000000..4e31028f7f --- /dev/null +++ b/egs/baker_zh/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/baker_zh_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/baker_zh/TTS/matcha/__init__.py b/egs/baker_zh/TTS/matcha/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/baker_zh/TTS/matcha/audio.py b/egs/baker_zh/TTS/matcha/audio.py new file mode 120000 index 0000000000..62d3959d66 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/audio.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/audio.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/export_onnx.py b/egs/baker_zh/TTS/matcha/export_onnx.py new file mode 100755 index 0000000000..28efbfe614 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/export_onnx.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script exports a Matcha-TTS model to ONNX. +Note that the model outputs fbank. You need to use a vocoder to convert +it to audio. See also ./export_onnx_hifigan.py + +python3 ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp-1 \ + --epoch 2000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json + +""" + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +import onnx +import torch +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=2000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp-new-3", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model, num_steps: int = 5): + super().__init__() + self.model = model + self.num_steps = num_steps + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + noise_scale: torch.Tensor, + length_scale: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + x: (batch_size, num_tokens), torch.int64 + x_lengths: (batch_size,), torch.int64 + noise_scale: (1,), torch.float32 + length_scale (1,), torch.float32 + Returns: + audio: (batch_size, num_samples) + + """ + mel = self.model.synthesise( + x=x, + x_lengths=x_lengths, + n_timesteps=self.num_steps, + temperature=noise_scale, + length_scale=length_scale, + )["mel"] + # mel: (batch_size, feat_dim, num_frames) + + return mel + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.pad_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + for num_steps in [2, 3, 4, 5, 6]: + logging.info(f"num_steps: {num_steps}") + wrapper = ModelWrapper(model, num_steps=num_steps) + wrapper.eval() + + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 1000, dtype=torch.int64) + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1.0]) + length_scale = torch.tensor([1.0]) + + opset_version = 14 + filename = f"model-steps-{num_steps}.onnx" + torch.onnx.export( + wrapper, + (x, x_lengths, noise_scale, length_scale), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "noise_scale", "length_scale"], + output_names=["mel"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_length": {0: "N"}, + "mel": {0: "N", 2: "L"}, + }, + ) + + meta_data = { + "model_type": "matcha-tts", + "language": "Chinese", + "has_espeak": 0, + "n_speakers": 1, + "jieba": 1, + "sample_rate": 22050, + "version": 1, + "pad_id": params.pad_id, + "model_author": "icefall", + "maintainer": "k2-fsa", + "dataset": "baker-zh", + "use_eos_bos": 0, + "dataset_url": "https://www.data-baker.com/open_source.html", + "dataset_comment": "The dataset is for non-commercial use only.", + "num_ode_steps": num_steps, + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py b/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py new file mode 120000 index 0000000000..d0b8af15bc --- /dev/null +++ b/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/export_onnx_hifigan.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/fbank.py b/egs/baker_zh/TTS/matcha/fbank.py new file mode 120000 index 0000000000..3cfb7fe3f4 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/fbank.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/fbank.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/generate_lexicon.py b/egs/baker_zh/TTS/matcha/generate_lexicon.py new file mode 100755 index 0000000000..f26f28e919 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/generate_lexicon.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +import jieba +from pypinyin import Style, lazy_pinyin, load_phrases_dict, phrases_dict, pinyin_dict +from tokenizer import Tokenizer + +load_phrases_dict( + { + "行长": [["hang2"], ["zhang3"]], + "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], + } +) + + +def main(): + filename = "lexicon.txt" + tokens = "./data/tokens.txt" + tokenizer = Tokenizer(tokens) + + word_dict = pinyin_dict.pinyin_dict + phrases = phrases_dict.phrases_dict + + i = 0 + with open(filename, "w", encoding="utf-8") as f: + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + tokens = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] + + f.write(f"{w} {tokens}\n") + + for key in phrases: + tokens = lazy_pinyin(key, style=Style.TONE3, tone_sandhi=True) + tokens = " ".join(tokens) + + f.write(f"{key} {tokens}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/matcha/hifigan b/egs/baker_zh/TTS/matcha/hifigan new file mode 120000 index 0000000000..c0a91072c0 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/hifigan @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/hifigan \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/infer.py b/egs/baker_zh/TTS/matcha/infer.py new file mode 100755 index 0000000000..b90c2fdbd8 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/infer.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +""" +python3 ./matcha/infer.py \ + --epoch 2000 \ + --exp-dir ./matcha/exp-1 \ + --vocoder ./generator_v2 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./generated.wav +""" + +import argparse +import datetime as dt +import json +import logging +from pathlib import Path + +import soundfile as sf +import torch +import torch.nn as nn +from hifigan.config import v1, v2, v3 +from hifigan.denoiser import Denoiser +from hifigan.models import Generator as HiFiGAN +from local.convert_text_to_tokens import split_text +from pypinyin import Style, lazy_pinyin +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import BakerZhTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + # The following arguments are used for inference on single text + parser.add_argument( + "--input-text", + type=str, + required=False, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=False, + help="The filename of the wave to save the generated speech", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampling rate of the generated speech (default: 22050 for baker_zh)", + ) + + return parser + + +def load_vocoder(checkpoint_path: Path) -> nn.Module: + checkpoint_path = str(checkpoint_path) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform( + mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module +) -> torch.Tensor: + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.squeeze() + + +def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: + text = split_text(text) + tokens = lazy_pinyin(text, style=Style.TONE3, tone_sandhi=True) + + x = tokenizer.texts_to_token_ids([tokens]) + x = torch.tensor(x, dtype=torch.long, device=device) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + return {"x_orig": text, "x": x, "x_lengths": x_lengths} + + +def synthesize( + model: nn.Module, + tokenizer: Tokenizer, + n_timesteps: int, + text: str, + length_scale: float, + temperature: float, + device: str = "cpu", + spks=None, +) -> dict: + text_processed = process_text(text=text, tokenizer=tokenizer, device=device) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + vocoder: nn.Module, + denoiser: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + texts = [c.supervisions[0].normalized_text for c in batch["cut"]] + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + for i in range(batch_size): + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=texts[i], + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav", + data=output["waveform"], + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav", + data=audio[i].numpy(), + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.inference_mode() +def main(): + parser = get_parser() + BakerZhTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + # Number of ODE Solver steps + params.n_timesteps = 2 + + # Changes to the speaking rate + params.length_scale = 1.0 + + # Sampling temperature + params.temperature = 0.667 + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + + # we need cut ids to organize tts results. + args.return_cuts = True + baker_zh = BakerZhTtsDataModule(args) + + test_cuts = baker_zh.test_cuts() + test_dl = baker_zh.test_dataloaders(test_cuts) + + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) + vocoder.to(device) + + denoiser = Denoiser(vocoder, mode="zeros") + denoiser.to(device) + + if params.input_text is not None and params.output_wav is not None: + logging.info("Synthesizing a single text") + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=params.input_text, + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.output_wav, + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16", + ) + else: + logging.info("Decoding the test set") + infer_dataset( + dl=test_dl, + params=params, + model=model, + vocoder=vocoder, + denoiser=denoiser, + tokenizer=tokenizer, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/matcha/model.py b/egs/baker_zh/TTS/matcha/model.py new file mode 120000 index 0000000000..8a1b812a94 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/model.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/model.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/models b/egs/baker_zh/TTS/matcha/models new file mode 120000 index 0000000000..09a862665f --- /dev/null +++ b/egs/baker_zh/TTS/matcha/models @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/models \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/monotonic_align b/egs/baker_zh/TTS/matcha/monotonic_align new file mode 120000 index 0000000000..d0a0dd6b5f --- /dev/null +++ b/egs/baker_zh/TTS/matcha/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/monotonic_align \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/onnx_pretrained.py b/egs/baker_zh/TTS/matcha/onnx_pretrained.py new file mode 100755 index 0000000000..f6b7f7caec --- /dev/null +++ b/egs/baker_zh/TTS/matcha/onnx_pretrained.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-4.onnx \ + --vocoder ./hifigan_v2.onnx \ + --tokens ./data/tokens.txt \ + --lexicon ./lexicon.txt \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./b.wav +""" + +import argparse +import datetime as dt +import logging +import re +from typing import Dict, List + +import jieba +import onnxruntime as ort +import soundfile as sf +import torch +from infer import load_vocoder +from utils import intersperse + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--acoustic-model", + type=str, + required=True, + help="Path to the acoustic model", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--lexicon", + type=str, + required=True, + help="Path to the lexicon.txt", + ) + + parser.add_argument( + "--vocoder", + type=str, + required=True, + help="Path to the vocoder", + ) + + parser.add_argument( + "--input-text", + type=str, + required=True, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=True, + help="The filename of the wave to save the generated speech", + ) + + return parser + + +class OnnxHifiGANModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 3, x.shape + assert x.shape[0] == 1, x.shape + + audio = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + )[0] + # audio: (batch_size, num_samples) + + return torch.from_numpy(audio) + + +class OnnxModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 2 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 2, x.shape + assert x.shape[0] == 1, x.shape + + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + print("x_lengths", x_lengths) + print("x", x.shape) + + noise_scale = torch.tensor([1.0], dtype=torch.float32) + length_scale = torch.tensor([1.0], dtype=torch.float32) + + mel = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lengths.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: length_scale.numpy(), + }, + )[0] + # mel: (batch_size, feat_dim, num_frames) + + return torch.from_numpy(mel) + + +def read_tokens(filename: str) -> Dict[str, int]: + token2id = dict() + with open(filename, encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + idx = int(info[0]) + else: + token, idx = info[0], int(info[1]) + assert token not in token2id, token + token2id[token] = idx + return token2id + + +def read_lexicon(filename: str) -> Dict[str, List[str]]: + word2token = dict() + with open(filename, encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + w = info[0] + tokens = info[1:] + word2token[w] = tokens + return word2token + + +def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]: + if word in word2tokens: + return word2tokens[word] + + if len(word) == 1: + return [] + + ans = [] + for w in word: + t = convert_word_to_tokens(word2tokens, w) + ans.extend(t) + return ans + + +def normalize_text(text): + whiter_space_re = re.compile(r"\s+") + + punctuations_re = [ + (re.compile(x[0], re.IGNORECASE), x[1]) + for x in [ + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("“", '"'), + ("”", '"'), + ("‘", "'"), + ("’", "'"), + (":", ":"), + ("、", ","), + ] + ] + + for regex, replacement in punctuations_re: + text = re.sub(regex, replacement, text) + return text + + +@torch.no_grad() +def main(): + params = get_parser().parse_args() + logging.info(vars(params)) + token2id = read_tokens(params.tokens) + word2tokens = read_lexicon(params.lexicon) + + text = normalize_text(params.input_text) + seg = jieba.cut(text) + tokens = [] + for s in seg: + if s in token2id: + tokens.append(s) + continue + + t = convert_word_to_tokens(word2tokens, s) + if t: + tokens.extend(t) + + model = OnnxModel(params.acoustic_model) + vocoder = OnnxHifiGANModel(params.vocoder) + + x = [] + for t in tokens: + if t in token2id: + x.append(token2id[t]) + + x = intersperse(x, item=token2id["_"]) + + x = torch.tensor(x, dtype=torch.int64).unsqueeze(0) + + start_t = dt.datetime.now() + mel = model(x) + end_t = dt.datetime.now() + + start_t2 = dt.datetime.now() + audio = vocoder(mel) + end_t2 = dt.datetime.now() + + print("audio", audio.shape) # (1, 1, num_samples) + audio = audio.squeeze() + + sample_rate = model.sample_rate + + t = (end_t - start_t).total_seconds() + t2 = (end_t2 - start_t2).total_seconds() + rtf_am = t * sample_rate / audio.shape[-1] + rtf_vocoder = t2 * sample_rate / audio.shape[-1] + print("RTF for acoustic model ", rtf_am) + print("RTF for vocoder", rtf_vocoder) + + # skip denoiser + sf.write(params.output_wav, audio, sample_rate, "PCM_16") + logging.info(f"Saved to {params.output_wav}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() + +""" + +|HifiGAN |RTF |#Parameters (M)| +|----------|-----|---------------| +|v1 |0.818| 13.926 | +|v2 |0.101| 0.925 | +|v3 |0.118| 1.462 | + +|Num steps|Acoustic Model RTF| +|---------|------------------| +| 2 | 0.039 | +| 3 | 0.047 | +| 4 | 0.071 | +| 5 | 0.076 | +| 6 | 0.103 | + +""" diff --git a/egs/baker_zh/TTS/matcha/tokenizer.py b/egs/baker_zh/TTS/matcha/tokenizer.py new file mode 100644 index 0000000000..dda82c29da --- /dev/null +++ b/egs/baker_zh/TTS/matcha/tokenizer.py @@ -0,0 +1,119 @@ +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import logging +from typing import Dict, List + +import tacotron_cleaner.cleaners + +try: + from piper_phonemize import phonemize_espeak +except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease run\n" + "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" + ) + +from utils import intersperse + + +# This tokenizer supports both English and Chinese. +# We assume you have used +# ../local/convert_text_to_tokens.py +# to process your text +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + assert token not in self.token2id, token + self.token2id[token] = id + + # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md + self.pad_id = self.token2id["_"] # padding + self.space_id = self.token2id[" "] # word separator (whitespace) + + self.vocab_size = len(self.token2id) + + def texts_to_token_ids( + self, + sentence_list: List[List[str]], + intersperse_blank: bool = True, + lang: str = "en-us", + ) -> List[List[int]]: + """ + Args: + sentence_list: + A list of sentences. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + lang: + Language argument passed to phonemize_espeak(). + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for sentence in sentence_list: + tokens_list = [] + for word in sentence: + if word in self.token2id: + tokens_list.append(word) + continue + + tmp_tokens_list = phonemize_espeak(word, lang) + for t in tmp_tokens_list: + tokens_list.extend(t) + + token_ids = [] + for t in tokens_list: + if t not in self.token2id: + logging.warning(f"Skip OOV {t} {sentence}") + continue + + if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id: + continue + + token_ids.append(self.token2id[t]) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.pad_id) + + token_ids_list.append(token_ids) + + return token_ids_list + + +def test_tokenizer(): + import jieba + from pypinyin import Style, lazy_pinyin + + tokenizer = Tokenizer("data/tokens.txt") + text1 = "今天is Monday, tomorrow is 星期二" + text2 = "你好吗? 我很好, how about you?" + + text1 = list(jieba.cut(text1)) + text2 = list(jieba.cut(text2)) + tokens1 = lazy_pinyin(text1, style=Style.TONE3, tone_sandhi=True) + tokens2 = lazy_pinyin(text2, style=Style.TONE3, tone_sandhi=True) + print(tokens1) + print(tokens2) + + ids = tokenizer.texts_to_token_ids([tokens1, tokens2]) + print(ids) + + +if __name__ == "__main__": + test_tokenizer() diff --git a/egs/baker_zh/TTS/matcha/train.py b/egs/baker_zh/TTS/matcha/train.py new file mode 100755 index 0000000000..ed2ba49b90 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/train.py @@ -0,0 +1,717 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + + +import argparse +import json +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.utils import fix_random_seed +from model import fix_len_compatibility +from models.matcha_tts import MatchaTTS +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import BakerZhTtsDataModule +from utils import MetricsTracker + +from icefall.checkpoint import load_checkpoint, save_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12335, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_data_statistics(): + return AttributeDict( + { + "mel_mean": 0, + "mel_std": 1, + } + ) + + +def _get_data_params() -> AttributeDict: + params = AttributeDict( + { + "name": "baker-zh", + "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", + "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", + # "batch_size": 64, + # "num_workers": 1, + # "pin_memory": False, + "cleaners": ["english_cleaners2"], + "add_blank": True, + "n_spks": 1, + "n_fft": 1024, + "n_feats": 80, + "sampling_rate": 22050, + "hop_length": 256, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + "seed": 1234, + "load_durations": False, + "data_statistics": get_data_statistics(), + } + ) + return params + + +def _get_model_params() -> AttributeDict: + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "n_spks": 1, # for baker-zh. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "data_statistics": get_data_statistics(), + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + + return params + + +def get_params(): + params = AttributeDict( + { + "model_args": _get_model_params(), + "data_args": _get_data_params(), + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 10, + "valid_interval": 1500, + "env_info": get_env_info(), + } + ) + return params + + +def get_model(params): + m = MatchaTTS(**params.model_args) + return m + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params): + """Parse batch data""" + mel_mean = params.data_args.data_statistics.mel_mean + mel_std_inv = 1 / params.data_args.data_statistics.mel_std + for i in range(batch["features"].shape[0]): + n = batch["features_lens"][i] + batch["features"][i : i + 1, :n, :] = ( + batch["features"][i : i + 1, :n, :] - mel_mean + ) * mel_std_inv + batch["features"][i : i + 1, n:, :] = 0 + + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.texts_to_token_ids(tokens, intersperse_blank=True) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + max_feature_length = fix_len_compatibility(features.shape[1]) + if max_feature_length > features.shape[1]: + pad = max_feature_length - features.shape[1] + features = torch.nn.functional.pad(features, (0, 0, 0, pad)) + + # features_lens[features_lens.argmax()] += pad + + return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long() + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + batch_size = len(batch["tokens"]) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + # summary stats + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer: Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer=optimizer, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + # audio: (N, T), float32 + # features: (N, T, C), float32 + # audio_lens, (N,), int32 + # features_lens, (N,), int32 + # tokens: List[List[str]], len(tokens) == N + + batch_size = len(batch["tokens"]) + + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + try: + with autocast(enabled=params.use_fp16): + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + loss = sum(losses.values()) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. + # The _growth_interval of the grad scaler is configurable, + # but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, " + f"batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + "Maximum memory allocated so far is " + f"{torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.pad_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + logging.info(params) + print(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of parameters: {num_param}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) + + logging.info("About to create datamodule") + + baker_zh = BakerZhTtsDataModule(args) + + train_cuts = baker_zh.train_cuts() + train_dl = baker_zh.train_dataloaders(train_cuts) + + valid_cuts = baker_zh.valid_cuts() + valid_dl = baker_zh.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + fix_random_seed(params.seed + epoch - 1) + if "sampler" in train_dl: + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer=optimizer, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + BakerZhTtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/baker_zh/TTS/matcha/tts_datamodule.py b/egs/baker_zh/TTS/matcha/tts_datamodule.py new file mode 100644 index 0000000000..d2bdfb96c5 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/tts_datamodule.py @@ -0,0 +1,340 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class BakerZhTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + pin_memory=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=True, + pin_memory=True, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz" + ) diff --git a/egs/baker_zh/TTS/matcha/utils.py b/egs/baker_zh/TTS/matcha/utils.py new file mode 120000 index 0000000000..ceaaea1963 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/utils.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/prepare.sh b/egs/baker_zh/TTS/prepare.sh new file mode 100755 index 0000000000..e15e3d8501 --- /dev/null +++ b/egs/baker_zh/TTS/prepare.sh @@ -0,0 +1,151 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download +mkdir -p $dl_dir + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib (used by ./matcha)" + for recipe in matcha; do + if [ ! -d $recipe/monotonic_align/build ]; then + cd $recipe/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib for $recipe already built" + fi + done +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # The directory $dl_dir/BANSYP contains the following 3 directories + + # ls -lh $dl_dir/BZNSYP/ + # total 0 + # drwxr-xr-x 10002 kuangfangjun root 0 Jan 4 2019 PhoneLabeling + # drwxr-xr-x 3 kuangfangjun root 0 Jan 31 2019 ProsodyLabeling + # drwxr-xr-x 10003 kuangfangjun root 0 Aug 26 17:45 Wave + + # If you have trouble accessing huggingface.co, please use + # + # cd $dl_dir + # wget https://huggingface.co/openspeech/BZNSYP/resolve/main/BZNSYP.tar.bz2 + # tar xf BZNSYP.tar.bz2 + # cd .. + + # If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink + # + # ln -sfv /path/to/BZNSYP $dl_dir/BZNSYP + # + if [ ! -d $dl_dir/BZNSYP/Wave ]; then + lhotse download baker-zh $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare baker-zh manifest" + # We assume that you have downloaded the baker corpus + # to $dl_dir/BZNSYP + mkdir -p data/manifests + if [ ! -e data/manifests/.baker-zh.done ]; then + lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests + touch data/manifests/.baker-zh.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Generate tokens.txt" + if [ ! -e data/tokens.txt ]; then + python3 ./local/generate_tokens.py --tokens data/tokens.txt + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Generate raw cutset" + if [ ! -e data/manifests/baker_zh_cuts_raw.jsonl.gz ]; then + lhotse cut simple \ + -r ./data/manifests/baker_zh_recordings_all.jsonl.gz \ + -s ./data/manifests/baker_zh_supervisions_all.jsonl.gz \ + ./data/manifests/baker_zh_cuts_raw.jsonl.gz + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Convert text to tokens" + if [ ! -e data/manifests/baker_zh_cuts.jsonl.gz ]; then + python3 ./local/convert_text_to_tokens.py \ + --in-file ./data/manifests/baker_zh_cuts_raw.jsonl.gz \ + --out-file ./data/manifests/baker_zh_cuts.jsonl.gz + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate fbank (used by ./matcha)" + mkdir -p data/fbank + if [ ! -e data/fbank/.baker-zh.done ]; then + ./local/compute_fbank_baker_zh.py + touch data/fbank/.baker-zh.done + fi + + if [ ! -e data/fbank/.baker-zh-validated.done ]; then + log "Validating data/fbank for baker-zh (used by ./matcha)" + python3 ./local/validate_manifest.py \ + data/fbank/baker_zh_cuts.jsonl.gz + touch data/fbank/.baker-zh-validated.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Split the baker-zh cuts into train, valid and test sets (used by ./matcha)" + if [ ! -e data/fbank/.baker_zh_split.done ]; then + lhotse subset --last 600 \ + data/fbank/baker_zh_cuts.jsonl.gz \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz \ + data/fbank/baker_zh_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz \ + data/fbank/baker_zh_cuts_test.jsonl.gz + + rm data/fbank/baker_zh_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/baker_zh_cuts.jsonl.gz | wc -l) - 600 )) + + lhotse subset --first $n \ + data/fbank/baker_zh_cuts.jsonl.gz \ + data/fbank/baker_zh_cuts_train.jsonl.gz + + touch data/fbank/.baker_zh_split.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 6: Compute fbank mean and std (used by ./matcha)" + if [ ! -f ./data/fbank/cmvn.json ]; then + ./local/compute_fbank_statistics.py ./data/fbank/baker_zh_cuts_train.jsonl.gz ./data/fbank/cmvn.json + fi +fi diff --git a/egs/baker_zh/TTS/shared b/egs/baker_zh/TTS/shared new file mode 120000 index 0000000000..4cbd91a7e9 --- /dev/null +++ b/egs/baker_zh/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 39280437b6..c9cfc22fd4 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -166,7 +166,7 @@ To export the checkpoint to onnx: --tokens ./data/tokens.txt ``` -The above command generate the following files: +The above command generates the following files: - model-steps-2.onnx - model-steps-3.onnx diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index 623517431c..39709cc36e 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -93,14 +93,14 @@ def forward( self, x: torch.Tensor, x_lengths: torch.Tensor, - temperature: torch.Tensor, + noise_scale: torch.Tensor, length_scale: torch.Tensor, ) -> torch.Tensor: """ Args: : x: (batch_size, num_tokens), torch.int64 x_lengths: (batch_size,), torch.int64 - temperature: (1,), torch.float32 + noise_scale: (1,), torch.float32 length_scale (1,), torch.float32 Returns: audio: (batch_size, num_samples) @@ -110,7 +110,7 @@ def forward( x=x, x_lengths=x_lengths, n_timesteps=self.num_steps, - temperature=temperature, + temperature=noise_scale, length_scale=length_scale, )["mel"] # mel: (batch_size, feat_dim, num_frames) @@ -127,7 +127,6 @@ def main(): params.update(vars(args)) tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size params.model_args.n_vocab = params.vocab_size @@ -153,14 +152,14 @@ def main(): # encoder has a large initial length x = torch.ones(1, 1000, dtype=torch.int64) x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - temperature = torch.tensor([1.0]) + noise_scale = torch.tensor([1.0]) length_scale = torch.tensor([1.0]) opset_version = 14 filename = f"model-steps-{num_steps}.onnx" torch.onnx.export( wrapper, - (x, x_lengths, temperature, length_scale), + (x, x_lengths, noise_scale, length_scale), filename, opset_version=opset_version, input_names=["x", "x_length", "noise_scale", "length_scale"], diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index 6d92b16ebf..19e9b49cb1 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -132,7 +132,7 @@ def __call__(self, x: torch.tensor): print("x_lengths", x_lengths) print("x", x.shape) - temperature = torch.tensor([1.0], dtype=torch.float32) + noise_scale = torch.tensor([1.0], dtype=torch.float32) length_scale = torch.tensor([1.0], dtype=torch.float32) mel = self.model.run( @@ -140,7 +140,7 @@ def __call__(self, x: torch.tensor): { self.model.get_inputs()[0].name: x.numpy(), self.model.get_inputs()[1].name: x_lengths.numpy(), - self.model.get_inputs()[2].name: temperature.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), self.model.get_inputs()[3].name: length_scale.numpy(), }, )[0]