import torch import torch.nn as nn import torch.nn.functional as F from torchaudio import transforms as taT, functional as taF DEVICE = "mps" class ShortChunkCNN(nn.Module): def __init__(self, n_channels=128, sample_rate=16000, n_class=50): super().__init__() # Spectrogram self.spec_bn = nn.BatchNorm2d(1) # CNN self.res_layers = nn.Sequential( Res_2d(1, n_channels, stride=2), Res_2d(n_channels, n_channels, stride=2), Res_2d(n_channels, n_channels*2, stride=2), Res_2d(n_channels*2, n_channels*2, stride=2), Res_2d(n_channels*2, n_channels*2, stride=2), Res_2d(n_channels*2, n_channels*2, stride=2), Res_2d(n_channels*2, n_channels*4, stride=2) ) # Dense self.dense1 = nn.Linear(n_channels*4, n_channels*4) self.bn = nn.BatchNorm1d(n_channels*4) self.dense2 = nn.Linear(n_channels*4, n_class) self.dropout = nn.Dropout(0.3) def forward(self, x): x = self.spec_bn(x) # CNN x = self.res_layers(x) x = x.squeeze(2) # Global Max Pooling if x.size(-1) != 1: x = nn.MaxPool1d(x.size(-1))(x) x = x.squeeze(2) # Dense x = self.dense1(x) x = self.bn(x) x = F.relu(x) x = self.dropout(x) x = self.dense2(x) x = nn.Sigmoid()(x) return x class Res_2d(nn.Module): def __init__(self, input_channels, output_channels, shape=3, stride=2): super().__init__() # convolution self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) self.bn_1 = nn.BatchNorm2d(output_channels) self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2) self.bn_2 = nn.BatchNorm2d(output_channels) # residual self.diff = False if (stride != 1) or (input_channels != output_channels): self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) self.bn_3 = nn.BatchNorm2d(output_channels) self.diff = True self.relu = nn.ReLU() def forward(self, x): # convolution out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) # residual if self.diff: x = self.bn_3(self.conv_3(x)) out = x + out out = self.relu(out) return out