Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torchlibrosa.stft import STFT, LogmelFilterBank | |
from timm.models.layers import to_2tuple | |
from .vision_transformer import VisionTransformer as _VisionTransformer | |
def conv3x3(in_channels, out_channels, stride=1): | |
"3x3 convolution with padding" | |
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) | |
class PatchEmbed_new(nn.Module): | |
""" Flexible Image to Patch Embedding | |
""" | |
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10): | |
super().__init__() | |
img_size = to_2tuple(img_size) | |
patch_size = to_2tuple(patch_size) | |
stride = to_2tuple(stride) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.in_chans = in_chans | |
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches | |
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w | |
self.patch_hw = (h, w) | |
self.num_patches = h*w | |
def get_output_shape(self, img_size): | |
return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape | |
def forward(self, x): | |
B, C, H, W = x.shape | |
x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 | |
x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 | |
x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 | |
return x | |
class BinauralEncoder(_VisionTransformer): | |
""" Spatial Audio Spectrogram Transformer designed for Sound Event Localization and Detection | |
-------------------------------------------------------- | |
References: | |
Spatial-AST from BAT: https://github.com/zszheng147/Spatial-AST and https://arxiv.org/abs/2402.01591 | |
-------------------------------------------------------- | |
""" | |
def __init__(self, num_cls_tokens=3, **kwargs): | |
super(BinauralEncoder, self).__init__(**kwargs) | |
img_size = (1024, 128) # 1024, 128 | |
in_chans = 1 | |
emb_dim = 768 | |
del self.cls_token | |
self.num_cls_tokens = num_cls_tokens | |
self.cls_tokens = nn.Parameter(torch.zeros(1, num_cls_tokens, emb_dim)) | |
self.patch_embed = PatchEmbed_new( | |
img_size=img_size, patch_size=(16, 16), | |
in_chans=in_chans, embed_dim=emb_dim, stride=16 | |
) # no overlap. stride=img_size=16 | |
num_patches = self.patch_embed.num_patches | |
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False) # fixed sin-cos embedding | |
self.spectrogram_extractor = STFT( | |
n_fft=1024, hop_length=320, win_length=1024, window='hann', | |
center=True, pad_mode='reflect', freeze_parameters=True | |
) | |
self.logmel_extractor = LogmelFilterBank( | |
sr=32000, n_fft=1024, n_mels=128, fmin=50, | |
fmax=14000, ref=1.0, amin=1e-10, top_db=None, freeze_parameters=True | |
) | |
self.conv_downsample = nn.Sequential( | |
conv3x3(4, 1), | |
nn.BatchNorm2d(1), | |
nn.GELU(), | |
) | |
self.bn = nn.BatchNorm2d(2, affine=False) | |
del self.norm # remove the original norm | |
self.target_frame = 1024 | |
def forward_features_mask(self, x): | |
B = x.shape[0] #bsz, 512, 768 (unmasked) | |
x = x + self.pos_embed[:, 1:, :] | |
cls_tokens = self.cls_tokens | |
cls_tokens = cls_tokens.expand(B, -1, -1) | |
x = torch.cat([cls_tokens, x], dim=1) # bsz, 512 + 2 + 10, 768 | |
x = self.pos_drop(x) | |
for blk in self.blocks: | |
x = blk(x) | |
return x | |
def forward(self, waveforms): | |
B, C, T = waveforms.shape | |
waveforms = waveforms.reshape(B * C, T) | |
real, imag = self.spectrogram_extractor(waveforms) | |
log_mel = self.logmel_extractor(torch.sqrt(real**2 + imag**2)).reshape(B, C, -1, 128) | |
log_mel = self.bn(log_mel) | |
IPD = torch.atan2(imag[1::2], real[1::2]) - torch.atan2(imag[::2], real[::2]) | |
x = torch.cat([log_mel, torch.matmul(torch.cat([torch.cos(IPD), torch.sin(IPD)], dim=1), self.logmel_extractor.melW)], dim=1) | |
if x.shape[2] < self.target_frame: | |
x = nn.functional.interpolate(x, (self.target_frame, x.shape[3]), mode="bicubic", align_corners=True) | |
x = self.conv_downsample(x) | |
x = self.patch_embed(x) | |
x = self.forward_features_mask(x) | |
return x |