From 03d500a414dd4ad95ca31b60f8eb088b991b53bd Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Tue, 24 Dec 2024 15:06:21 +0800 Subject: [PATCH] clean infer codes --- .../f5-tts/{eval_infer_batch.py => infer.py} | 263 ++++------ egs/wenetspeech4tts/TTS/f5-tts/train.py | 4 +- egs/wenetspeech4tts/TTS/infer_f5.sh | 14 +- egs/wenetspeech4tts/TTS/local/compute_wer.sh | 27 + .../TTS/local/offline-decode-files.py | 495 ++++++++++++++++++ 5 files changed, 638 insertions(+), 165 deletions(-) rename egs/wenetspeech4tts/TTS/f5-tts/{eval_infer_batch.py => infer.py} (60%) create mode 100644 egs/wenetspeech4tts/TTS/local/compute_wer.sh create mode 100755 egs/wenetspeech4tts/TTS/local/offline-decode-files.py diff --git a/egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py similarity index 60% rename from egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py rename to egs/wenetspeech4tts/TTS/f5-tts/infer.py index d70df9626a..4db628a660 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/eval_infer_batch.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -1,45 +1,81 @@ import argparse +import logging import math import os import random import time +from pathlib import Path -# import bigvan -# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/") import torch import torch.nn.functional as F import torchaudio from accelerate import Accelerator - -# from importlib.resources import files -# import sys -# sys.path.append(f"/home/yuekaiz/BigVGAN/") -# from bigvgan import BigVGAN from bigvganinference import BigVGANInference - -# from f5_tts.eval.utils_eval import ( -# get_inference_prompt, -# get_librispeech_test_clean_metainfo, -# get_seedtts_testset_metainfo, -# ) -# from f5_tts.infer.utils_infer import load_vocoder from model.cfm import CFM from model.dit import DiT from model.modules import MelSpec from model.utils import convert_char_to_pinyin from tqdm import tqdm -from train import get_tokenizer, load_pretrained_checkpoint +from train import ( + add_model_arguments, + get_model, + get_tokenizer, + load_F5_TTS_pretrained_checkpoint, +) from icefall.checkpoint import load_checkpoint -def load_vocoder(device): - # huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x - model = BigVGANInference.from_pretrained( - "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + default="f5-tts/vocab.txt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--model-path", + type=str, + default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--nfe", + type=int, + default=16, + help="The number of steps for the neural ODE", + ) + + parser.add_argument( + "--manifest-file", + type=str, + default="/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst", + help="The manifest file in seed_tts_eval format", ) - model = model.eval().to(device) - return model + + parser.add_argument( + "--output-dir", + type=Path, + default="results", + help="The output directory to save the generated wavs", + ) + + parser.add_argument("-ss", "--swaysampling", default=-1, type=float) + add_model_arguments(parser) + return parser.parse_args() def get_inference_prompt( @@ -52,7 +88,7 @@ def get_inference_prompt( win_length=1024, n_mel_channels=100, hop_length=256, - mel_spec_type="vocos", + mel_spec_type="bigvgan", target_rms=0.1, use_truth_duration=False, infer_batch_size=1, @@ -209,151 +245,54 @@ def get_seedtts_testset_metainfo(metalst): f.close() metainfo = [] for line in lines: - if len(line.strip().split("|")) == 5: - utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|") - elif len(line.strip().split("|")) == 4: - utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") - gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") if not os.path.isabs(prompt_wav): prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) return metainfo -accelerator = Accelerator() -device = f"cuda:{accelerator.process_index}" - - -# --------------------- Dataset Settings -------------------- # - -target_sample_rate = 24000 -n_mel_channels = 100 -hop_length = 256 -win_length = 1024 -n_fft = 1024 -target_rms = 0.1 - -# rel_path = str(files("f5_tts").joinpath("../../")) - - def main(): - # ---------------------- infer setting ---------------------- # - - parser = argparse.ArgumentParser(description="batch inference") - - parser.add_argument("-s", "--seed", default=None, type=int) - parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN") - parser.add_argument("-n", "--expname", required=True) - parser.add_argument("-c", "--ckptstep", default=15000, type=int) - parser.add_argument( - "-m", - "--mel_spec_type", - default="bigvgan", - type=str, - choices=["bigvgan", "vocos"], - ) - parser.add_argument( - "-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"] - ) - - parser.add_argument("-nfe", "--nfestep", default=32, type=int) - parser.add_argument("-o", "--odemethod", default="euler") - parser.add_argument("-ss", "--swaysampling", default=-1, type=float) - - parser.add_argument("-t", "--testset", required=True) - - args = parser.parse_args() - - seed = args.seed - dataset_name = args.dataset - exp_name = args.expname - ckpt_step = args.ckptstep - - ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt" - ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt" + args = get_parser() - mel_spec_type = args.mel_spec_type - tokenizer = args.tokenizer - - nfe_step = args.nfestep - ode_method = args.odemethod - sway_sampling_coef = args.swaysampling - - testset = args.testset - - infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended) - cfg_strength = 2.0 - speed = 1.0 - use_truth_duration = False - no_ref_audio = False - - model_cls = DiT - model_cfg = dict( - dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 - ) - metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst" - metainfo = get_seedtts_testset_metainfo(metalst) - - # path to save genereted wavs - output_dir = ( - f"./" - f"results/{exp_name}_{ckpt_step}/{testset}/" - f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}" - f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" - f"_cfg{cfg_strength}_speed{speed}" - f"{'_gt-dur' if use_truth_duration else ''}" - f"{'_no-ref-audio' if no_ref_audio else ''}" - ) + accelerator = Accelerator() + device = f"cuda:{accelerator.process_index}" + metainfo = get_seedtts_testset_metainfo(args.manifest_file) prompts_all = get_inference_prompt( metainfo, - speed=speed, - tokenizer=tokenizer, - target_sample_rate=target_sample_rate, - n_mel_channels=n_mel_channels, - hop_length=hop_length, - mel_spec_type=mel_spec_type, - target_rms=target_rms, - use_truth_duration=use_truth_duration, - infer_batch_size=infer_batch_size, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, ) - vocoder = load_vocoder(device) - - # Tokenizer - vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt") - - # Model - model = CFM( - transformer=model_cls( - **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels - ), - mel_spec_kwargs=dict( - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mel_channels=n_mel_channels, - target_sample_rate=target_sample_rate, - mel_spec_type=mel_spec_type, - ), - odeint_kwargs=dict( - method=ode_method, - ), - vocab_char_map=vocab_char_map, - ).to(device) - - dtype = torch.float32 if mel_spec_type == "bigvgan" else None - # model = load_pretrained_checkpoint(model, ckpt_path) - _ = load_checkpoint( - ckpt_path, - model=model, + vocoder = BigVGANInference.from_pretrained( + "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False ) - model = model.eval().to(device) + vocoder = vocoder.eval().to(device) + + model = get_model(args).eval().to(device) + checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=True) + + if "model_state_dict" or "ema_model_state_dict" in checkpoint: + model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) + else: + _ = load_checkpoint( + args.model_path, + model=model, + ) - if not os.path.exists(output_dir) and accelerator.is_main_process: - os.makedirs(output_dir) + os.makedirs(args.output_dir, exist_ok=True) - # start batch inference accelerator.wait_for_everyone() start = time.time() @@ -378,25 +317,23 @@ def main(): text=final_text_list, duration=total_mel_lens, lens=ref_mel_lens, - steps=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - no_ref_audio=no_ref_audio, - seed=seed, + steps=args.nfe, + cfg_strength=2.0, + sway_sampling_coef=args.swaysampling, + no_ref_audio=False, + seed=args.seed, ) - # Final result for i, gen in enumerate(generated): gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) - if mel_spec_type == "vocos": - generated_wave = vocoder.decode(gen_mel_spec).cpu() - elif mel_spec_type == "bigvgan": - generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + target_rms = 0.1 + target_sample_rate = 24_000 if ref_rms_list[i] < target_rms: generated_wave = generated_wave * ref_rms_list[i] / target_rms torchaudio.save( - f"{output_dir}/{utts[i]}.wav", + f"{args.output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate, ) @@ -408,4 +345,6 @@ def main(): if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 3009235c4f..f6a0ce0e67 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -409,7 +409,7 @@ def get_model(params): return model -def load_pretrained_checkpoint( +def load_F5_TTS_pretrained_checkpoint( model, ckpt_path, device: str = "cpu", dtype=torch.float32 ): # model = model.to(dtype) @@ -937,7 +937,7 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_model(params) - # model = load_pretrained_checkpoint(model, params.pretrained_model_path) + # model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path) model = model.to(device) with open(f"{params.exp_dir}/model.txt", "w") as f: diff --git a/egs/wenetspeech4tts/TTS/infer_f5.sh b/egs/wenetspeech4tts/TTS/infer_f5.sh index a2decbd787..eee412e5a4 100644 --- a/egs/wenetspeech4tts/TTS/infer_f5.sh +++ b/egs/wenetspeech4tts/TTS/infer_f5.sh @@ -1,3 +1,15 @@ export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha +#bigvganinference +model_path=/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt +manifest=/home/yuekaiz/HF/valle_wenetspeech4tts_demo/wenetspeech4tts.txt +manifest=/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst +# get wenetspeech4tts +manifest_base_stem=$(basename $manifest) +mainfest_base_stem=${manifest_base_stem%.*} +output_dir=./results/f5-tts-pretrained/$mainfest_base_stem -accelerate launch f5-tts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 + +pip install sherpa-onnx bigvganinference lhotse kaldialign sentencepiece +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir || exit 1 + +bash local/compute_wer.sh $output_dir $manifest diff --git a/egs/wenetspeech4tts/TTS/local/compute_wer.sh b/egs/wenetspeech4tts/TTS/local/compute_wer.sh new file mode 100644 index 0000000000..2a214cd676 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/compute_wer.sh @@ -0,0 +1,27 @@ +wav_dir=$1 +wav_files=$(ls $wav_dir/*.wav) +# wav_files=$(echo $wav_files | cut -d " " -f 1) +# if wav_files is empty, then exit +if [ -z "$wav_files" ]; then + exit 1 +fi +label_file=$2 +model_path=local/sherpa-onnx-paraformer-zh-2023-09-14 + +if [ ! -d $model_path ]; then + pip install sherpa-onnx + wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local +fi + +python3 local/offline-decode-files.py \ + --tokens=$model_path/tokens.txt \ + --paraformer=$model_path/model.int8.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=24000 \ + --log-dir $wav_dir \ + --feature-dim=80 \ + --label $label_file \ + $wav_files diff --git a/egs/wenetspeech4tts/TTS/local/offline-decode-files.py b/egs/wenetspeech4tts/TTS/local/offline-decode-files.py new file mode 100755 index 0000000000..fa6cbdb3eb --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/offline-decode-files.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 by manyeyes +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python API to transcribe +file(s) with a non-streaming model. + +(1) For paraformer + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/paraformer.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(2) For transducer models from icefall + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(3) For CTC models from NeMo + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ + --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav + +(4) For Whisper models + +python3 ./python-api-examples/offline-decode-files.py \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --whisper-task=transcribe \ + --num-threads=1 \ + ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav + +(5) For CTC models from WeNet + +python3 ./python-api-examples/offline-decode-files.py \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav + +(6) For tdnn models of the yesno recipe from icefall + +python3 ./python-api-examples/offline-decode-files.py \ + --sample-rate=8000 \ + --feature-dim=23 \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download non-streaming pre-trained models +used in this file. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import sherpa_onnx +import soundfile as sf + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, like + HELLO WORLD + 你好世界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="", + help=""" + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. + Used only when hotwords-file is given. + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + default="", + help=""" + The path to the bpe vocabulary, the bpe vocabulary is generated by + sentencepiece, you can also export the bpe vocabulary through a bpe model + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given + and modeling-unit is bpe or cjkchar+bpe. + """, + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--nemo-ctc", + default="", + type=str, + help="Path to the model.onnx from NeMo CTC", + ) + + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the model.onnx from WeNet CTC", + ) + + parser.add_argument( + "--tdnn-model", + default="", + type=str, + help="Path to the model.onnx for the tdnn model of the yesno recipe", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="""Sample rate of the feature extractor. Must match the one + expected by the model. Note: The input sound files can have a + different sample rate from this argument.""", + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + parser.add_argument( + "--name", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--log-dir", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--label", + type=str, + default=None, + help="wav_base_name label", + ) + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and can be of type + 32-bit floating point PCM. Its sample rate does not need to be 24kHz. + + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, + which are normalized to the range [-1, 1]. + - Sample rate of the wave file. + """ + + samples, sample_rate = sf.read(wave_filename, dtype="float32") + assert ( + samples.ndim == 1 + ), f"Expected single channel, but got {samples.ndim} channels." + + samples_float32 = samples.astype(np.float32) + + return samples_float32, sample_rate + + +def normalize_text_alimeeting(text: str) -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + import re + + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + return text + + +def main(): + args = get_args() + assert_file_exists(args.tokens) + assert args.num_threads > 0, args.num_threads + + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.paraformer) + + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + + print("Started!") + start_time = time.time() + + streams, results = [], [] + total_duration = 0 + + for i, wave_filename in enumerate(args.sound_files): + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + + streams.append(s) + if i % 10 == 0: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + streams = [] + print(f"Processed {i} files") + # process the last batch + if streams: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + end_time = time.time() + print("Done!") + + results_dict = {} + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + wave_basename = Path(wave_filename).stem + results_dict[wave_basename] = result + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + print(f"num_threads: {args.num_threads}") + print(f"decoding_method: {args.decoding_method}") + print(f"Wave duration: {total_duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + if args.label: + from icefall.utils import store_transcripts, write_error_stats + + labels_dict = {} + with open(args.label, "r") as f: + for line in f: + # fields = line.strip().split(" ") + # fields = [item for item in fields if item] + # assert len(fields) == 4 + # prompt_text, prompt_audio, text, audio_path = fields + + fields = line.strip().split("|") + fields = [item for item in fields if item] + assert len(fields) == 4 + audio_path, prompt_text, prompt_audio, text = fields + labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text) + + final_results = [] + for key, value in results_dict.items(): + final_results.append((key, labels_dict[key], value)) + + store_transcripts( + filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results + ) + with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f: + write_error_stats(f, "test-set", final_results, enable_log=True) + + with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f: + print(f.readline()) # WER + print(f.readline()) # Detailed errors + + +if __name__ == "__main__": + main()