From 499a689860d832edc3b5c744ef5d62eeddc96d3b Mon Sep 17 00:00:00 2001 From: juan montesinos Date: Mon, 4 Jul 2022 15:22:47 +0200 Subject: [PATCH] VoViT core --- vovit/__init__.py | 19 +- vovit/core/__init__.py | 13 + vovit/core/kabsch.py | 119 ++++++ vovit/core/landmark_estimator/TDDFA_GPU.py | 37 +- vovit/core/landmark_estimator/bfm/bfm.py | 1 + .../landmark_estimator/utils/functions.py | 1 + .../landmark_estimator/utils/tddfa_util.py | 5 +- vovit/core/models/__init__.py | 2 + vovit/core/models/modules/__init__.py | 0 vovit/core/models/modules/gconv.py | 80 ++++ vovit/core/models/modules/graph.py | 158 ++++++++ vovit/core/models/modules/spec2vec.py | 182 +++++++++ vovit/core/models/modules/st_gcn.py | 230 +++++++++++ vovit/core/models/production_model.py | 380 ++++++++++++++++++ vovit/core/models/transformers/__init__.py | 48 +++ vovit/core/models/transformers/av_sttrans.py | 130 ++++++ vovit/core/models/utils.py | 12 + vovit/core/speech_mean_face.npy | Bin 0 -> 1760 bytes vovit/utils.py | 29 ++ 19 files changed, 1428 insertions(+), 18 deletions(-) create mode 100644 vovit/core/__init__.py create mode 100644 vovit/core/kabsch.py create mode 100644 vovit/core/models/__init__.py create mode 100644 vovit/core/models/modules/__init__.py create mode 100644 vovit/core/models/modules/gconv.py create mode 100644 vovit/core/models/modules/graph.py create mode 100644 vovit/core/models/modules/spec2vec.py create mode 100644 vovit/core/models/modules/st_gcn.py create mode 100644 vovit/core/models/production_model.py create mode 100644 vovit/core/models/transformers/__init__.py create mode 100644 vovit/core/models/transformers/av_sttrans.py create mode 100644 vovit/core/models/utils.py create mode 100644 vovit/core/speech_mean_face.npy diff --git a/vovit/__init__.py b/vovit/__init__.py index 24a3061..07f38d7 100644 --- a/vovit/__init__.py +++ b/vovit/__init__.py @@ -1,4 +1,5 @@ import os +import yaml import torch from einops import rearrange @@ -20,12 +21,15 @@ def __init__(self, *, model_name: str, debug: dict, pretrained: bool = True, if self.extract_landmarks: from .core.landmark_estimator.TDDFA_GPU import TDDFA - self.face_extractor = TDDFA() + cfg = yaml.load(open(utils.DEFAULT_CFG_PATH), Loader=yaml.SafeLoader) + cfg['checkpoint_fp'] = os.path.join(utils.LANDMARK_LIB_PATH, 'weights', 'mb1_120x120.pth') + cfg['bfm_fp'] = os.path.join(utils.LANDMARK_LIB_PATH, 'configs', 'bfm_noneck_v3.pkl') + self.face_extractor = TDDFA(**cfg) self.register_buffer('mean_face', torch.from_numpy(np_load(os.path.join(core_path, 'speech_mean_face.npy'))).float(), persistent=False) - def forward(self, mixture, visuals): + def forward(self, mixture, visuals, extract_landmarks=False): """ :param mixture: torch.Tensor of shape (B,N) :param visuals: torch.Tensor of shape (B,C,H,W) BGR format required @@ -35,7 +39,7 @@ def forward(self, mixture, visuals): raise NotImplementedError else: cropped_video = visuals - if self.extract_landmarks: + if extract_landmarks: ld = self.face_extractor(cropped_video) avg = (ld[:-2] + ld[1:-1] + ld[2:]) / 3 ld[:-2] = avg @@ -58,10 +62,19 @@ def forward_unlimited(self, mixture, visuals): Allows to run inference in an unlimited duration samples (up to gpu memory constrains) The results will be trimmed to multiples of 2 seconds (e.g. if your audio is 8.5 seconds long, the result will be trimmed to 8 seconds) + Args: + visuals: raw video if self.extract_landmarks is True, precomputed_landmarks otherwise. + lanmarks are uint16 tensors of shape (T,3,68) + raw video are uint8 RGB tensors of shape (T,H,W,3) (values between 0-255) + mixture: tensor of shape (N) """ fps = VIDEO_FRAMERATE length = self.vovit.avse.av_se.ap._audio_length n_chunks = visuals.shape[0] // (fps * 2) + if self.extract_landmarks: + visuals = self.face_extractor(visuals) + avg = (visuals[:-2] + visuals[1:-1] + visuals[2:]) / 3 + visuals[:-2] = avg visuals = visuals[:n_chunks * fps * 2].view(n_chunks, fps * 2, 3, 68) mixture = mixture[:n_chunks * length].view(n_chunks, -1) pred = self.forward(mixture, visuals) diff --git a/vovit/core/__init__.py b/vovit/core/__init__.py new file mode 100644 index 0000000..3740982 --- /dev/null +++ b/vovit/core/__init__.py @@ -0,0 +1,13 @@ +AUDIO_SAMPLERATE = 16384 +VIDEO_FRAMERATE = 25 +N_FFT = 1022 +HOP_LENGTH = 256 +SP_FREQ_SHAPE = N_FFT // 2 + 1 + +fourier_defaults = {"audio_samplerate": AUDIO_SAMPLERATE, + "n_fft": N_FFT, + "sp_freq_shape": SP_FREQ_SHAPE, + "hop_length": HOP_LENGTH} +core_path = __path__[0] + +from .models import VoViT diff --git a/vovit/core/kabsch.py b/vovit/core/kabsch.py new file mode 100644 index 0000000..5846d37 --- /dev/null +++ b/vovit/core/kabsch.py @@ -0,0 +1,119 @@ +from typing import Union +import torch + + +def rigid_transform_3D(target_face: torch.tensor, mean_face: torch.tensor) -> torch.tensor: + """ + Compute a rigid transformation between two sets of landmarks by using Kabsch algorithm. + The Kabsch algorithm, named after Wolfgang Kabsch, is a method for calculating the optimal rotation matrix + that minimizes the RMSD (root mean squared deviation) between two paired sets of points. + args: + target_face: NumPy array of shape (3,N) + mean_face: NumPy array of shape (3,N) + + returns: + R: NumPy array of shape (3,3) + t: NumPy array of shape (3,1) + + source: + https://en.wikipedia.org/wiki/Kabsch_algorithm + """ + # Geometric transformations in 3D + # https://cseweb.ucsd.edu/classes/wi18/cse167-a/lec3.pdf + + # Affine transformation (theoretical) + # http://learning.aols.org/aols/3D_Affine_Coordinate_Transformations.pdf + + # Implementation from http://nghiaho.com/?page_id=671 + # + assert target_face.shape == mean_face.shape + assert target_face.shape[0] == 3, "3D rigid transform only" + + # find mean column wise + centroid_A = torch.mean(target_face, dim=1) + centroid_B = torch.mean(mean_face, dim=1) + + # ensure centroids are 3x1 + centroid_A = centroid_A.reshape(-1, 1) + centroid_B = centroid_B.reshape(-1, 1) + + # subtract mean + Am = target_face - centroid_A + Bm = mean_face - centroid_B + + H = Am @ Bm.T + # H = (Am.cpu() @ Bm.T.cpu()) + + # find rotation + U, S, Vt = torch.linalg.svd(H) # torch.svd differs from torch.linalg.svd + # https://pytorch.org/docs/stable/generated/torch.svd.html + R = Vt.T @ U.T + + # special reflection case + if torch.linalg.det(R) < 0: + print("det(R) < R, reflection detected!, correcting for it ...") + Vt[2, :] *= -1 + R = Vt.T @ U.T + + t = -R @ centroid_A + centroid_B + + return R, t + + +def apply_transformation(R, t, landmarks: torch.tensor) -> torch.tensor: + """ + Apply a rigid transformation to a set of landmarks. + args: + R: NumPy array of shape (3,3) + t: NumPy array of shape (3,1) + landmarks: NumPy array of shape (3,N) + """ + assert landmarks.shape[0] == 3, "landmarks must be 3D" + assert R.shape == (3, 3), "R must be 3x3" + assert t.shape == (3, 1), "t must be 3x1" + + # apply transformation + transformed_landmarks = R @ landmarks + t + + return transformed_landmarks + + +def register_sequence_of_landmarks(target_sequence: torch.tensor, mean_face: torch.tensor, per_frame=False, + display_sequence: Union[torch.tensor, None] = None) -> torch.tensor: + """ + Register a sequence of landmarks to a mean face. + Computational complexity: O(3*N*T) + args: + target_face: NumPy array of shape (T,3,N) + mean_face: NumPy array of shape (3,N) + per_frame: either to estimate the transformation per frame or given the mean face. + display_sequence: (optional) NumPy array of shape (T',3,N'). Optional array to estimate the transformation + on some of the landmarks. + + returns: + registered_sequence: NumPy array of shape (T,3,N) + + example: + Computing the transformation ignoring landmarks from 48 onwards but + estimating the transformation for all of them + >>> registered_sequence = register_sequence_of_landmarks(landmarks[..., :48], + >>> mean_face[:, :48], + >>> display_sequence=landmarks) + """ + if display_sequence is None: + display_sequence = target_sequence + + if not per_frame: + # Estimates the mean face + target_mean_face = torch.mean(target_sequence, dim=0) + # compute rigid transformation + R, t = rigid_transform_3D(target_mean_face, mean_face) + + # apply transformation + registered_sequence = [] + for x, y in zip(target_sequence, display_sequence): + if per_frame: + R, t = rigid_transform_3D(x, mean_face) + registered_sequence.append(apply_transformation(R, t, y)) + + return torch.stack(registered_sequence) diff --git a/vovit/core/landmark_estimator/TDDFA_GPU.py b/vovit/core/landmark_estimator/TDDFA_GPU.py index 35092cd..e35c20c 100644 --- a/vovit/core/landmark_estimator/TDDFA_GPU.py +++ b/vovit/core/landmark_estimator/TDDFA_GPU.py @@ -7,13 +7,13 @@ from torch import nn from torchvision.transforms import Compose -import models -from bfm import BFMModel -from utils.io import _load -from utils.functions import ( +from . import models +from .bfm import BFMModel +from .utils.io import _load +from .utils.functions import ( crop_video, reshape_fortran, parse_roi_box_from_bbox, ) -from utils.tddfa_util import ( +from .utils.tddfa_util import ( load_model, _batched_parse_param, batched_similar_transform, ToTensorGjz, NormalizeGjz ) @@ -25,6 +25,7 @@ class TDDFA(nn.Module): """TDDFA: named Three-D Dense Face Alignment (TDDFA)""" def __init__(self, **kvs): + super(TDDFA, self).__init__() self.size = kvs.get('size', 120) # load BFM @@ -48,7 +49,6 @@ def __init__(self, **kvs): ) model = load_model(model, kvs.get('checkpoint_fp')) - self.model = model # data normalization @@ -59,12 +59,8 @@ def __init__(self, **kvs): # params normalization config r = _load(param_mean_std_fp) - self.param_mean = torch.from_numpy(r.get('mean')) - self.param_std = torch.from_numpy(r.get('std')) - self.param_mean = self.param_mean - self.param_std = self.param_std - - + self.register_buffer('param_mean', torch.from_numpy(r.get('mean')), persistent=False) + self.register_buffer('param_std', torch.from_numpy(r.get('std')), persistent=False) def batched_inference(self, video_ori, bbox, **kvs): """The main call of TDDFA, given image and box / landmark, return 3DMM params and roi_box @@ -75,7 +71,8 @@ def batched_inference(self, video_ori, bbox, **kvs): """ roi_box = parse_roi_box_from_bbox(bbox) video = crop_video(video_ori, roi_box) - img = torch.nn.functional.interpolate(video, size=(self.size, self.size), mode='bilinear', align_corners=False) + img = torch.nn.functional.interpolate(video.float(), size=(self.size, self.size), mode='bilinear', + align_corners=False) inp = self.transform_normalize(img) param = self.model(inp) @@ -96,3 +93,17 @@ def batched_recon_vers(self, param, roi_box, **kvs): pts3d = batched_similar_transform(pts3d, roi_box, size) return pts3d + + def forward(self, video): + """ + :param video: RGB Video of shape (T,H,W,C) uint8 (values between 0-255). Video has to be cropped around the face + accurately (mainly to reduce GPU memory requirements). + :return: + """ + T, H, W, C = video.shape + assert C == 3, 'Video has to be RGB' + video = video.flip(-1) # BGR conversion + video = video.permute(0, 3, 1, 2) # T H W C -> T C H W + param, box_roi = self.batched_inference(video, [0, 0, W, H]) + pts = self.batched_recon_vers(param, box_roi) + return pts \ No newline at end of file diff --git a/vovit/core/landmark_estimator/bfm/bfm.py b/vovit/core/landmark_estimator/bfm/bfm.py index 30cca8d..69f5eaf 100644 --- a/vovit/core/landmark_estimator/bfm/bfm.py +++ b/vovit/core/landmark_estimator/bfm/bfm.py @@ -22,6 +22,7 @@ def _to_ctype(arr): class BFMModel(torch.nn.Module): def __init__(self, bfm_fp, shape_dim=40, exp_dim=10): + super(BFMModel, self).__init__() bfm = _load(bfm_fp) if osp.split(bfm_fp)[-1] == 'bfm_noneck_v3.pkl': self.tri = _load(make_abs_path('../configs/tri.pkl')) # this tri/face is re-built for bfm_noneck_v3 diff --git a/vovit/core/landmark_estimator/utils/functions.py b/vovit/core/landmark_estimator/utils/functions.py index 9e2bd62..38252c2 100644 --- a/vovit/core/landmark_estimator/utils/functions.py +++ b/vovit/core/landmark_estimator/utils/functions.py @@ -57,6 +57,7 @@ def reshape_fortran(x, shape): if len(x.shape) > 0: x = x.permute(*reversed(range(len(x.shape)))) return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape)))) + def crop_video(video, roi_box): bs, c, h, w = video.shape diff --git a/vovit/core/landmark_estimator/utils/tddfa_util.py b/vovit/core/landmark_estimator/utils/tddfa_util.py index 7a75d84..04b7e57 100644 --- a/vovit/core/landmark_estimator/utils/tddfa_util.py +++ b/vovit/core/landmark_estimator/utils/tddfa_util.py @@ -53,8 +53,9 @@ def __repr__(self): class NormalizeGjz(torch.nn.Module): def __init__(self, mean, std): - self.mean = mean - self.std = std + super(NormalizeGjz, self).__init__() + self.register_buffer('mean', torch.tensor(mean), persistent=False) + self.register_buffer('std', torch.tensor(std), persistent=False) def __call__(self, tensor): tensor.sub_(self.mean).div_(self.std) diff --git a/vovit/core/models/__init__.py b/vovit/core/models/__init__.py new file mode 100644 index 0000000..9340cb8 --- /dev/null +++ b/vovit/core/models/__init__.py @@ -0,0 +1,2 @@ +from .. import fourier_defaults, VIDEO_FRAMERATE +from .production_model import VoViT diff --git a/vovit/core/models/modules/__init__.py b/vovit/core/models/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vovit/core/models/modules/gconv.py b/vovit/core/models/modules/gconv.py new file mode 100644 index 0000000..6b223ab --- /dev/null +++ b/vovit/core/models/modules/gconv.py @@ -0,0 +1,80 @@ +# The based unit of graph convolutional networks. + +import torch +import torch.nn as nn + + +class ConvTemporalGraphical(nn.Module): + r"""The basic module for applying a graph convolution. + + Args: + in_channels (int): Number of channels in the input sequence data + out_channels (int): Number of channels produced by the convolution + kernel_size (int): Size of the graph convolving kernel + t_kernel_size (int): Size of the temporal convolving kernel + t_stride (int, optional): Stride of the temporal convolution. Default: 1 + t_padding (int, optional): Temporal zero-padding added to both sides of + the input. Default: 0 + t_dilation (int, optional): Spacing between temporal kernel elements. + Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. + Default: ``True`` + + Shape: + - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format + - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format + - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format + - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format + + where + :math:`N` is a batch size, + :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, + :math:`T_{in}/T_{out}` is a length of input/output sequence, + :math:`V` is the number of graph nodes. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + t_kernel_size=1, + t_stride=1, + t_padding=0, + t_dilation=1, + bias=True, + ): + + super().__init__() + + self.kernel_size = kernel_size + self.conv = nn.Conv2d(in_channels, + out_channels * kernel_size, + kernel_size=(t_kernel_size, 1), + padding=(t_padding, 0), + stride=(t_stride, 1), + dilation=(t_dilation, 1), + bias=bias) + + def forward(self, x, A): + if A.ndim == 4: + assert A.size(1) == self.kernel_size + elif A.ndim == 5: + assert A.size(2) == self.kernel_size + else: + assert A.size(0) == self.kernel_size + + x = self.conv(x) # B,channels=3,T,J + + n, kc, t, v = x.size() + x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v) + if A.ndimension() == 3: + # static or dynamic + x = torch.einsum('nkctv,kvw->nctw', (x, A)) + elif A.ndimension() == 4: + # Categorical + x = torch.einsum('nkctv,nkvw->nctw', (x, A)) + elif A.ndimension() == 5: + x = torch.einsum('nkctv,ntkvw->nctw', (x, A)) + else: + raise Exception('Adjacency matrix dimensionalty is %d but should be 3,4 or 5' % A.ndimension()) + return x.contiguous(), A diff --git a/vovit/core/models/modules/graph.py b/vovit/core/models/modules/graph.py new file mode 100644 index 0000000..e9e0948 --- /dev/null +++ b/vovit/core/models/modules/graph.py @@ -0,0 +1,158 @@ +import numpy as np + + +class Graph(): + """ The Graph to model the skeletons extracted by the openpose + + Args: + strategy (string): must be one of the follow candidates + - uniform: Uniform Labeling + - distance: Distance Partitioning + - spatial: Spatial Configuration + For more information, please refer to the section 'Partition Strategies' + in our paper (https://arxiv.org/abs/1801.07455). + + layout (string): must be one of the follow candidates + - acappella: ACAPELLA + + max_hop (int): the maximal distance between two connected nodes + dilation (int): controls the spacing between the kernel points + + """ + + def __init__(self, + layout='acappella', + strategy='uniform', + max_hop=1, + dilation=1): + self.max_hop = max_hop + self.dilation = dilation + + self.get_edge(layout) + self.hop_dis = get_hop_distance(self.num_node, + self.edge, + max_hop=max_hop) + self.get_adjacency(strategy) + + def __str__(self): + return self.A + + def get_edge(self, layout): + if layout == 'acappella': + self.num_node = 68 + self_link = [(i, i) for i in range(self.num_node)] + all = [(i, i + 1) for i in range(68)] + + face = all[slice(0, 16)] + eyebrown1 = all[slice(17, 21)] + eyebrown2 = all[slice(22, 26)] + nose = all[slice(27, 30)] + nostril = all[slice(31, 35)] + eye1 = all[slice(36, 41)] + eye2 = all[slice(42, 47)] + lips = all[slice(48, 59)] + teeth = all[slice(60, 67)] + self.edge = self_link + face + eye1 + eye2 + eyebrown1 + eyebrown2 + nose + nostril + lips + teeth + self.center = 0 + # ORIGINAL SOURCECODE + # https://github.com/1adrianb/face-alignment/blob/master/examples/detect_landmarks_in_image.py + # pred_types = {'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)), + # 'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)), + # 'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)), + # 'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)), + # 'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)), + # 'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)), + # 'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)), + # 'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)), + # 'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4)) + # } + # Voxceleb distribution + # slice(0, 17), 'face contour' + # slice(17, 22), 'right eyebrow' + # slice(22, 27), 'left eyebrow' + # slice(27, 36), 'nose' + # slice(36, 42), 'right eye' + # slice(42, 48), 'left eye' + # slice(48, 69), 'mouth' + else: + raise ValueError("Do Not Exist This Layout.") + return self.edge, self.num_node + + def get_adjacency(self, strategy): + valid_hop = range(0, self.max_hop + 1, self.dilation) + adjacency = np.zeros((self.num_node, self.num_node)) + for hop in valid_hop: + adjacency[self.hop_dis == hop] = 1 + normalize_adjacency = normalize_digraph(adjacency) + + if strategy == 'uniform': + A = np.zeros((1, self.num_node, self.num_node)) + A[0] = normalize_adjacency + self.A = A + elif strategy == 'distance': + A = np.zeros((len(valid_hop), self.num_node, self.num_node)) + for i, hop in enumerate(valid_hop): + A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == + hop] + self.A = A + elif strategy == 'spatial': + A = [] + for hop in valid_hop: + a_root = np.zeros((self.num_node, self.num_node)) + a_close = np.zeros((self.num_node, self.num_node)) + a_further = np.zeros((self.num_node, self.num_node)) + for i in range(self.num_node): + for j in range(self.num_node): + if self.hop_dis[j, i] == hop: + if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]: + a_root[j, i] = normalize_adjacency[j, i] + elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]: + a_close[j, i] = normalize_adjacency[j, i] + else: + a_further[j, i] = normalize_adjacency[j, i] + if hop == 0: + A.append(a_root) + else: + A.append(a_root + a_close) + A.append(a_further) + A = np.stack(A) + self.A = A + else: + raise ValueError("Do Not Exist This Strategy") + + +def get_hop_distance(num_node, edge, max_hop=1): + A = np.zeros((num_node, num_node)) + for i, j in edge: + A[j, i] = 1 + A[i, j] = 1 + + # compute hop steps + hop_dis = np.zeros((num_node, num_node)) + np.inf + transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] + arrive_mat = (np.stack(transfer_mat) > 0) + for d in range(max_hop, -1, -1): + hop_dis[arrive_mat[d]] = d + return hop_dis + + +def normalize_digraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i] ** (-1) + AD = np.dot(A, Dn) + return AD + + +def normalize_undigraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i] ** (-0.5) + DAD = np.dot(np.dot(Dn, A), Dn) + return DAD diff --git a/vovit/core/models/modules/spec2vec.py b/vovit/core/models/modules/spec2vec.py new file mode 100644 index 0000000..87e39aa --- /dev/null +++ b/vovit/core/models/modules/spec2vec.py @@ -0,0 +1,182 @@ +import torch.nn.functional as F +from einops import rearrange +from torch import nn + + +class Spec2Vec(nn.Module): + def __init__(self, last_shape=8): + super(Spec2Vec, self).__init__() + + # Audio model layers , name of layers as per table 1 given in paper. + + self.conv1 = nn.Conv2d( + 2, + 96, + kernel_size=(1, 7), + padding=self.get_padding((1, 7), (1, 1)), + dilation=(1, 1), + ) + + self.conv2 = nn.Conv2d( + 96, + 96, + kernel_size=(7, 1), + padding=self.get_padding((7, 1), (1, 1)), + dilation=(1, 1), + ) + + self.conv3 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (1, 1)), + dilation=(1, 1), + ) + + self.conv4 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (2, 1)), + dilation=(2, 1), + ) + + self.conv5 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (4, 1)), + dilation=(4, 1), + ) + + self.conv6 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (8, 1)), + dilation=(8, 1), + ) + + self.conv7 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (16, 1)), + dilation=(16, 1), + ) + + self.conv8 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (32, 1)), + dilation=(32, 1), + ) + + self.conv9 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (1, 1)), + dilation=(1, 1), + ) + + self.conv10 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (2, 2)), + dilation=(2, 2), + ) + + self.conv11 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (4, 4)), + dilation=(4, 4), + ) + + self.conv12 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (8, 8)), + dilation=(8, 8), + ) + + self.conv13 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (16, 16)), + dilation=(16, 16), + ) + + self.conv14 = nn.Conv2d( + 96, + 96, + kernel_size=(5, 5), + padding=self.get_padding((5, 5), (32, 32)), + dilation=(32, 32), + ) + + self.conv15 = nn.Conv2d( + 96, + last_shape, + kernel_size=(1, 1), + padding=self.get_padding((1, 1), (1, 1)), + dilation=(1, 1), + ) + + # Batch normalization layers + + self.batch_norm1 = nn.BatchNorm2d(96) + self.batch_norm2 = nn.BatchNorm2d(96) + self.batch_norm3 = nn.BatchNorm2d(96) + self.batch_norm4 = nn.BatchNorm2d(96) + self.batch_norm5 = nn.BatchNorm2d(96) + self.batch_norm6 = nn.BatchNorm2d(96) + self.batch_norm7 = nn.BatchNorm2d(96) + self.batch_norm8 = nn.BatchNorm2d(96) + self.batch_norm9 = nn.BatchNorm2d(96) + self.batch_norm10 = nn.BatchNorm2d(96) + self.batch_norm11 = nn.BatchNorm2d(96) + self.batch_norm11 = nn.BatchNorm2d(96) + self.batch_norm12 = nn.BatchNorm2d(96) + self.batch_norm13 = nn.BatchNorm2d(96) + self.batch_norm14 = nn.BatchNorm2d(96) + self.batch_norm15 = nn.BatchNorm2d(last_shape) + + def get_padding(self, kernel_size, dilation): + padding = ( + ((dilation[0]) * (kernel_size[0] - 1)) // 2, + ((dilation[1]) * (kernel_size[1] - 1)) // 2, + ) + return padding + + def forward(self, input_audio): + # input audio will be (2,256,256) + + output_layer = F.leaky_relu(self.batch_norm1(self.conv1(input_audio)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm2(self.conv2(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm3(self.conv3(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm4(self.conv4(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm5(self.conv5(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm6(self.conv6(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm7(self.conv7(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm8(self.conv8(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm9(self.conv9(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm10(self.conv10(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm11(self.conv11(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm12(self.conv12(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm13(self.conv13(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm14(self.conv14(output_layer)), negative_slope=0.1) + output_layer = F.leaky_relu(self.batch_norm15(self.conv15(output_layer)), negative_slope=0.1) + + # output_layer will be (N,8,256,256) + # we want it to be (N,8*256,256,1) + output_layer = rearrange(output_layer, 'b c t f -> b c f t ') + # output_layer = torch.permute(output_layer, [0,1,3,2]).reshape(10,8*256,256) + output_layer = rearrange(output_layer, 'b c f t -> b (c f) t') + return output_layer.unsqueeze(-1) # b (c f) t 1 diff --git a/vovit/core/models/modules/st_gcn.py b/vovit/core/models/modules/st_gcn.py new file mode 100644 index 0000000..4e571e8 --- /dev/null +++ b/vovit/core/models/modules/st_gcn.py @@ -0,0 +1,230 @@ +import torch +import torch.nn as nn + +from .gconv import ConvTemporalGraphical +from .graph import Graph + + +def init_eiw(x): + B, T, J = x.shape + x = x.unsqueeze(2).expand(B, T, J, J) + x = torch.min(x, x.transpose(2, 3)) + x = x.unsqueeze(2).expand(B, T, 3, J, J) + return x + + +class FiLM(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.bias = nn.Linear(in_channels, out_channels) + self.scale = nn.Linear(in_channels, out_channels) + + def forward(self, x, c, *args): + return x * self.scale(c).view(*args) + self.bias(c).view(*args) + + + + + +class ST_GCN(nn.Module): + r"""Spatial temporal graph convolutional networks. + + Args: + in_channels (int): Number of channels in the input data + num_class (int): Number of classes for the classification task + graph_cfg (dict): The arguments for building the graph + edge_importance_weighting (bool): If ``True``, adds a learnable + importance weighting to the edges of the graph + **kwargs (optional): Other parameters for graph convolution units + + Shape: + - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})` + - Output: :math:`(N, num_class)` where + :math:`N` is a batch size, + :math:`T_{in}` is a length of input sequence, [X,Y,C] + :math:`V_{in}` is the number of graph nodes, NUMBER OF JOINTS + :math:`M_{in}` is the number of instance in a frame. NUMBER OF PEOPLE + """ + + def __init__(self, + in_channels, + dilated, + graph_cfg, + temporal_downsample:bool, + input_type='x', + **kwargs): + super().__init__() + self.temporal_downsample = temporal_downsample + # load graph + self.graph = Graph(**graph_cfg) + A = torch.tensor(self.graph.A, + dtype=torch.float32, + requires_grad=False) + self.register_buffer('A', A) + self.input_type = input_type + self.dilated = dilated + # build networks + spatial_kernel_size = A.size(0) + temporal_kernel_size = 5 + kernel_size = (temporal_kernel_size, spatial_kernel_size) + self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + if kwargs.get('bn_momentum') is not None: + del kwargs['bn_momentum'] + kwargs['edge_importance_weighting'] = kwargs.get('edge_importance_weighting') + kwargs['A'] = A + kwargs['dilation'] = 2 if dilated else 1 + kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} + + self.st_gcn_networks = nn.ModuleList(( + st_gcn_block(in_channels, 32, + kernel_size, 1, + residual=False, **kwargs0), + st_gcn_block(32, 32, kernel_size, 1, **kwargs), + st_gcn_block(32, 64, kernel_size, 1, **kwargs), + st_gcn_block(64, 64, kernel_size, 1, **kwargs), + st_gcn_block(64, 128, kernel_size, 1, **kwargs), + st_gcn_block(128, 128, kernel_size, 1, **kwargs), + st_gcn_block(128, 256, kernel_size, 1, **kwargs), + st_gcn_block(256, 256, kernel_size, 1, **kwargs), + )) + + if self.temporal_downsample: + self.st_gcn_networks = nn.ModuleList(( + st_gcn_block(in_channels, 32, + kernel_size, 1, + residual=False, **kwargs0), + st_gcn_block(32, 32, kernel_size, 1, **kwargs), + st_gcn_block(32, 64, kernel_size, 2, **kwargs), + st_gcn_block(64, 64, kernel_size, 1, **kwargs), + st_gcn_block(64, 128, kernel_size, 2, **kwargs), + st_gcn_block(128, 128, kernel_size, 1, **kwargs), + st_gcn_block(128, 256, kernel_size, 2, **kwargs), + st_gcn_block(256, 256, kernel_size, 1, **kwargs), + )) + + def forward(self, x, *args): + args = list(args) + if x.shape[1] == 3: + args.append(x[:, 2, ...]) + x = self.extract_feature(x, *args) + + return x + + def extract_feature(self, x, *args): + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + # forwad + + for gcn in self.st_gcn_networks: + x, _ = gcn(x, self.A, *args) + return x + + + +class st_gcn_block(nn.Module): + r"""Applies a spatial temporal graph convolution over an input graph sequence. + + Args: + in_channels (int): Number of channels in the input sequence data + out_channels (int): Number of channels produced by the convolution + kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel + stride (int, optional): Stride of the temporal convolution. Default: 1 + dropout (int, optional): Dropout rate of the final output. Default: 0 + residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True`` + + Shape: + - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format + - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format + - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format + - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format + + where + :math:`N` is a batch size, + :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, + :math:`T_{in}/T_{out}` is a length of input/output sequence, + :math:`V` is the number of graph nodes. + + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dropout=0, + residual=True, + dilation=1, + edge_importance_weighting='static', + A=None, + activation='relu'): + super().__init__() + + assert len(kernel_size) == 2 + assert kernel_size[0] % 2 == 1 + + padding = (dilation * (kernel_size[0] - 1)) // 2 + padding = (padding, 0) + + self.ctype = edge_importance_weighting + self.activation = activation + if edge_importance_weighting == 'static': + self.edge_importance = 1. + self.edge_importance_weighting = True + elif edge_importance_weighting == 'dynamic': + self.edge_importance_weighting = True + self.edge_importance = nn.Parameter(torch.ones(A.shape)) + + else: + raise ValueError('edge_importance_weighting (%s) not implemented') + + self.gcn = ConvTemporalGraphical(in_channels, out_channels, + kernel_size[1]) + + self.tcn = nn.Sequential( + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + out_channels, + out_channels, + (kernel_size[0], 1), + (stride, 1), + padding, + dilation=(dilation, 1), + ), + nn.BatchNorm2d(out_channels), + nn.Dropout(dropout, inplace=True), + ) + + if not residual: + self.residual = lambda x: 0 + + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + + else: + self.residual = nn.Sequential( + nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=(stride, 1)), + nn.BatchNorm2d(out_channels), + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, A): + if self.edge_importance_weighting: + A = A * self.edge_importance + res = self.residual(x) + x, A = self.gcn(x, A) + x = self.tcn(x) + res + if self.activation == 'relu': + return self.relu(x), A + else: + return x, A diff --git a/vovit/core/models/production_model.py b/vovit/core/models/production_model.py new file mode 100644 index 0000000..fb7b2ee --- /dev/null +++ b/vovit/core/models/production_model.py @@ -0,0 +1,380 @@ +import os +import inspect +from copy import copy + +import torch +from einops import rearrange +from torch import nn, istft +from torchaudio.functional import spectrogram + +from . import fourier_defaults, VIDEO_FRAMERATE + +from .weights import WEIGHTS_PATH +from .utils import load_weights +from .modules.spec2vec import Spec2Vec +from .modules.st_gcn import ST_GCN +from .transformers import * + +DURATION = {'vovit_speech': 2, 'vovit_singing_voice': 4} + +_base_args = {'d_model': 512, + 'nhead': 8, + 'dim_feedforward': 1024, + 'dropout': 0.3, + 'spec_freq_size': fourier_defaults['sp_freq_shape'], + 'skeleton_pooling': 'AdaptativeAP', + "graph_kwargs": { + "graph_cfg": { + "layout": "acappella", + "strategy": "spatial", + "max_hop": 1, + "dilation": 1}, + "edge_importance_weighting": "dynamic", + "dropout": False, + "dilated": False} + } +_stt_args = {'fusion_module': 'spectral_transformer'} +_singing_voice_args = {'num_encoder_layers': 4, + 'num_decoder_layers': 4, + 'n_temp_feats': 64 * DURATION['vovit_singing_voice'] + } +_speech_args = {'num_encoder_layers': 10, + 'num_decoder_layers': 10, + 'n_temp_feats': 64 * DURATION['vovit_speech']} + + +def copyupdt(original: dict, *args): + assert isinstance(original, dict) + new_dic = copy(original) + for arg in args: + assert isinstance(arg, dict) + new_dic.update(arg) + return new_dic + + +vovit_speech_args = copyupdt(_base_args, _stt_args, _speech_args) +vovit_singing_voice_args = copyupdt(_base_args, _stt_args, _singing_voice_args) + + +def complex_product(x, y): + assert x.shape == y.shape, "x and y must have the same shape" + assert x.shape[-1] == 2, "Last dimension must be 2" + real = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] + imag = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] + return torch.stack([real, imag], dim=-1) + + +def complex_division(x, y): + assert x.shape == y.shape, "x and y must have the same shape" + assert x.shape[-1] == 2, "Last dimension must be 2" + real = (x[..., 0] * y[..., 0] + x[..., 1] * y[..., 1]) / (y[..., 0] ** 2 + y[..., 1] ** 2) + imag = (x[..., 1] * y[..., 0] - x[..., 0] * y[..., 1]) / (y[..., 0] ** 2 + y[..., 1] ** 2) + return torch.stack([real, imag], dim=-1) + + +class AudioPreprocessor(nn.Module): + def __init__(self, *, + debug: dict, + audio_length: int, audio_samplerate: int, + n_fft: int, hop_length: int, sp_freq_shape: int, + downsample_coarse: bool): + super(AudioPreprocessor, self).__init__() + + self.downsample_coarse = downsample_coarse + self.debug = debug + + self._audio_samplerate = audio_samplerate + self._audio_length = audio_length + self._n_fft = n_fft + self._sp_freq_shape = sp_freq_shape + self._hop_length = hop_length + self.register_buffer('_window', torch.hann_window(self._n_fft), persistent=False) + + def wav2sp(self, x): + # CUDNN does not support half complex numbers for non-power2 windows + # Casting to float32 is a workaround + dtype = x.dtype + x = x.float() + s = spectrogram(x, pad=0, window=self._window.float(), win_length=self._n_fft, + n_fft=self._n_fft, hop_length=self._hop_length, + power=None, normalized=False, return_complex=False) + return s.to(dtype) + + def istft(self, x): + if not x.is_complex(): + x = x.float() + return istft(x, n_fft=self._n_fft, hop_length=self._hop_length, length=self._audio_length, + window=self._window.float()) + + def sp2wav(self, inference_mask, mixture, compute_wav): + if self.downsample_coarse: + inference_mask = torch.nn.functional.upsample(rearrange(inference_mask, 'b f t c -> b c f t'), + scale_factor=(2, 1), mode='nearest').squeeze( + 1) + inference_mask = rearrange(inference_mask, 'b c f t -> b f t c') + estimated_sp = complex_product(inference_mask, mixture) + if not compute_wav: + return None, estimated_sp + estimated_wav = self.istft(estimated_sp) + return estimated_wav, estimated_sp + + def preprocess_audio(self, *src: list, n_sources=2): + """ + Inputs contains the following keys: + audio: the main audio waveform of shape N,M + audio_acmt: the secondary audio waveform of shame N,M + src: If using inference on real mixtures, the mixture audio waveform of shape N,M + """ + + self.n_sources = n_sources + # Inference in case of a real sample + sp_mix_raw = self.wav2sp(src[0]).contiguous() / self.n_sources + + if self.downsample_coarse: + # Contiguous required to address memory problems in certain gpus + sp_mix = sp_mix_raw[:, ::2, ...].contiguous() # BxFxTx2 + x = rearrange(sp_mix, 'b f t c -> b c f t') + output = {'mixture': x, 'sp_mix_raw': sp_mix_raw} + + return output + + def get_inference_mask(self, logits_mask, sp_mix): + sp_mix = rearrange(sp_mix, 'b c f t -> b f t c') + inference_mask = rearrange(logits_mask, 'b c f t -> b f t c') + inference_mask = self.n_sources * inference_mask + target_sp = complex_product(inference_mask, sp_mix) + return inference_mask, target_sp + + +class AudioVisualNetwork(nn.Module): + def __init__(self, *, + audio_kwargs, + video_temporal_features: int, + landmarks_enabled: bool, + n=1, **kwargs): + """ + :param audio_model: + :param audio_kwargs: + :param video_enabled: bool Whether to use video features or not + :param video_temporal_features: int Amount of visual temporal features to use (controls video upsampling and v + video max pool) + :param landmarks_enabled: bool Whether to use landmark and graph-cnn + :param single_frame_enabled: bool Whether to use appearance features extracted by a cnn or not + :param single_emb_enabled: bool Whether to use appearance features pre-computed by a cnn or not + :param n: This model works on 4s audio tracks. n is a multiplier for larger tracks 8s track--> n=2 + """ + super(AudioVisualNetwork, self).__init__() + self.audio_processor = self.ap = AudioPreprocessor(**audio_kwargs) + # Placeholder + self._n = n + self.feat_num = 0 + + # Flags + self.video_temporal_features = video_temporal_features + self.landmarks_enabled = landmarks_enabled + + self._define_graph_network(kwargs) + self._define_audio_network(kwargs) + + def _define_audio_network(self, kwargs): + # Defining audio model /Stratum + self.audio_network = SAplusF(input_dim=self.feat_num, **kwargs) + + def _define_graph_network(self, kwargs): + if self.landmarks_enabled: + self.feat_num += 256 + # Graph convolutional network for skeleton analysis + if kwargs['skeleton_pooling'] == 'AdaptativeAP': + self.pool = nn.AdaptiveAvgPool2d((None, 1)) + elif kwargs['skeleton_pooling'] == 'AdaptativeMP': + self.pool = nn.AdaptiveMaxPool2d((None, 1)) + elif kwargs['skeleton_pooling'] == 'linear': + self.pool = nn.Linear(self.graph_net.heads[0].graph.num_node, 1, bias=False) + else: + raise ValueError( + 'VnNet pooling type: %s not implemented. Choose between AdaptativeMP,AdaptativeMP or linear' % + kwargs['skeleton_pooling']) + + if kwargs['graph_kwargs']['graph_cfg']['layout'] == 'upperbody_with_hands': + in_channels = 3 + elif kwargs['graph_kwargs']['graph_cfg']['layout'] == 'acappella': + in_channels = 2 + else: + raise NotImplementedError + flag = self.video_temporal_features < 50 + + self.graph_net = ST_GCN(in_channels=in_channels, **kwargs['graph_kwargs'], temporal_downsample=flag) + + def forward(self, inputs: dict, + compute_wav=True): + # Placeholder + output = {'logits_mask': None, + 'inference_mask': None, + 'loss_mask': None, + 'gt_mask': None, + 'separation_loss': None, + 'alignment_loss': None, + 'estimated_sp': None, + 'estimated_wav': None} + + landmarks = inputs['landmarks'] + + # ========================================== + + audio_feats = self.audio_processor.preprocess_audio(inputs['src']) + + """ + mixture: ready to fed the network + sources raw: list of all the independent sources before downsampling + weight: gradient penalty term for the loss + sp_mix_raw: mixture spectrogram before downsampling + """ + + # ========================================== + # Generating visual features + visual_features = self.forward_visual(landmarks) + + # ========================================== + + # Computing audiovisual prediction + pred = self.forward_audiovisual(audio_feats, visual_features) + + # ========================================== + + logits_mask = pred + output['logits_mask'] = logits_mask + inference_mask, target_sp = self.audio_processor.get_inference_mask(logits_mask, audio_feats['mixture']) + output['inference_mask'] = inference_mask + # Upsampling must be carried out on the mask, NOT the spectrogram + # https://www.juanmontesinos.com/posts/2021/02/08/bss-masking/ + + estimated_wav, estimated_sp = self.ap.sp2wav(inference_mask, + audio_feats['sp_mix_raw'], + compute_wav) + output['estimated_sp'] = estimated_sp + output['estimated_wav'] = estimated_wav + + output['mix_sp'] = torch.view_as_complex(audio_feats['sp_mix_raw']) + return output + + def forward_audiovisual(self, audio_feats, visual_features): + pred = self.audio_network(audio_feats, visual_features) + return pred + + def forward_visual(self, landmarks): + sk_features = self.graph_net(landmarks) + sk_features = self.pool(sk_features).squeeze(3) + sk_features = torch.nn.functional.interpolate(sk_features, + size=self.video_temporal_features * self._n) + sk_features = rearrange(sk_features, 'b c t -> b t c') + return sk_features + + +class SAplusF(nn.Module): + def __init__(self, *, + fusion_module: str, + input_dim: int, + last_shape=8, + **kwargs): + super(SAplusF, self).__init__() + self.audio_net = Spec2Vec(last_shape) + self.fusion_module = self._set_fusion_module(fusion_module, input_dim * (last_shape + 1), **kwargs) + + def _set_fusion_module(self, *args, **kwargs): + transformer_kw = {} + for arg in inspect.getfullargspec(AVSpectralTransformer.__init__).kwonlyargs: + transformer_kw[arg] = kwargs[arg] + fusion_module = AVSpectralTransformer(**transformer_kw) + + return fusion_module + + def forward(self, audio: dict, video_feats): + """ + :param audio: spectrogram of the mixture B C F T + :param video_feats: video features of the tgt speaker + :param audio_clean: spectrogram of the tgt speaker (if training) B C F T + """ + # input_audio will be (N,2,256,256) + # input video feats [N, 256, 256] + input_audio = rearrange(audio['mixture'], 'b c f t -> b c t f') # Freqxtime required for audio network + audio_feats = self.audio_net(input_audio) # [N, 2048, 256, 1] + # audio_feats will be (N,8*256,256,1) + audio_feats = rearrange(audio_feats.squeeze(-1), 'b feats t -> b t feats') + complex_mask = self.fusion_module(video_feats, audio_feats, audio) + + return complex_mask + + +class RefinementAVSE_LowLatency(nn.Module): + + def __init__(self, av_se: nn.Module): + super(RefinementAVSE_LowLatency, self).__init__() + from flerken.models import UNet + self.av_se = av_se + self.unet = UNet(mode='upsample', architecture='sop', layer_kernels="ssssst", + output_channels=1, + film=None, + useBN=True, + activation=torch.sigmoid, + layer_channels=[32, 64, 128, 256, 256, 256, 512]) + + def forward_avse(self, inputs, compute_istft: bool): + self.av_se.eval() + output = self.av_se(inputs, compute_wav=compute_istft) + return output + + def forward(self, *args, **kwargs): + return self.inference(*args, **kwargs) + + def inference(self, inputs: dict, n_iter=1): + with torch.no_grad(): + output = self.forward_avse(inputs, compute_istft=False) + estimated_sp = output['estimated_sp'] + for i in range(n_iter): + estimated_sp_mag = estimated_sp.norm(dim=-1) + mask = self.unet(estimated_sp_mag[:, None].contiguous())[:, 0] + estimated_sp = mask.unsqueeze(-1) * estimated_sp + output[f'estimated_sp_{i}'] = estimated_sp + output['ref_mask'] = mask + output['ref_est_sp'] = estimated_sp + output['ref_est_wav'] = self.av_se.ap.istft(output['ref_est_sp']) + return output + + +class VoViT(nn.Module): + def __init__(self, *, model_name: str, debug: dict, pretrained: bool): + super(VoViT, self).__init__() + self.pretrained = pretrained + self.debug = debug + + assert model_name.lower() in ['vovit_speech', 'vovit_singing_voice'] + self.avse = self._instantiate_avse_model(debug, model_name.lower()) + + def _instantiate_avse_model(self, debug, model_name: str): + audio_kw = copy(fourier_defaults) + sr = fourier_defaults['audio_samplerate'] + audio_kw.update({'audio_length': DURATION[model_name] * sr - 1, 'downsample_coarse': True, 'debug': debug}) + + model = AudioVisualNetwork(audio_kwargs=audio_kw, + video_temporal_features=64 * DURATION[model_name], + landmarks_enabled=True, + **globals()[model_name + '_args']) + if self.pretrained: + state_dict = load_weights(os.path.join(WEIGHTS_PATH, model_name + '.pth')) + model.load_state_dict(state_dict) + print('VoViT pre-trained weights loaded') + if model_name == 'vovit_speech': + model = RefinementAVSE_LowLatency(model) + if self.pretrained: + state_dict = load_weights(os.path.join(WEIGHTS_PATH, 'refinement_avse_low_latency.pth')) + model.unet.load_state_dict(state_dict) + print('Lead Voice enhancer pre-trained weights loaded') + + return model + + def forward(self, mixture, landmarks): + """ + mixture: torch.Tensor (B,N) + """ + inputs = {'src': mixture, 'landmarks': landmarks} + return self.avse(inputs) diff --git a/vovit/core/models/transformers/__init__.py b/vovit/core/models/transformers/__init__.py new file mode 100644 index 0000000..bd6488d --- /dev/null +++ b/vovit/core/models/transformers/__init__.py @@ -0,0 +1,48 @@ +import numpy as np +import torch +from torch import nn + +__all__ = ['get_sinusoid_encoding_table', 'AVSpectralTransformer'] + + +def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + ''' Sinusoid position encoding table ''' + + def cal_angle(position, hid_idx): + return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) + + def get_posi_angle_vec(position): + return [cal_angle(position, hid_j) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) + + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + if padding_idx is not None: + # zero vector for padding dimension + sinusoid_table[padding_idx] = 0. + + return torch.FloatTensor(sinusoid_table) + + +class BaseFusionModule(nn.Module): + def forward(self, v_feats: torch.Tensor, a_feats: torch.Tensor, *args) -> torch.Tensor: + """ + :param v_feats: Visual features from tgt speaker. BxTxC + :param a_feats: Audio features from the mixture. BxT'xC' (T may be equal to T' for some models) + :param args: Place-holder for other models which require extra info + :return: Complex mask which applied over the mixture estimate the clean audio. + """ + raise NotImplementedError + + def generate_square_subsequent_mask(self, sz: int, device: torch.device): + r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + +from .av_sttrans import AVSpectralTransformer \ No newline at end of file diff --git a/vovit/core/models/transformers/av_sttrans.py b/vovit/core/models/transformers/av_sttrans.py new file mode 100644 index 0000000..2df98df --- /dev/null +++ b/vovit/core/models/transformers/av_sttrans.py @@ -0,0 +1,130 @@ +""" +Inspired by https://arxiv.org/pdf/1911.09783.pdf + WildMix Dataset and Spectro-Temporal Transformer Model + for Monoaural Audio Source Separation +Source code (not used but can help others) +https://github.com/tianjunm/monaural-source-separation/blob/fd773aec28d4dee54746e340c30c855b59b5f6ab/models/stt_aaai.py +""" + +import torch +from einops import rearrange +from torch import nn +from torch.nn import TransformerEncoderLayer as TorchTFL, TransformerEncoder, LayerNorm + +from . import BaseFusionModule, get_sinusoid_encoding_table + + +class TransformerEncoderLayer(TorchTFL): + def _sa_block(self, x, + attn_mask, key_padding_mask): + x = self.self_attn(x, x, x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=True)[0] + return self.dropout1(x) + + +def build_encoder(d_model, nhead, dim_feedforward, dropout, num_encoder_layers): + assert (d_model % nhead) == 0, f'Transformers d_model must be divisible by nhead but' \ + f' {d_model}/{nhead}={d_model // nhead}R{d_model % nhead}' + layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation='gelu') + encoder = TransformerEncoder(layer, num_encoder_layers, norm=LayerNorm(d_model)) + return encoder + + +class STEncoder(nn.Module): + """ + Builds either an spectral encoder or a temporal encoder. + """ + + def __init__(self, n_temp_feats, d_model: int, nhead: int, num_encoder_layers: int, + dim_feedforward: int, dropout: float): + super(STEncoder, self).__init__() + self.encoder = build_encoder(d_model, nhead, dim_feedforward, dropout, num_encoder_layers) + self.pos_emb = nn.Embedding.from_pretrained( + get_sinusoid_encoding_table(n_temp_feats + 1, d_model, padding_idx=0), + freeze=True) + + def forward(self, av_feats: torch.Tensor) -> torch.Tensor: + """ + :param av_feats: AudioVisual signal of shape BxTxC if temporal or BxCxT if spectral + """ + B, T, C = av_feats.shape + pos = torch.arange(1, T + 1, device=av_feats.device) + pos = self.pos_emb(pos) + av_feats += pos + av_feats = self.encoder(rearrange(av_feats, 'b t c -> t b c')) + return av_feats + + +class SpectroTemporalEncoder(nn.Module): + def __init__(self, *, n_temp_feats, n_channels, nhead: int, num_encoder_layers: int, + dim_feedforward: int, dropout: float): + super(SpectroTemporalEncoder, self).__init__() + self.temporal_enc = STEncoder(n_temp_feats, n_channels, nhead, num_encoder_layers, dim_feedforward, dropout) + self.spectral_enc = STEncoder(n_channels, n_temp_feats, nhead, num_encoder_layers, dim_feedforward, dropout) + self.av_dim_adaptation = nn.Sequential(nn.LazyLinear(n_channels), nn.LeakyReLU(0.1)) + + def forward(self, av_feats): + """ + :param av_feats: BxTxC + :return: TxBxC set of features + """ + av_feats = self.av_dim_adaptation(av_feats) + temp_feats = self.temporal_enc(av_feats) + + # Note that the spectral encoder is gonna permute b t c-> c b t + spectral_feats = self.spectral_enc(rearrange(av_feats, 'b t c -> b c t')) + spectral_feats = rearrange(spectral_feats, 'c b t -> t b c') + feats = spectral_feats + temp_feats + # feats = spectral_feats + return feats + + + +class AVSpectralTransformer(BaseFusionModule): + def __init__(self, *, n_temp_feats, d_model: int, nhead: int, num_encoder_layers: int, num_decoder_layers: int, + dim_feedforward: int, dropout: float, spec_freq_size: int): + super(AVSpectralTransformer, self).__init__() + self.encoder = SpectroTemporalEncoder(n_temp_feats=n_temp_feats, n_channels=d_model, nhead=nhead, + num_encoder_layers=num_encoder_layers, dim_feedforward=dim_feedforward, + dropout=dropout) + self.decoder = build_encoder(d_model, nhead, dim_feedforward, dropout, num_decoder_layers) + self.feats2mask = nn.Sequential(nn.LazyLinear(2 * d_model), nn.LeakyReLU(0.5), + nn.Linear(2 * d_model, spec_freq_size)) + self.d_model = d_model + self.spec_freq_size = spec_freq_size + + def forward(self, v_feats: torch.Tensor, a_feats: torch.Tensor, *args) -> torch.Tensor: + B, T, C = a_feats.shape + av_feats = torch.cat([v_feats, a_feats], dim=-1) + + memory = self.encoder(av_feats) + latent_feats = self.decoder(memory) + latent_feats = self.feats2mask(latent_feats) + mask = rearrange(latent_feats, 't b c -> b c t ').reshape(B, 2, self.spec_freq_size // 2, T) + return mask + + +class ConvTemporalEncoder(nn.Module): + def __init__(self, *, n_temp_feats, n_channels, nhead: int, num_encoder_layers: int, + dim_feedforward: int, dropout: float): + super(ConvTemporalEncoder, self).__init__() + self.temporal_enc = STEncoder(n_temp_feats, n_channels, nhead, num_encoder_layers, dim_feedforward, dropout) + self.spectral_enc = STEncoder(n_channels, n_temp_feats, nhead, num_encoder_layers, dim_feedforward, dropout) + self.av_dim_adaptation = nn.Sequential(nn.LazyLinear(n_channels), nn.LeakyReLU(0.1)) + + def forward(self, av_feats): + """ + :param av_feats: BxTxC + :return: TxBxC set of features + """ + av_feats = self.av_dim_adaptation(av_feats) + temp_feats = self.temporal_enc(av_feats) + + # Note that the spectral encoder is gonna permute b t c-> c b t + spectral_feats = self.spectral_enc(rearrange(av_feats, 'b t c -> b c t')) + spectral_feats = rearrange(spectral_feats, 'c b t -> t b c') + feats = spectral_feats + temp_feats + # feats = spectral_feats + return feats diff --git a/vovit/core/models/utils.py b/vovit/core/models/utils.py new file mode 100644 index 0000000..de6df6f --- /dev/null +++ b/vovit/core/models/utils.py @@ -0,0 +1,12 @@ +import torch + + +def load_weights(path): + state_dict = torch.load(path, map_location=lambda storage, loc: storage) + keys = ["audio_processor._window", "audio_processor.sp2mel.fb", "audio_processor.mel2sp.fb", "ap._window", + "ap.sp2mel.fb", "ap.mel2sp.fb", "audio_processor.wav2sp.window", "ap.wav2sp.window"] + for key in keys: + if key in state_dict: + del state_dict[key] + return state_dict + diff --git a/vovit/core/speech_mean_face.npy b/vovit/core/speech_mean_face.npy new file mode 100644 index 0000000000000000000000000000000000000000..075fedd49c3c09835b06d9db590a786c0b7202dd GIT binary patch literal 1760 zcmbVM`%_e97#1(7ba4ns7-YmKR}sq#AR?nrK*VLGW%ryt=j@)d7xsdrM6xJ3?{eZF^DD8Sx_mUv$a zu;*!WQr>7>ZPWr#U$Y;Ns(=uUNE*1=4SXXg74>ObRzuB+8gfFig{A2BoLewqo zOshR@#%cGN@zejYVBo;xKm0mILeGYU%~^vArtkt=3>-e1VGe!ZMf~}(Qr8|I_+M{U z0K+LI1>K<_JnZ_a!64?DLlYmQmm{ZWPaw}-j^BA3y&Us-p{EPczUk(uwyS37344z} zI&4AH2@b@DJIeB$qFxys-o*qro?s}}ba-&iU|&X-!I6J2;>YDags-*v@%+%f%qiC$ zbbquH#M24hUe=xH>b*4U+-?>7Gg;%i8rAv6iQDz3HbyrXnA!!&e+wU%`+>>R>8Ow(K^_vJlb5^$-rq*j(evOmx1r;!rI@A8P625q|z*PXI=Fq-ldEoLN z&B!JHWar`*qqTQ_p6HW9{5k=C${bpREhm{!3n2WPOa6&9BgV=U|H1-pE2q}?0iv;gt@>7u+GuZkZ`PKkdu+59sNzcm;;U_1qYs`u7IlyV$ zhg0}p7jt%DdpX<47&rNUiJR)ziszdy+PfA*%vBeD*1_C`6{$MqL8C5s8pPubgK%v! z$j4Dpg2z+j)BZ`=#@6Ee*r}wSmPXI(bA+GGq~|!Bn4{S1oi_Z*#+&(ZiQxaDsGB9u z%>lQ`y@;9MW6|+`(@;Wg{GQGUhP>E24QS1Br;Yf!~Ls3K?V7|jhF zCsX|wtfb!!mHam3pt+_Sc$>w8^{|1TSX$OUR4A{KHT=x%cf3a np.ndarray: """ Cast an audio array in integer format into float scaling properly .