xcczach's picture
Upload 73 files
35c1cfd verified
import torch
import torch.nn as nn
from torchlibrosa.stft import STFT, LogmelFilterBank
from timm.models.layers import to_2tuple
from .vision_transformer import VisionTransformer as _VisionTransformer
def conv3x3(in_channels, out_channels, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
class PatchEmbed_new(nn.Module):
""" Flexible Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
self.patch_hw = (h, w)
self.num_patches = h*w
def get_output_shape(self, img_size):
return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
return x
class BinauralEncoder(_VisionTransformer):
""" Spatial Audio Spectrogram Transformer designed for Sound Event Localization and Detection
--------------------------------------------------------
References:
Spatial-AST from BAT: https://github.com/zszheng147/Spatial-AST and https://arxiv.org/abs/2402.01591
--------------------------------------------------------
"""
def __init__(self, num_cls_tokens=3, **kwargs):
super(BinauralEncoder, self).__init__(**kwargs)
img_size = (1024, 128) # 1024, 128
in_chans = 1
emb_dim = 768
del self.cls_token
self.num_cls_tokens = num_cls_tokens
self.cls_tokens = nn.Parameter(torch.zeros(1, num_cls_tokens, emb_dim))
self.patch_embed = PatchEmbed_new(
img_size=img_size, patch_size=(16, 16),
in_chans=in_chans, embed_dim=emb_dim, stride=16
) # no overlap. stride=img_size=16
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False) # fixed sin-cos embedding
self.spectrogram_extractor = STFT(
n_fft=1024, hop_length=320, win_length=1024, window='hann',
center=True, pad_mode='reflect', freeze_parameters=True
)
self.logmel_extractor = LogmelFilterBank(
sr=32000, n_fft=1024, n_mels=128, fmin=50,
fmax=14000, ref=1.0, amin=1e-10, top_db=None, freeze_parameters=True
)
self.conv_downsample = nn.Sequential(
conv3x3(4, 1),
nn.BatchNorm2d(1),
nn.GELU(),
)
self.bn = nn.BatchNorm2d(2, affine=False)
del self.norm # remove the original norm
self.target_frame = 1024
def forward_features_mask(self, x):
B = x.shape[0] #bsz, 512, 768 (unmasked)
x = x + self.pos_embed[:, 1:, :]
cls_tokens = self.cls_tokens
cls_tokens = cls_tokens.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # bsz, 512 + 2 + 10, 768
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
return x
@torch.no_grad()
def forward(self, waveforms):
B, C, T = waveforms.shape
waveforms = waveforms.reshape(B * C, T)
real, imag = self.spectrogram_extractor(waveforms)
log_mel = self.logmel_extractor(torch.sqrt(real**2 + imag**2)).reshape(B, C, -1, 128)
log_mel = self.bn(log_mel)
IPD = torch.atan2(imag[1::2], real[1::2]) - torch.atan2(imag[::2], real[::2])
x = torch.cat([log_mel, torch.matmul(torch.cat([torch.cos(IPD), torch.sin(IPD)], dim=1), self.logmel_extractor.melW)], dim=1)
if x.shape[2] < self.target_frame:
x = nn.functional.interpolate(x, (self.target_frame, x.shape[3]), mode="bicubic", align_corners=True)
x = self.conv_downsample(x)
x = self.patch_embed(x)
x = self.forward_features_mask(x)
return x