Skip to content

Commit

Permalink
clean infer codes
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaiz committed Dec 24, 2024
1 parent 3ba6feb commit 03d500a
Show file tree
Hide file tree
Showing 5 changed files with 638 additions and 165 deletions.
Original file line number Diff line number Diff line change
@@ -1,45 +1,81 @@
import argparse
import logging
import math
import os
import random
import time
from pathlib import Path

# import bigvan
# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
import torch
import torch.nn.functional as F
import torchaudio
from accelerate import Accelerator

# from importlib.resources import files
# import sys
# sys.path.append(f"/home/yuekaiz/BigVGAN/")
# from bigvgan import BigVGAN
from bigvganinference import BigVGANInference

# from f5_tts.eval.utils_eval import (
# get_inference_prompt,
# get_librispeech_test_clean_metainfo,
# get_seedtts_testset_metainfo,
# )
# from f5_tts.infer.utils_infer import load_vocoder
from model.cfm import CFM
from model.dit import DiT
from model.modules import MelSpec
from model.utils import convert_char_to_pinyin
from tqdm import tqdm
from train import get_tokenizer, load_pretrained_checkpoint
from train import (
add_model_arguments,
get_model,
get_tokenizer,
load_F5_TTS_pretrained_checkpoint,
)

from icefall.checkpoint import load_checkpoint


def load_vocoder(device):
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
model = BigVGANInference.from_pretrained(
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--tokens",
type=str,
default="f5-tts/vocab.txt",
help="Path to the unique text tokens file",
)

parser.add_argument(
"--model-path",
type=str,
default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt",
help="Path to the unique text tokens file",
)

parser.add_argument(
"--seed",
type=int,
default=0,
help="The seed for random generators intended for reproducibility",
)

parser.add_argument(
"--nfe",
type=int,
default=16,
help="The number of steps for the neural ODE",
)

parser.add_argument(
"--manifest-file",
type=str,
default="/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst",
help="The manifest file in seed_tts_eval format",
)
model = model.eval().to(device)
return model

parser.add_argument(
"--output-dir",
type=Path,
default="results",
help="The output directory to save the generated wavs",
)

parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
add_model_arguments(parser)
return parser.parse_args()


def get_inference_prompt(
Expand All @@ -52,7 +88,7 @@ def get_inference_prompt(
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="vocos",
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
Expand Down Expand Up @@ -209,151 +245,54 @@ def get_seedtts_testset_metainfo(metalst):
f.close()
metainfo = []
for line in lines:
if len(line.strip().split("|")) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
elif len(line.strip().split("|")) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
return metainfo


accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"


# --------------------- Dataset Settings -------------------- #

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
target_rms = 0.1

# rel_path = str(files("f5_tts").joinpath("../../"))


def main():
# ---------------------- infer setting ---------------------- #

parser = argparse.ArgumentParser(description="batch inference")

parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=15000, type=int)
parser.add_argument(
"-m",
"--mel_spec_type",
default="bigvgan",
type=str,
choices=["bigvgan", "vocos"],
)
parser.add_argument(
"-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]
)

parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)

parser.add_argument("-t", "--testset", required=True)

args = parser.parse_args()

seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep

ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"
ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt"
args = get_parser()

mel_spec_type = args.mel_spec_type
tokenizer = args.tokenizer

nfe_step = args.nfestep
ode_method = args.odemethod
sway_sampling_coef = args.swaysampling

testset = args.testset

infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.0
speed = 1.0
use_truth_duration = False
no_ref_audio = False

model_cls = DiT
model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst"
metainfo = get_seedtts_testset_metainfo(metalst)

# path to save genereted wavs
output_dir = (
f"./"
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
f"{'_no-ref-audio' if no_ref_audio else ''}"
)
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"

metainfo = get_seedtts_testset_metainfo(args.manifest_file)
prompts_all = get_inference_prompt(
metainfo,
speed=speed,
tokenizer=tokenizer,
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
mel_spec_type=mel_spec_type,
target_rms=target_rms,
use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
)

vocoder = load_vocoder(device)

# Tokenizer
vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt")

# Model
model = CFM(
transformer=model_cls(
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)

dtype = torch.float32 if mel_spec_type == "bigvgan" else None
# model = load_pretrained_checkpoint(model, ckpt_path)
_ = load_checkpoint(
ckpt_path,
model=model,
vocoder = BigVGANInference.from_pretrained(
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
)
model = model.eval().to(device)
vocoder = vocoder.eval().to(device)

model = get_model(args).eval().to(device)
checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=True)

if "model_state_dict" or "ema_model_state_dict" in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
else:
_ = load_checkpoint(
args.model_path,
model=model,
)

if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
os.makedirs(args.output_dir, exist_ok=True)

# start batch inference
accelerator.wait_for_everyone()
start = time.time()

Expand All @@ -378,25 +317,23 @@ def main():
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
no_ref_audio=no_ref_audio,
seed=seed,
steps=args.nfe,
cfg_strength=2.0,
sway_sampling_coef=args.swaysampling,
no_ref_audio=False,
seed=args.seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec).cpu()
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()

generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(
f"{output_dir}/{utts[i]}.wav",
f"{args.output_dir}/{utts[i]}.wav",
generated_wave,
target_sample_rate,
)
Expand All @@ -408,4 +345,6 @@ def main():


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
4 changes: 2 additions & 2 deletions egs/wenetspeech4tts/TTS/f5-tts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def get_model(params):
return model


def load_pretrained_checkpoint(
def load_F5_TTS_pretrained_checkpoint(
model, ckpt_path, device: str = "cpu", dtype=torch.float32
):
# model = model.to(dtype)
Expand Down Expand Up @@ -937,7 +937,7 @@ def run(rank, world_size, args):
logging.info("About to create model")

model = get_model(params)
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
# model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
model = model.to(device)

with open(f"{params.exp_dir}/model.txt", "w") as f:
Expand Down
14 changes: 13 additions & 1 deletion egs/wenetspeech4tts/TTS/infer_f5.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
#bigvganinference
model_path=/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
manifest=/home/yuekaiz/HF/valle_wenetspeech4tts_demo/wenetspeech4tts.txt
manifest=/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst
# get wenetspeech4tts
manifest_base_stem=$(basename $manifest)
mainfest_base_stem=${manifest_base_stem%.*}
output_dir=./results/f5-tts-pretrained/$mainfest_base_stem

accelerate launch f5-tts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16

pip install sherpa-onnx bigvganinference lhotse kaldialign sentencepiece
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir || exit 1

bash local/compute_wer.sh $output_dir $manifest
Loading

0 comments on commit 03d500a

Please sign in to comment.