-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata.py
63 lines (54 loc) · 2.2 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from torch.utils import data
import torch.nn.functional as F
import torch
import numpy as np
import pickle
import os
from spec_utils import get_mspec
import random
class AutoVCDataset(data.Dataset):
def __init__(self, paths, spk_embs, len_crop, scale=None, shift=None) -> None:
super().__init__()
self.paths = paths
self.spk_embs = spk_embs
self.len_crop = len_crop
# assert jitter % 32 == 0, "Jitter must be divisible by 32"
# self.jitter_choices = list(range(0, jitter+1, 32))
if scale is not None and shift is not None:
self.norm_mel = lambda x: (x + shift) / scale
self.denorm_mel = lambda x: (x*scale) - shift
else:
self.norm_mel = lambda x: x
self.denorm_mel = lambda x: x
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, index):
pth = self.paths[index]
if pth.suffix == '.pt': mspec = torch.load(str(pth)) # (N, n_mels)
else: mspec = get_mspec(pth, is_hifigan=True) # (N, n_mels)
mspec = self.random_crop(mspec)
spk_id = pth.parent.stem
spk_emb = self.spk_embs[spk_id]
mspec = self.norm_mel(mspec)
return mspec, spk_emb
def random_crop(self, mspec):
N, _ = mspec.shape
clen = self.len_crop
if N < clen:
# pad mspec
n_pad = clen - N
mspec = F.pad(mspec, (0, 0, 0, n_pad), value=mspec.min())
elif N > clen:
crop_start = random.randint(0, N - clen)
mspec = mspec[crop_start:crop_start+clen]
return mspec
def get_loader(files, spk_embs, len_crop, batch_size=16,
num_workers=8, shuffle=False, scale=None, shift=None):
"""Build and return a data loader."""
dataset = AutoVCDataset(files, spk_embs, len_crop, scale=scale, shift=shift)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=shuffle, pin_memory=shuffle) # set pin memory to True if training.
return data_loader