Skip to content

Commit

Permalink
Fix wenetspeech decoding speed (#953)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool authored Mar 21, 2023
1 parent 7948624 commit d74822d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pathlib import Path

import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter

# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
Expand Down Expand Up @@ -69,7 +69,7 @@ def compute_fbank_wenetspeech_dev_test():
storage_path=f"{in_out_dir}/feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
storage_type=LilcomHdf5Writer,
storage_type=LilcomChunkyWriter,
overwrite=True,
)

Expand Down
25 changes: 5 additions & 20 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@

from icefall.utils import str2bool

set_caching_enabled(False)
torch.set_num_threads(1)


class _SeedWorkers:
def __init__(self, seed: int):
Expand Down Expand Up @@ -348,24 +345,18 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)

valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
logging.info("About to create dev dataloader")

from lhotse.dataset.iterable_dataset import IterableDatasetWrapper

dev_iter_dataset = IterableDatasetWrapper(
dataset=validate,
sampler=valid_sampler,
)
valid_dl = DataLoader(
dev_iter_dataset,
validate,
batch_size=None,
sampler=valid_sampler,
num_workers=self.args.num_workers,
persistent_workers=False,
)
Expand All @@ -383,19 +374,13 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper

test_iter_dataset = IterableDatasetWrapper(
dataset=test,
sampler=sampler,
)
test_dl = DataLoader(
test_iter_dataset,
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
Expand Down
77 changes: 6 additions & 71 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,83 +651,18 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os

from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset

# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)

dev = "dev"
test_net = "test_net"
test_meeting = "test_meeting"

if not os.path.exists(f"{dev}/shared-0.tar"):
os.makedirs(dev)
dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)

if not os.path.exists(f"{test_net}/shared-0.tar"):
os.makedirs(test_net)
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)

if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)

dev_shards = [
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_cuts = wenetspeech.valid_cuts()
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)

test_net_shards = [
str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)

test_meeting_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_cuts = wenetspeech.test_net_cuts()
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)

dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)

test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
Expand Down
77 changes: 6 additions & 71 deletions egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,83 +661,18 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os

from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset

# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)

dev = "dev"
test_net = "test_net"
test_meeting = "test_meeting"

if not os.path.exists(f"{dev}/shared-0.tar"):
os.makedirs(dev)
dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)

if not os.path.exists(f"{test_net}/shared-0.tar"):
os.makedirs(test_net)
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)

if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)

dev_shards = [
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_cuts = wenetspeech.valid_cuts()
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)

test_net_shards = [
str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)

test_meeting_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_cuts = wenetspeech.test_net_cuts()
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)

dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)

test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
Expand Down

0 comments on commit d74822d

Please sign in to comment.