balthou's picture
draft audio sep app
f6b56a2
import torch
from typing import List
class FilterBank(torch.nn.Module):
"""Convolution filter bank (linear)
Serves as an embedding for the audio signal
"""
def __init__(self, ch_in: int, out_dim=16, k_size=5, dilation_list: List[int] = [1, 2, 4, 8]):
super().__init__()
self.out_dim = out_dim
self.source_modality_conv = torch.nn.ModuleList()
for dilation in dilation_list:
self.source_modality_conv.append(
torch.nn.Conv1d(ch_in, out_dim//len(dilation_list), k_size, dilation=dilation, padding=(dilation*(k_size//2)))
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.cat([conv(x) for conv in self.source_modality_conv], axis=1)
assert out.shape[1] == self.out_dim
return out
class ResConvolution(torch.nn.Module):
"""ResNet building block
https://paperswithcode.com/method/residual-connection
"""
def __init__(self, ch, hdim=None, k_size=5):
super().__init__()
hdim = hdim or ch
self.conv1 = torch.nn.Conv1d(ch, hdim, k_size, padding=k_size//2)
self.conv2 = torch.nn.Conv1d(hdim, ch, k_size, padding=k_size//2)
self.non_linearity = torch.nn.ReLU()
def forward(self, x_in):
x = self.conv1(x_in)
x = self.non_linearity(x)
x = self.conv2(x)
x += x_in
x = self.non_linearity(x)
return x
if __name__ == "__main__":
model = FilterBank(1, 16)
inp = torch.rand(2, 1, 2048)
out = model(inp)
print(model)
print(out[0].shape)