diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 2cb476e208..8dd6725365 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -366,13 +366,14 @@ def decode_dataset( num_cuts = 0 - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" + # try: + # num_batches = len(dl) + # except TypeError: + # num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): + batch = batch[0] texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] @@ -399,9 +400,8 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + # batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_idx}, cuts processed until now is {num_cuts}") return results @@ -547,20 +547,19 @@ def main(): test_sets = ["test"] test_dls = [test_dl] + # for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + lexicon=lexicon, + sos_id=sos_id, + eos_id=eos_id, + ) - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - lexicon=lexicon, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results(params=params, test_set_name=test_sets[0], results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index aacbd153de..0571dddb78 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional +from lhotse.cut import MonoCut from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( CutConcatenate, @@ -180,7 +181,34 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="When enabled, select noise from MUSAN and mix it" "with training dataset. ", ) + def to_dict(self, obj): + """ + Recursively convert an object and its nested objects to dictionaries. + """ + if isinstance(obj, (str, int, float, bool, type(None))): + return obj + elif isinstance(obj, list): + return [to_dict(item) for item in obj] + elif isinstance(obj, dict): + return {key: to_dict(value) for key, value in obj.items()} + elif hasattr(obj, '__dict__'): + return {key: to_dict(value) for key, value in obj.__dict__.items()} + else: + raise TypeError(f"Unsupported type: {type(obj)}") + def my_collate_fn(self, batch): + """ + Convert MonoCut to dict. + """ + return_batch = [] + for item in batch: + if isinstance(item, MonoCut): + processed_item = self.to_dict(item) + return_batch.append(processed_item) + elif isinstance(item, dict): + return_batch.append(item) + return return_batch + def train_dataloaders( self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None ) -> DataLoader: @@ -354,9 +382,10 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: ) test_dl = DataLoader( test, - batch_size=None, + batch_size=100, # specified to some value sampler=sampler, - num_workers=self.args.num_workers, + num_workers=4, # if larger, it will be more time-consuming for decoding, may stuck + collate_fn=self.my_collate_fn ) return test_dl