diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py index bd73e520e0..20d7341db5 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py @@ -20,7 +20,7 @@ from pathlib import Path import torch -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -69,7 +69,7 @@ def compute_fbank_wenetspeech_dev_test(): storage_path=f"{in_out_dir}/feats_{partition}", num_workers=num_workers, batch_duration=batch_duration, - storage_type=LilcomHdf5Writer, + storage_type=LilcomChunkyWriter, overwrite=True, ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 9c07263a23..c9e30e7379 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -46,9 +46,6 @@ from icefall.utils import str2bool -set_caching_enabled(False) -torch.set_num_threads(1) - class _SeedWorkers: def __init__(self, seed: int): @@ -348,24 +345,18 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: cut_transforms=transforms, return_cuts=self.args.return_cuts, ) + valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, - rank=0, - world_size=1, shuffle=False, ) logging.info("About to create dev dataloader") - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - - dev_iter_dataset = IterableDatasetWrapper( - dataset=validate, - sampler=valid_sampler, - ) valid_dl = DataLoader( - dev_iter_dataset, + validate, batch_size=None, + sampler=valid_sampler, num_workers=self.args.num_workers, persistent_workers=False, ) @@ -383,19 +374,13 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, - rank=0, - world_size=1, shuffle=False, ) - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - test_iter_dataset = IterableDatasetWrapper( - dataset=test, - sampler=sampler, - ) test_dl = DataLoader( - test_iter_dataset, + test, batch_size=None, + sampler=sampler, num_workers=self.args.num_workers, ) return test_dl diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 823b33ae59..bdd1f27bc6 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -651,83 +651,18 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # Note: Please use "pip install webdataset==0.1.103" - # for installing the webdataset. - import glob - import os - - from lhotse import CutSet - from lhotse.dataset.webdataset import export_to_webdataset - # we need cut ids to display recognition results. args.return_cuts = True wenetspeech = WenetSpeechAsrDataModule(args) - dev = "dev" - test_net = "test_net" - test_meeting = "test_meeting" - - if not os.path.exists(f"{dev}/shared-0.tar"): - os.makedirs(dev) - dev_cuts = wenetspeech.valid_cuts() - export_to_webdataset( - dev_cuts, - output_path=f"{dev}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test_net}/shared-0.tar"): - os.makedirs(test_net) - test_net_cuts = wenetspeech.test_net_cuts() - export_to_webdataset( - test_net_cuts, - output_path=f"{test_net}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test_meeting}/shared-0.tar"): - os.makedirs(test_meeting) - test_meeting_cuts = wenetspeech.test_meeting_cuts() - export_to_webdataset( - test_meeting_cuts, - output_path=f"{test_meeting}/shared-%d.tar", - shard_size=300, - ) - - dev_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) - ] - cuts_dev_webdataset = CutSet.from_webdataset( - dev_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) + dev_cuts = wenetspeech.valid_cuts() + dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - test_net_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) - ] - cuts_test_net_webdataset = CutSet.from_webdataset( - test_net_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_meeting_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) - ] - cuts_test_meeting_webdataset = CutSet.from_webdataset( - test_meeting_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) + test_net_cuts = wenetspeech.test_net_cuts() + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) - test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) - test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset) + test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_dl = [dev_dl, test_net_dl, test_meeting_dl] diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 32d5738b10..de12b2ff05 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -661,83 +661,18 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # Note: Please use "pip install webdataset==0.1.103" - # for installing the webdataset. - import glob - import os - - from lhotse import CutSet - from lhotse.dataset.webdataset import export_to_webdataset - # we need cut ids to display recognition results. args.return_cuts = True wenetspeech = WenetSpeechAsrDataModule(args) - dev = "dev" - test_net = "test_net" - test_meeting = "test_meeting" - - if not os.path.exists(f"{dev}/shared-0.tar"): - os.makedirs(dev) - dev_cuts = wenetspeech.valid_cuts() - export_to_webdataset( - dev_cuts, - output_path=f"{dev}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test_net}/shared-0.tar"): - os.makedirs(test_net) - test_net_cuts = wenetspeech.test_net_cuts() - export_to_webdataset( - test_net_cuts, - output_path=f"{test_net}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test_meeting}/shared-0.tar"): - os.makedirs(test_meeting) - test_meeting_cuts = wenetspeech.test_meeting_cuts() - export_to_webdataset( - test_meeting_cuts, - output_path=f"{test_meeting}/shared-%d.tar", - shard_size=300, - ) - - dev_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) - ] - cuts_dev_webdataset = CutSet.from_webdataset( - dev_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) + dev_cuts = wenetspeech.valid_cuts() + dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - test_net_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) - ] - cuts_test_net_webdataset = CutSet.from_webdataset( - test_net_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_meeting_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) - ] - cuts_test_meeting_webdataset = CutSet.from_webdataset( - test_meeting_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) + test_net_cuts = wenetspeech.test_net_cuts() + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) - test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) - test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset) + test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_dl = [dev_dl, test_net_dl, test_meeting_dl]