CRSTC / src /MDFDSED.py
CAPYLEE's picture
Upload 218 files
dd1cb8f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import pandas as pd
from copy import deepcopy
########################################################################################################################
# DYconv #
########################################################################################################################
class Dynamic_conv2d(nn.Module):
def __init__(self, in_planes, out_planes, freq_size, kernel_size, stride=1, padding=0, groups=1, bias=False,
n_basis_kernels=4, temperature=31, reduction=4, pool_dim='freq', conv1d_kernel=[3, 1],
dilated_DY=0, dilation_size=[[0, 0], [0, 0], [0, 0], [0, 0]], dy_chan_proportion=None, aggconv=False):
super(Dynamic_conv2d, self).__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.n_basis_kernels = n_basis_kernels
self.pool_dim = pool_dim
self.groups = groups
self.dilated_DY = dilated_DY
self.dilation_size = dilation_size
self.dy_chan_proportion = dy_chan_proportion
self.aggconv = aggconv
if dy_chan_proportion is not None:
self.n_attention = len(dilation_size)
dy_out_planes = int(out_planes * dy_chan_proportion[0] / dy_chan_proportion[1])
self.dy_out_planes = dy_out_planes
self.stt_out_plane = out_planes - dy_out_planes * self.n_attention
if not self.dilated_DY:
self.dilation_size = []
for _ in range(self.n_attention):
if self.n_basis_kernels == 4:
self.dilation_size.append([[1, 1], [1, 1], [1, 1], [1, 1]])
if not aggconv:
if self.stt_out_plane > 0:
self.stt_conv = nn.Conv2d(in_planes, self.stt_out_plane, kernel_size, stride, padding, bias=bias)
self.weight = []
for n_bk in n_basis_kernels:
self.weight.append(nn.Parameter(torch.randn(n_bk, dy_out_planes, in_planes,
self.kernel_size, self.kernel_size)),
requires_grad=True)
for j in range(self.n_attention):
for i in range(self.n_basis_kernels):
nn.init.kaiming_normal_(self.weight[j, i])
self.bias = []
if bias:
self.bias.append(nn.Parameter(torch.Tensor(self.n_attention, n_basis_kernels, dy_out_planes),
requires_grad=True))
else:
self.bias = None
else:
output_sizes = [0, 0, 0]
for i in range(self.n_attention):
for dil in self.dilation_size[i]:
output_sizes[dil[1]-1] += 1
self.output_sizes = output_sizes
self.conv_dil1 = nn.Conv2d(in_planes,
self.stt_out_plane + dy_out_planes * output_sizes[0], kernel_size, stride,
self.padding, bias=bias)
if self.output_sizes[1] > 0:
self.conv_dil2 = nn.Conv2d(in_planes, dy_out_planes * output_sizes[1], kernel_size, stride,
(self.padding + 1, self.padding + 1), dilation=2, bias=bias)
if self.output_sizes[2] > 0:
self.conv_dil3 = nn.Conv2d(in_planes, dy_out_planes * output_sizes[2], kernel_size, stride,
(self.padding + 2, self.padding + 2), dilation=3, bias=bias)
self.attentions = []
if isinstance(n_basis_kernels, int):
n_basis_kernels = [n_basis_kernels] * self.n_attention
for i in range(self.n_attention):
if i == 0:
self.attention_0 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_0)
elif i == 1:
self.attention_1 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_1)
elif i == 2:
self.attention_2 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_2)
elif i == 3:
self.attention_3 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_3)
elif i == 4:
self.attention_4 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_4)
elif i == 5:
self.attention_5 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_5)
elif i == 6:
self.attention_6 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_6)
elif i == 7:
self.attention_7 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_7)
elif i == 8:
self.attention_8 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_8)
elif i == 9:
self.attention_9 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_9)
elif i == 10:
self.attention_10 = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_10)
elif i == 11:
self.attention_11= attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels[i], temperature, reduction, pool_dim)
self.attentions.append(self.attention_11)
else:
self.n_attention = 1
dy_out_planes = out_planes
self.dy_out_planes = out_planes
self.attention = attention2d(in_planes, conv1d_kernel, freq_size, self.stride,
n_basis_kernels, temperature, reduction, pool_dim)
self.weight = nn.Parameter(torch.randn(n_basis_kernels, dy_out_planes, in_planes,
self.kernel_size, self.kernel_size),
requires_grad=True)
for i in range(self.n_basis_kernels):
nn.init.kaiming_normal_(self.weight[i])
if bias:
self.bias = nn.Parameter(torch.Tensor(n_basis_kernels, dy_out_planes), requires_grad=True)
else:
self.bias = None
def forward(self, x): # x size : [bs, in_chan, frames, freqs]
if self.dy_chan_proportion is not None:
if not self.aggconv:
if self.stt_out_plane:
stt_output = self.stt_conv(x)
att_outputs = ()
for i in range(self.n_attention):
bias = self.bias[i] if self.bias is not None else None
att_output = self.attention_forward(x, self.attentions[i], self.dilation_size[i], self.weight[i], bias)
att_outputs += (att_output,)
else:
out_dil1 = self.conv_dil1(x)
if self.output_sizes[1] > 0:
out_dil2 = self.conv_dil2(x)
if self.output_sizes[2] > 0:
out_dil3 = self.conv_dil3(x)
if self.stt_out_plane > 0:
stt_output = out_dil1[:, :self.stt_out_plane]
conv_start_idxes = [self.stt_out_plane, 0, 0]
att_outputs = ()
for i in range(self.n_attention):
bk_outs = []
for dil_size in self.dilation_size[i]:
if dil_size[1] == 1:
bk_outs.append(out_dil1[:, conv_start_idxes[0]: conv_start_idxes[0] + self.dy_out_planes])
conv_start_idxes[0] += self.dy_out_planes
if dil_size[1] == 2:
bk_outs.append(out_dil2[:, conv_start_idxes[1]: conv_start_idxes[1] + self.dy_out_planes])
conv_start_idxes[1] += self.dy_out_planes
if dil_size[1] == 3:
bk_outs.append(out_dil3[:, conv_start_idxes[2]: conv_start_idxes[2] + self.dy_out_planes])
conv_start_idxes[2] += self.dy_out_planes
att_output = self.attention_forward_aggconv(x, self.attentions[i], bk_outs)
att_outputs += (att_output,)
if self.stt_out_plane > 0:
output = torch.cat((stt_output,) + att_outputs, dim=1)
else:
output = torch.cat(att_outputs, dim=1)
else:
output = self.attention_forward(x, self.attention, self.dilation_size[0], self.weight, self.bias)
return output
def attention_forward_aggconv(self, x, attention, bk_outs):
kernel_attention = attention(x) # kernel_attention size : [bs, n_ker, 1, 1, freqs]
output = torch.stack(bk_outs, dim=1)
if self.pool_dim in ['freq']:
assert kernel_attention.shape[-2] == output.shape[-2]
elif self.pool_dim in ['time']:
assert kernel_attention.shape[-1] == output.shape[-1]
output = torch.sum(output * kernel_attention, dim=1) # output size : [bs, out_chan, frames, freqs]
return output
def attention_forward(self, x, attention, dilation_size, weight, bias):
kernel_attention = attention(x) # kernel_attention size : [bs, n_ker, 1, 1, freqs]
if self.dilated_DY:
output = []
for i in range(self.n_basis_kernels):
padding = (self.padding + dilation_size[i][0] - 1, self.padding + dilation_size[i][1] - 1)
if bias is not None:
output.append(F.conv2d(x, weight=weight[i], bias=bias[i], stride=self.stride,
padding=padding, dilation=dilation_size[i], groups=self.groups))
else:
output.append(F.conv2d(x, weight=weight[i], bias=None, stride=self.stride,
padding=padding, dilation=dilation_size[i], groups=self.groups))
output = torch.stack(output, dim=1)
else:
aggregate_weight = weight.view(-1, self.in_planes, self.kernel_size, self.kernel_size)
# weight size : [n_ker * out_chan, in_chan, ks, ks]
if bias is not None:
aggregate_bias = bias.view(-1)
output = F.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
groups=self.groups)
else:
output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
groups=self.groups)
# output size : [bs, n_ker * out_chan, frames, freqs]
output = output.view(x.size(0), self.n_basis_kernels, self.dy_out_planes, output.size(-2), output.size(-1))
# output size : [bs, n_ker, out_chan, frames, freqs]
if self.pool_dim in ['freq']:
assert kernel_attention.shape[-2] == output.shape[-2]
elif self.pool_dim in ['time']:
assert kernel_attention.shape[-1] == output.shape[-1]
output = torch.sum(output * kernel_attention, dim=1) # output size : [bs, out_chan, frames, freqs]
return output
class attention2d(nn.Module):
def __init__(self, in_planes, kernel_size, freq_size, stride, n_basis_kernels,
temperature, reduction, pool_dim):
super(attention2d, self).__init__()
self.freq_size = freq_size
self.pool_dim = pool_dim
self.temperature = temperature
hidden_planes = in_planes // reduction
if hidden_planes < 4:
hidden_planes = 4
padding_1 = int((kernel_size[0] - 1) / 2)
padding_2 = int((kernel_size[1] - 1) / 2)
if pool_dim == 'both':
self.fc1 = nn.Linear(in_planes, hidden_planes)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(hidden_planes, n_basis_kernels)
else:
self.conv1d1 = nn.Conv1d(in_planes, hidden_planes, kernel_size[0], stride=stride, padding=padding_1,
bias=False)
self.bn = nn.BatchNorm1d(hidden_planes)
self.relu = nn.ReLU(inplace=True)
self.conv1d2 = nn.Conv1d(hidden_planes, n_basis_kernels, kernel_size[1], padding=padding_2, bias=True)
# initialize
if pool_dim in ["freq", "time"]:
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x): # x size : [bs, chan, frames, freqs]
### Pool dimensions and apply pre-processings
if self.pool_dim == 'freq': #TDY
x = torch.mean(x, dim=3) # x size : [bs, chan, frames]
elif self.pool_dim == 'time': #FDY
x = torch.mean(x, dim=2) # x size : [bs, chan, freqs]
elif self.pool_dim == 'both': #DY
x = F.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)
### extract attention weights
if self.pool_dim == 'both':
x = self.relu(self.fc1(x)) #x size : [bs, sqzd_chan]
att = self.fc2(x).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) #att size : [bs, n_ker, 1, 1, 1]
elif self.pool_dim == 'freq':
x = self.relu(self.bn(self.conv1d1(x))) #x size : [bs, sqzd_chan, frames]
att = self.conv1d2(x).unsqueeze(2).unsqueeze(4) #x size : [bs, n_ker, 1, frames, 1]
else: #self.pool_dim == 'time', FDY
x = self.relu(self.bn(self.conv1d1(x))) #x size : [bs, sqzd_chan, freqs]
att = self.conv1d2(x) #att size : [bs, n_ker, freqs]
att = att.unsqueeze(2).unsqueeze(3) #att size : [bs, n_ker, 1, 1, freqs]
return F.softmax(att / self.temperature, 1)
########################################################################################################################
# CRNN #
########################################################################################################################
class GLU(nn.Module):
def __init__(self, in_dim):
super(GLU, self).__init__()
self.sigmoid = nn.Sigmoid()
self.linear = nn.Linear(in_dim, in_dim)
def forward(self, x): #x size = [batch, chan, freq, frame]
lin = self.linear(x.permute(0, 2, 3, 1)) #x size = [batch, freq, frame, chan]
lin = lin.permute(0, 3, 1, 2) #x size = [batch, chan, freq, frame]
sig = self.sigmoid(x)
res = lin * sig
return res
class ContextGating(nn.Module):
def __init__(self, in_dim):
super(ContextGating, self).__init__()
self.sigmoid = nn.Sigmoid()
self.sigmoid = nn.Sigmoid()
self.linear = nn.Linear(in_dim, in_dim)
def forward(self, x): #x size = [batch, chan, freq, frame]
lin = self.linear(x.permute(0, 2, 3, 1)) #x size = [batch, freq, frame, chan]
lin = lin.permute(0, 3, 1, 2) #x size = [batch, chan, freq, frame]
sig = self.sigmoid(lin)
res = x * sig
# ores = x * sig
return res
class BiGRU(nn.Module):
def __init__(self, n_in, n_hidden, dropout=0, num_layers=1):
super(BiGRU, self).__init__()
self.rnn = nn.GRU(n_in, n_hidden, bidirectional=True, dropout=dropout, batch_first=True, num_layers=num_layers)
def forward(self, x):
#self.rnn.flatten_parameters()
x, _ = self.rnn(x)
return x
class DYCNN(nn.Module):
def __init__(self,
n_input_ch,
activation="Relu",
dropout=0,
kernel=[3, 3, 3],
pad=[1, 1, 1],
stride=[1, 1, 1],
dilation=[1, 1, 1],
n_filt=[64, 64, 64],
pooling=[(1, 4), (1, 4), (1, 4)],
pre_conv=None,
normalization="batch",
DY_layers=[0, 0, 0, 0, 0, 0, 0],
n_basis_kernels=4,
temperature=31,
dy_reduction=4,
pool_dim='freq',
conv1d_kernel=[3, 1],
dilated_DY=[0, 0, 0, 0, 0, 0, 0],
dilation_size=[[0, 0], [0, 0], [0, 0], [0, 0]],
dy_chan_proportion=None,
aggconv=False,):
super(DYCNN, self).__init__()
self.n_filt = n_filt
self.n_filt_last = n_filt[-1]
cnn = nn.Sequential()
if len(n_filt) == 7:
freq_dims = [128, 64, 32, 16, 8, 4, 2]
if pre_conv is not None:
cnn.add_module("pre_conv", nn.Conv2d(n_input_ch, pre_conv, 3, 1, 1))
n_input_ch = pre_conv
def conv(i, normalization="batch", dropout=None, activ='relu'):
in_dim = n_input_ch if i == 0 else n_filt[i - 1]
out_dim = n_filt[i]
# convolution
if DY_layers[i] == 1:
cnn.add_module("conv{0}".format(i), Dynamic_conv2d(in_dim, out_dim, freq_dims[i], kernel[i], stride[i],
pad[i],
n_basis_kernels=n_basis_kernels,
temperature=temperature,
pool_dim=pool_dim,
reduction=dy_reduction,
conv1d_kernel=conv1d_kernel,
dilated_DY=dilated_DY[i],
dilation_size=dilation_size,
dy_chan_proportion=dy_chan_proportion,
aggconv=aggconv))
else:
cnn.add_module("conv{0}".format(i), nn.Conv2d(in_dim, out_dim, kernel[i], stride[i], pad[i],
dilation[i]))
# normalization
if normalization == "batch":
cnn.add_module("batchnorm{0}".format(i), nn.BatchNorm2d(out_dim, eps=0.001, momentum=0.99))
# non-linearity
if activ.lower() == "relu":
cnn.add_module("Relu{0}".format(i), nn.ReLU())
elif activ.lower() == "glu":
cnn.add_module("glu{0}".format(i), GLU(out_dim))
elif activ.lower() == "cg":
cnn.add_module("cg{0}".format(i), ContextGating(out_dim))
# dropout
if dropout is not None:
cnn.add_module("dropout{0}".format(i), nn.Dropout(dropout))
for i in range(len(n_filt)):
conv(i, normalization=normalization, dropout=dropout, activ=activation)
cnn.add_module("pooling{0}".format(i), nn.AvgPool2d(pooling[i]))
self.cnn = cnn
def forward(self, x): #x size : [bs, chan, frames, freqs]
x = self.cnn(x)
return x
class DYCRNN(nn.Module):
def __init__(self,
n_input_ch,
n_class=10,
n_RNN_cell=128,
n_RNN_layer=2,
rec_dropout=0,
attention=True,
conv_dropout=0.5,
**convkwargs):
super(DYCRNN, self).__init__()
self.n_input_ch = n_input_ch
self.attention = attention
self.n_class = n_class
self.cnn = DYCNN(n_input_ch=n_input_ch, dropout=conv_dropout, **convkwargs)
rnn_in = self.cnn.n_filt[-1]
self.rnn = BiGRU(n_in=rnn_in, n_hidden=n_RNN_cell, dropout=rec_dropout, num_layers=n_RNN_layer)
self.dropout = nn.Dropout(conv_dropout)
self.sigmoid = nn.Sigmoid()
linear_in = n_RNN_cell * 2
self.linear = nn.Linear(linear_in, n_class)
if self.attention:
self.linear_att = nn.Linear(linear_in, n_class)
if self.attention == "time":
self.softmax = nn.Softmax(dim=1) # softmax on time dimension
elif self.attention == "class":
self.softmax = nn.Softmax(dim=-1) # softmax on class dimension
def forward(self, x): # input size: [bs, freqs, frames]
#cnn
x = x.transpose(1, 2).unsqueeze(1) # x size: [bs, chan, frames, freqs]
x = self.cnn(x) # x size: [bs, chan, frames, 1]
x = x.squeeze(-1) # x size: [bs, chan, frames]
x = x.permute(0, 2, 1) # x size: [bs, frames, chan]
#rnn
x = self.rnn(x) # x size: [bs, frames, 2 * chan]
x = self.dropout(x)
strong = self.linear(x) # strong size: [bs, frames, n_class]
strong = self.sigmoid(strong)
if self.attention:
attention = self.linear_att(x) # attention size: [bs, frames, n_class]
attention = self.softmax(attention) # attention size: [bs, frames, n_class]
attention = torch.clamp(attention, min=1e-7, max=1)
weak = (strong * attention).sum(1) / attention.sum(1) # weak size: [bs, n_class]
else:
weak = strong.mean(1)
return strong.transpose(1, 2), weak