|
import numpy as np |
|
import scipy.linalg |
|
from scipy.spatial.transform import Rotation as R |
|
import torch as th |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from src.warping import GeometricTimeWarper, MonotoneTimeWarper |
|
from src.utils import Net |
|
|
|
|
|
class GeometricWarper(nn.Module): |
|
def __init__(self, sampling_rate=48000): |
|
super().__init__() |
|
self.warper = GeometricTimeWarper(sampling_rate=sampling_rate) |
|
|
|
def _transmitter_mouth(self, view): |
|
|
|
mouth_offset = np.array([0.09, 0, -0.20]) |
|
quat = view[:, 3:, :].transpose(2, 1).contiguous().detach().cpu().view(-1, 4).numpy() |
|
|
|
norms = scipy.linalg.norm(quat, axis=1) |
|
eps_val = (norms == 0).astype(np.float32) |
|
quat = quat + eps_val[:, None] |
|
transmitter_rot_mat = R.from_quat(quat) |
|
transmitter_mouth = transmitter_rot_mat.apply(mouth_offset, inverse=True) |
|
transmitter_mouth = th.Tensor(transmitter_mouth).view(view.shape[0], -1, 3).transpose(2, 1).contiguous() |
|
if view.is_cuda: |
|
transmitter_mouth = transmitter_mouth.cuda() |
|
return transmitter_mouth |
|
|
|
def _3d_displacements(self, view): |
|
transmitter_mouth = self._transmitter_mouth(view) |
|
|
|
left_ear_offset = th.Tensor([0, -0.08, -0.22]).cuda() if view.is_cuda else th.Tensor([0, -0.08, -0.22]) |
|
right_ear_offset = th.Tensor([0, 0.08, -0.22]).cuda() if view.is_cuda else th.Tensor([0, 0.08, -0.22]) |
|
|
|
displacement_left = view[:, 0:3, :] + transmitter_mouth - left_ear_offset[None, :, None] |
|
displacement_right = view[:, 0:3, :] + transmitter_mouth - right_ear_offset[None, :, None] |
|
displacement = th.stack([displacement_left, displacement_right], dim=1) |
|
return displacement |
|
|
|
def _warpfield(self, view, seq_length): |
|
return self.warper.displacements2warpfield(self._3d_displacements(view), seq_length) |
|
|
|
def forward(self, mono, view): |
|
''' |
|
:param mono: input signal as tensor of shape B x 1 x T |
|
:param view: rx/tx position/orientation as tensor of shape B x 7 x K (K = T / 400) |
|
:return: warped: warped left/right ear signal as tensor of shape B x 2 x T |
|
''' |
|
return self.warper(th.cat([mono, mono], dim=1), self._3d_displacements(view)) |
|
|
|
|
|
class Warpnet(nn.Module): |
|
def __init__(self, layers=4, channels=64, view_dim=7): |
|
super().__init__() |
|
self.layers = [nn.Conv1d(view_dim if l == 0 else channels, channels, kernel_size=2) for l in range(layers)] |
|
self.layers = nn.ModuleList(self.layers) |
|
self.linear = nn.Conv1d(channels, 2, kernel_size=1) |
|
self.neural_warper = MonotoneTimeWarper() |
|
self.geometric_warper = GeometricWarper() |
|
|
|
def neural_warpfield(self, view, seq_length): |
|
warpfield = view |
|
for layer in self.layers: |
|
warpfield = F.pad(warpfield, pad=[1, 0]) |
|
warpfield = F.relu(layer(warpfield)) |
|
warpfield = self.linear(warpfield) |
|
warpfield = F.interpolate(warpfield, size=seq_length) |
|
return warpfield |
|
|
|
def forward(self, mono, view): |
|
''' |
|
:param mono: input signal as tensor of shape B x 1 x T |
|
:param view: rx/tx position/orientation as tensor of shape B x 7 x K (K = T / 400) |
|
:return: warped: warped left/right ear signal as tensor of shape B x 2 x T |
|
''' |
|
geometric_warpfield = self.geometric_warper._warpfield(view, mono.shape[-1]) |
|
neural_warpfield = self.neural_warpfield(view, mono.shape[-1]) |
|
warpfield = geometric_warpfield + neural_warpfield |
|
|
|
warpfield = -F.relu(-warpfield) |
|
warped = self.neural_warper(th.cat([mono, mono], dim=1), warpfield) |
|
return warped |
|
|
|
class BinauralNetwork(Net): |
|
def __init__(self, |
|
view_dim=7, |
|
warpnet_layers=4, |
|
warpnet_channels=64, |
|
model_name='binaural_network', |
|
use_cuda=True): |
|
super().__init__(model_name, use_cuda) |
|
self.warper = Warpnet(warpnet_layers, warpnet_channels) |
|
if self.use_cuda: |
|
self.cuda() |
|
|
|
def forward(self, mono, view): |
|
''' |
|
:param mono: the input signal as a B x 1 x T tensor |
|
:param view: the receiver/transmitter position as a B x 7 x T tensor |
|
:return: out: the binaural output produced by the network |
|
intermediate: a two-channel audio signal obtained from the output of each intermediate layer |
|
as a list of B x 2 x T tensors |
|
''' |
|
|
|
|
|
warped = self.warper(mono, view) |
|
|
|
return warped |
|
|