Skip to content

Commit

Permalink
Modified aishell/ASR/conformer_ctc/decode.py,asr_datamodule.py for ba…
Browse files Browse the repository at this point in the history
…tch-way decoding, faster.
  • Loading branch information
czl66 committed Dec 24, 2024
1 parent 19ce1a4 commit 448c28b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 22 deletions.
39 changes: 19 additions & 20 deletions egs/aishell/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,14 @@ def decode_dataset(

num_cuts = 0

try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
# try:
# num_batches = len(dl)
# except TypeError:
# num_batches = "?"

results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
batch = batch[0]
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]

Expand All @@ -399,9 +400,8 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"])

if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"

logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
# batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_idx}, cuts processed until now is {num_cuts}")
return results


Expand Down Expand Up @@ -547,20 +547,19 @@ def main():

test_sets = ["test"]
test_dls = [test_dl]
# for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
sos_id=sos_id,
eos_id=eos_id,
)

for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
sos_id=sos_id,
eos_id=eos_id,
)

save_results(params=params, test_set_name=test_set, results_dict=results_dict)
save_results(params=params, test_set_name=test_sets[0], results_dict=results_dict)

logging.info("Done!")

Expand Down
33 changes: 31 additions & 2 deletions egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

from lhotse.cut import MonoCut
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
Expand Down Expand Up @@ -180,7 +181,34 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def to_dict(self, obj):
"""
Recursively convert an object and its nested objects to dictionaries.
"""
if isinstance(obj, (str, int, float, bool, type(None))):
return obj
elif isinstance(obj, list):
return [to_dict(item) for item in obj]
elif isinstance(obj, dict):
return {key: to_dict(value) for key, value in obj.items()}
elif hasattr(obj, '__dict__'):
return {key: to_dict(value) for key, value in obj.__dict__.items()}
else:
raise TypeError(f"Unsupported type: {type(obj)}")

def my_collate_fn(self, batch):
"""
Convert MonoCut to dict.
"""
return_batch = []
for item in batch:
if isinstance(item, MonoCut):
processed_item = self.to_dict(item)
return_batch.append(processed_item)
elif isinstance(item, dict):
return_batch.append(item)
return return_batch

def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
Expand Down Expand Up @@ -354,9 +382,10 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
)
test_dl = DataLoader(
test,
batch_size=None,
batch_size=100, # specified to some value
sampler=sampler,
num_workers=self.args.num_workers,
num_workers=4, # if larger, it will be more time-consuming for decoding, may stuck
collate_fn=self.my_collate_fn
)
return test_dl

Expand Down

0 comments on commit 448c28b

Please sign in to comment.