Spaces:
Building
Building
import torch | |
from typing import List | |
class FilterBank(torch.nn.Module): | |
"""Convolution filter bank (linear) | |
Serves as an embedding for the audio signal | |
""" | |
def __init__(self, ch_in: int, out_dim=16, k_size=5, dilation_list: List[int] = [1, 2, 4, 8]): | |
super().__init__() | |
self.out_dim = out_dim | |
self.source_modality_conv = torch.nn.ModuleList() | |
for dilation in dilation_list: | |
self.source_modality_conv.append( | |
torch.nn.Conv1d(ch_in, out_dim//len(dilation_list), k_size, dilation=dilation, padding=(dilation*(k_size//2))) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
out = torch.cat([conv(x) for conv in self.source_modality_conv], axis=1) | |
assert out.shape[1] == self.out_dim | |
return out | |
class ResConvolution(torch.nn.Module): | |
"""ResNet building block | |
https://paperswithcode.com/method/residual-connection | |
""" | |
def __init__(self, ch, hdim=None, k_size=5): | |
super().__init__() | |
hdim = hdim or ch | |
self.conv1 = torch.nn.Conv1d(ch, hdim, k_size, padding=k_size//2) | |
self.conv2 = torch.nn.Conv1d(hdim, ch, k_size, padding=k_size//2) | |
self.non_linearity = torch.nn.ReLU() | |
def forward(self, x_in): | |
x = self.conv1(x_in) | |
x = self.non_linearity(x) | |
x = self.conv2(x) | |
x += x_in | |
x = self.non_linearity(x) | |
return x | |
if __name__ == "__main__": | |
model = FilterBank(1, 16) | |
inp = torch.rand(2, 1, 2048) | |
out = model(inp) | |
print(model) | |
print(out[0].shape) | |