|
|
|
|
|
import math |
|
import copy |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchaudio import transforms |
|
from torchlibrosa.augmentation import SpecAugmentation |
|
|
|
from .utils import mean_with_lens, max_with_lens, \ |
|
init, pack_wrapper, generate_length_mask, PositionalEncoding |
|
|
|
|
|
def init_layer(layer): |
|
"""Initialize a Linear or Convolutional layer. """ |
|
nn.init.xavier_uniform_(layer.weight) |
|
|
|
if hasattr(layer, 'bias'): |
|
if layer.bias is not None: |
|
layer.bias.data.fill_(0.) |
|
|
|
|
|
def init_bn(bn): |
|
"""Initialize a Batchnorm layer. """ |
|
bn.bias.data.fill_(0.) |
|
bn.weight.data.fill_(1.) |
|
|
|
|
|
class BaseEncoder(nn.Module): |
|
|
|
""" |
|
Encode the given audio into embedding |
|
Base encoder class, cannot be called directly |
|
All encoders should inherit from this class |
|
""" |
|
|
|
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim): |
|
super(BaseEncoder, self).__init__() |
|
self.spec_dim = spec_dim |
|
self.fc_feat_dim = fc_feat_dim |
|
self.attn_feat_dim = attn_feat_dim |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
class Block2D(nn.Module): |
|
|
|
def __init__(self, cin, cout, kernel_size=3, padding=1): |
|
super().__init__() |
|
self.block = nn.Sequential( |
|
nn.BatchNorm2d(cin), |
|
nn.Conv2d(cin, |
|
cout, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
bias=False), |
|
nn.LeakyReLU(inplace=True, negative_slope=0.1)) |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class LinearSoftPool(nn.Module): |
|
"""LinearSoftPool |
|
Linear softmax, takes logits and returns a probability, near to the actual maximum value. |
|
Taken from the paper: |
|
A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling |
|
https://arxiv.org/abs/1810.09050 |
|
""" |
|
def __init__(self, pooldim=1): |
|
super().__init__() |
|
self.pooldim = pooldim |
|
|
|
def forward(self, logits, time_decision): |
|
return (time_decision**2).sum(self.pooldim) / time_decision.sum( |
|
self.pooldim) |
|
|
|
|
|
class MeanPool(nn.Module): |
|
|
|
def __init__(self, pooldim=1): |
|
super().__init__() |
|
self.pooldim = pooldim |
|
|
|
def forward(self, logits, decision): |
|
return torch.mean(decision, dim=self.pooldim) |
|
|
|
|
|
class AttentionPool(nn.Module): |
|
"""docstring for AttentionPool""" |
|
def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs): |
|
super().__init__() |
|
self.inputdim = inputdim |
|
self.outputdim = outputdim |
|
self.pooldim = pooldim |
|
self.transform = nn.Linear(inputdim, outputdim) |
|
self.activ = nn.Softmax(dim=self.pooldim) |
|
self.eps = 1e-7 |
|
|
|
def forward(self, logits, decision): |
|
|
|
|
|
w = self.activ(torch.clamp(self.transform(logits), -15, 15)) |
|
detect = (decision * w).sum( |
|
self.pooldim) / (w.sum(self.pooldim) + self.eps) |
|
|
|
return detect |
|
|
|
|
|
class MMPool(nn.Module): |
|
|
|
def __init__(self, dims): |
|
super().__init__() |
|
self.avgpool = nn.AvgPool2d(dims) |
|
self.maxpool = nn.MaxPool2d(dims) |
|
|
|
def forward(self, x): |
|
return self.avgpool(x) + self.maxpool(x) |
|
|
|
|
|
def parse_poolingfunction(poolingfunction_name='mean', **kwargs): |
|
"""parse_poolingfunction |
|
A heler function to parse any temporal pooling |
|
Pooling is done on dimension 1 |
|
:param poolingfunction_name: |
|
:param **kwargs: |
|
""" |
|
poolingfunction_name = poolingfunction_name.lower() |
|
if poolingfunction_name == 'mean': |
|
return MeanPool(pooldim=1) |
|
elif poolingfunction_name == 'linear': |
|
return LinearSoftPool(pooldim=1) |
|
elif poolingfunction_name == 'attention': |
|
return AttentionPool(inputdim=kwargs['inputdim'], |
|
outputdim=kwargs['outputdim']) |
|
|
|
|
|
def embedding_pooling(x, lens, pooling="mean"): |
|
if pooling == "max": |
|
fc_embs = max_with_lens(x, lens) |
|
elif pooling == "mean": |
|
fc_embs = mean_with_lens(x, lens) |
|
elif pooling == "mean+max": |
|
x_mean = mean_with_lens(x, lens) |
|
x_max = max_with_lens(x, lens) |
|
fc_embs = x_mean + x_max |
|
elif pooling == "last": |
|
indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1)) |
|
|
|
fc_embs = torch.gather(x, 1, indices).squeeze(1) |
|
else: |
|
raise Exception(f"pooling method {pooling} not support") |
|
return fc_embs |
|
|
|
|
|
class Cdur5Encoder(BaseEncoder): |
|
|
|
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"): |
|
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) |
|
self.pooling = pooling |
|
self.features = nn.Sequential( |
|
Block2D(1, 32), |
|
nn.LPPool2d(4, (2, 4)), |
|
Block2D(32, 128), |
|
Block2D(128, 128), |
|
nn.LPPool2d(4, (2, 4)), |
|
Block2D(128, 128), |
|
Block2D(128, 128), |
|
nn.LPPool2d(4, (1, 4)), |
|
nn.Dropout(0.3), |
|
) |
|
with torch.no_grad(): |
|
rnn_input_dim = self.features( |
|
torch.randn(1, 1, 500, spec_dim)).shape |
|
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] |
|
|
|
self.gru = nn.GRU(rnn_input_dim, |
|
128, |
|
bidirectional=True, |
|
batch_first=True) |
|
self.apply(init) |
|
|
|
def forward(self, input_dict): |
|
x = input_dict["spec"] |
|
lens = input_dict["spec_len"] |
|
if "upsample" not in input_dict: |
|
input_dict["upsample"] = False |
|
lens = torch.as_tensor(copy.deepcopy(lens)) |
|
N, T, _ = x.shape |
|
x = x.unsqueeze(1) |
|
x = self.features(x) |
|
x = x.transpose(1, 2).contiguous().flatten(-2) |
|
x, _ = self.gru(x) |
|
if input_dict["upsample"]: |
|
x = nn.functional.interpolate( |
|
x.transpose(1, 2), |
|
T, |
|
mode='linear', |
|
align_corners=False).transpose(1, 2) |
|
else: |
|
lens //= 4 |
|
attn_emb = x |
|
fc_emb = embedding_pooling(x, lens, self.pooling) |
|
return { |
|
"attn_emb": attn_emb, |
|
"fc_emb": fc_emb, |
|
"attn_emb_len": lens |
|
} |
|
|
|
|
|
def conv_conv_block(in_channel, out_channel): |
|
return nn.Sequential( |
|
nn.Conv2d(in_channel, |
|
out_channel, |
|
kernel_size=3, |
|
bias=False, |
|
padding=1), |
|
nn.BatchNorm2d(out_channel), |
|
nn.ReLU(True), |
|
nn.Conv2d(out_channel, |
|
out_channel, |
|
kernel_size=3, |
|
bias=False, |
|
padding=1), |
|
nn.BatchNorm2d(out_channel), |
|
nn.ReLU(True) |
|
) |
|
|
|
|
|
class Cdur8Encoder(BaseEncoder): |
|
|
|
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"): |
|
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) |
|
self.pooling = pooling |
|
self.features = nn.Sequential( |
|
conv_conv_block(1, 64), |
|
MMPool((2, 2)), |
|
nn.Dropout(0.2, True), |
|
conv_conv_block(64, 128), |
|
MMPool((2, 2)), |
|
nn.Dropout(0.2, True), |
|
conv_conv_block(128, 256), |
|
MMPool((1, 2)), |
|
nn.Dropout(0.2, True), |
|
conv_conv_block(256, 512), |
|
MMPool((1, 2)), |
|
nn.Dropout(0.2, True), |
|
nn.AdaptiveAvgPool2d((None, 1)), |
|
) |
|
self.init_bn = nn.BatchNorm2d(spec_dim) |
|
self.embedding = nn.Linear(512, 512) |
|
self.gru = nn.GRU(512, 256, bidirectional=True, batch_first=True) |
|
self.apply(init) |
|
|
|
def forward(self, input_dict): |
|
x = input_dict["spec"] |
|
lens = input_dict["spec_len"] |
|
lens = torch.as_tensor(copy.deepcopy(lens)) |
|
x = x.unsqueeze(1) |
|
x = x.transpose(1, 3) |
|
x = self.init_bn(x) |
|
x = x.transpose(1, 3) |
|
x = self.features(x) |
|
x = x.transpose(1, 2).contiguous().flatten(-2) |
|
x = F.dropout(x, p=0.5, training=self.training) |
|
x = F.relu_(self.embedding(x)) |
|
x, _ = self.gru(x) |
|
attn_emb = x |
|
lens //= 4 |
|
fc_emb = embedding_pooling(x, lens, self.pooling) |
|
return { |
|
"attn_emb": attn_emb, |
|
"fc_emb": fc_emb, |
|
"attn_emb_len": lens |
|
} |
|
|
|
|
|
class Cnn10Encoder(BaseEncoder): |
|
|
|
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim): |
|
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) |
|
self.features = nn.Sequential( |
|
conv_conv_block(1, 64), |
|
nn.AvgPool2d((2, 2)), |
|
nn.Dropout(0.2, True), |
|
conv_conv_block(64, 128), |
|
nn.AvgPool2d((2, 2)), |
|
nn.Dropout(0.2, True), |
|
conv_conv_block(128, 256), |
|
nn.AvgPool2d((2, 2)), |
|
nn.Dropout(0.2, True), |
|
conv_conv_block(256, 512), |
|
nn.AvgPool2d((2, 2)), |
|
nn.Dropout(0.2, True), |
|
nn.AdaptiveAvgPool2d((None, 1)), |
|
) |
|
self.init_bn = nn.BatchNorm2d(spec_dim) |
|
self.embedding = nn.Linear(512, 512) |
|
self.apply(init) |
|
|
|
def forward(self, input_dict): |
|
x = input_dict["spec"] |
|
lens = input_dict["spec_len"] |
|
lens = torch.as_tensor(copy.deepcopy(lens)) |
|
x = x.unsqueeze(1) |
|
x = x.transpose(1, 3) |
|
x = self.init_bn(x) |
|
x = x.transpose(1, 3) |
|
x = self.features(x) |
|
x = x.transpose(1, 2).contiguous().flatten(-2) |
|
attn_emb = x |
|
lens //= 16 |
|
fc_emb = embedding_pooling(x, lens, "mean+max") |
|
fc_emb = F.dropout(fc_emb, p=0.5, training=self.training) |
|
fc_emb = self.embedding(fc_emb) |
|
fc_emb = F.relu_(fc_emb) |
|
return { |
|
"attn_emb": attn_emb, |
|
"fc_emb": fc_emb, |
|
"attn_emb_len": lens |
|
} |
|
|
|
|
|
class ConvBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
|
|
super(ConvBlock, self).__init__() |
|
|
|
self.conv1 = nn.Conv2d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=(3, 3), stride=(1, 1), |
|
padding=(1, 1), bias=False) |
|
|
|
self.conv2 = nn.Conv2d(in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=(3, 3), stride=(1, 1), |
|
padding=(1, 1), bias=False) |
|
|
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
|
self.init_weight() |
|
|
|
def init_weight(self): |
|
init_layer(self.conv1) |
|
init_layer(self.conv2) |
|
init_bn(self.bn1) |
|
init_bn(self.bn2) |
|
|
|
|
|
def forward(self, input, pool_size=(2, 2), pool_type='avg'): |
|
|
|
x = input |
|
x = F.relu_(self.bn1(self.conv1(x))) |
|
x = F.relu_(self.bn2(self.conv2(x))) |
|
if pool_type == 'max': |
|
x = F.max_pool2d(x, kernel_size=pool_size) |
|
elif pool_type == 'avg': |
|
x = F.avg_pool2d(x, kernel_size=pool_size) |
|
elif pool_type == 'avg+max': |
|
x1 = F.avg_pool2d(x, kernel_size=pool_size) |
|
x2 = F.max_pool2d(x, kernel_size=pool_size) |
|
x = x1 + x2 |
|
else: |
|
raise Exception('Incorrect argument!') |
|
|
|
return x |
|
|
|
|
|
class Cnn14Encoder(nn.Module): |
|
def __init__(self, sample_rate=32000): |
|
super().__init__() |
|
sr_to_fmax = { |
|
32000: 14000, |
|
16000: 8000 |
|
} |
|
|
|
self.melspec_extractor = transforms.MelSpectrogram( |
|
sample_rate=sample_rate, |
|
n_fft=32 * sample_rate // 1000, |
|
win_length=32 * sample_rate // 1000, |
|
hop_length=10 * sample_rate // 1000, |
|
f_min=50, |
|
f_max=sr_to_fmax[sample_rate], |
|
n_mels=64, |
|
norm="slaney", |
|
mel_scale="slaney" |
|
) |
|
self.hop_length = 10 * sample_rate // 1000 |
|
self.db_transform = transforms.AmplitudeToDB() |
|
|
|
self.spec_augmenter = SpecAugmentation(time_drop_width=64, |
|
time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2) |
|
|
|
self.bn0 = nn.BatchNorm2d(64) |
|
|
|
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) |
|
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) |
|
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) |
|
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) |
|
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) |
|
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) |
|
|
|
self.downsample_ratio = 32 |
|
|
|
self.fc1 = nn.Linear(2048, 2048, bias=True) |
|
|
|
self.init_weight() |
|
|
|
def init_weight(self): |
|
init_bn(self.bn0) |
|
init_layer(self.fc1) |
|
|
|
def load_pretrained(self, pretrained): |
|
checkpoint = torch.load(pretrained, map_location="cpu") |
|
|
|
if "model" in checkpoint: |
|
state_keys = checkpoint["model"].keys() |
|
backbone = False |
|
for key in state_keys: |
|
if key.startswith("backbone."): |
|
backbone = True |
|
break |
|
|
|
if backbone: |
|
state_dict = {} |
|
for key, value in checkpoint["model"].items(): |
|
if key.startswith("backbone."): |
|
model_key = key.replace("backbone.", "") |
|
state_dict[model_key] = value |
|
else: |
|
state_dict = checkpoint["model"] |
|
elif "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
state_dict_keys = list(filter( |
|
lambda x: "audio_encoder" in x, state_dict.keys())) |
|
state_dict = { |
|
key.replace('audio_encoder.', ''): state_dict[key] |
|
for key in state_dict_keys |
|
} |
|
else: |
|
raise Exception("Unkown checkpoint format") |
|
|
|
model_dict = self.state_dict() |
|
pretrained_dict = { |
|
k: v for k, v in state_dict.items() if (k in model_dict) and ( |
|
model_dict[k].shape == v.shape) |
|
} |
|
model_dict.update(pretrained_dict) |
|
self.load_state_dict(model_dict, strict=True) |
|
|
|
def forward(self, input_dict): |
|
""" |
|
Input: (batch_size, n_samples)""" |
|
waveform = input_dict["wav"] |
|
wave_length = input_dict["wav_len"] |
|
specaug = input_dict["specaug"] |
|
x = self.melspec_extractor(waveform) |
|
x = self.db_transform(x) |
|
x = x.transpose(1, 2) |
|
x = x.unsqueeze(1) |
|
|
|
|
|
if self.training and specaug: |
|
x = self.spec_augmenter(x) |
|
|
|
x = x.transpose(1, 3) |
|
x = self.bn0(x) |
|
x = x.transpose(1, 3) |
|
|
|
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = torch.mean(x, dim=3) |
|
attn_emb = x.transpose(1, 2) |
|
|
|
wave_length = torch.as_tensor(wave_length) |
|
feat_length = torch.div(wave_length, self.hop_length, |
|
rounding_mode="floor") + 1 |
|
feat_length = torch.div(feat_length, self.downsample_ratio, |
|
rounding_mode="floor") |
|
x_max = max_with_lens(attn_emb, feat_length) |
|
x_mean = mean_with_lens(attn_emb, feat_length) |
|
x = x_max + x_mean |
|
x = F.dropout(x, p=0.5, training=self.training) |
|
x = F.relu_(self.fc1(x)) |
|
fc_emb = F.dropout(x, p=0.5, training=self.training) |
|
|
|
output_dict = { |
|
'fc_emb': fc_emb, |
|
'attn_emb': attn_emb, |
|
'attn_emb_len': feat_length |
|
} |
|
|
|
return output_dict |
|
|
|
|
|
class RnnEncoder(BaseEncoder): |
|
|
|
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, |
|
pooling="mean", **kwargs): |
|
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) |
|
self.pooling = pooling |
|
self.hidden_size = kwargs.get('hidden_size', 512) |
|
self.bidirectional = kwargs.get('bidirectional', False) |
|
self.num_layers = kwargs.get('num_layers', 1) |
|
self.dropout = kwargs.get('dropout', 0.2) |
|
self.rnn_type = kwargs.get('rnn_type', "GRU") |
|
self.in_bn = kwargs.get('in_bn', False) |
|
self.embed_dim = self.hidden_size * (self.bidirectional + 1) |
|
self.network = getattr(nn, self.rnn_type)( |
|
attn_feat_dim, |
|
self.hidden_size, |
|
num_layers=self.num_layers, |
|
bidirectional=self.bidirectional, |
|
dropout=self.dropout, |
|
batch_first=True) |
|
if self.in_bn: |
|
self.bn = nn.BatchNorm1d(self.embed_dim) |
|
self.apply(init) |
|
|
|
def forward(self, input_dict): |
|
x = input_dict["attn"] |
|
lens = input_dict["attn_len"] |
|
lens = torch.as_tensor(lens) |
|
|
|
if self.in_bn: |
|
x = pack_wrapper(self.bn, x, lens) |
|
out = pack_wrapper(self.network, x, lens) |
|
|
|
attn_emb = out |
|
fc_emb = embedding_pooling(out, lens, self.pooling) |
|
return { |
|
"attn_emb": attn_emb, |
|
"fc_emb": fc_emb, |
|
"attn_emb_len": lens |
|
} |
|
|
|
|
|
class Cnn14RnnEncoder(nn.Module): |
|
def __init__(self, sample_rate=32000, pretrained=None, |
|
freeze_cnn=False, freeze_cnn_bn=False, |
|
pooling="mean", **kwargs): |
|
super().__init__() |
|
self.cnn = Cnn14Encoder(sample_rate) |
|
self.rnn = RnnEncoder(64, 2048, 2048, pooling, **kwargs) |
|
if pretrained is not None: |
|
self.cnn.load_pretrained(pretrained) |
|
if freeze_cnn: |
|
assert pretrained is not None, "cnn is not pretrained but frozen" |
|
for param in self.cnn.parameters(): |
|
param.requires_grad = False |
|
self.freeze_cnn_bn = freeze_cnn_bn |
|
|
|
def train(self, mode): |
|
super().train(mode=mode) |
|
if self.freeze_cnn_bn: |
|
def bn_eval(module): |
|
class_name = module.__class__.__name__ |
|
if class_name.find("BatchNorm") != -1: |
|
module.eval() |
|
self.cnn.apply(bn_eval) |
|
return self |
|
|
|
def forward(self, input_dict): |
|
output_dict = self.cnn(input_dict) |
|
output_dict["attn"] = output_dict["attn_emb"] |
|
output_dict["attn_len"] = output_dict["attn_emb_len"] |
|
del output_dict["attn_emb"], output_dict["attn_emb_len"] |
|
output_dict = self.rnn(output_dict) |
|
return output_dict |
|
|
|
|
|
class TransformerEncoder(BaseEncoder): |
|
|
|
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, d_model, **kwargs): |
|
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) |
|
self.d_model = d_model |
|
dropout = kwargs.get("dropout", 0.2) |
|
self.nhead = kwargs.get("nhead", self.d_model // 64) |
|
self.nlayers = kwargs.get("nlayers", 2) |
|
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) |
|
|
|
self.attn_proj = nn.Sequential( |
|
nn.Linear(attn_feat_dim, self.d_model), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.LayerNorm(self.d_model) |
|
) |
|
layer = nn.TransformerEncoderLayer(d_model=self.d_model, |
|
nhead=self.nhead, |
|
dim_feedforward=self.dim_feedforward, |
|
dropout=dropout) |
|
self.model = nn.TransformerEncoder(layer, self.nlayers) |
|
self.cls_token = nn.Parameter(torch.zeros(d_model)) |
|
self.init_params() |
|
|
|
def init_params(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward(self, input_dict): |
|
attn_feat = input_dict["attn"] |
|
attn_feat_len = input_dict["attn_len"] |
|
attn_feat_len = torch.as_tensor(attn_feat_len) |
|
|
|
attn_feat = self.attn_proj(attn_feat) |
|
|
|
cls_emb = self.cls_token.reshape(1, 1, self.d_model).repeat( |
|
attn_feat.size(0), 1, 1) |
|
attn_feat = torch.cat((cls_emb, attn_feat), dim=1) |
|
attn_feat = attn_feat.transpose(0, 1) |
|
|
|
attn_feat_len += 1 |
|
src_key_padding_mask = ~generate_length_mask( |
|
attn_feat_len, attn_feat.size(0)).to(attn_feat.device) |
|
output = self.model(attn_feat, src_key_padding_mask=src_key_padding_mask) |
|
|
|
attn_emb = output.transpose(0, 1) |
|
fc_emb = attn_emb[:, 0] |
|
return { |
|
"attn_emb": attn_emb, |
|
"fc_emb": fc_emb, |
|
"attn_emb_len": attn_feat_len |
|
} |
|
|
|
|
|
class Cnn14TransformerEncoder(nn.Module): |
|
def __init__(self, sample_rate=32000, pretrained=None, |
|
freeze_cnn=False, freeze_cnn_bn=False, |
|
d_model="mean", **kwargs): |
|
super().__init__() |
|
self.cnn = Cnn14Encoder(sample_rate) |
|
self.trm = TransformerEncoder(64, 2048, 2048, d_model, **kwargs) |
|
if pretrained is not None: |
|
self.cnn.load_pretrained(pretrained) |
|
if freeze_cnn: |
|
assert pretrained is not None, "cnn is not pretrained but frozen" |
|
for param in self.cnn.parameters(): |
|
param.requires_grad = False |
|
self.freeze_cnn_bn = freeze_cnn_bn |
|
|
|
def train(self, mode): |
|
super().train(mode=mode) |
|
if self.freeze_cnn_bn: |
|
def bn_eval(module): |
|
class_name = module.__class__.__name__ |
|
if class_name.find("BatchNorm") != -1: |
|
module.eval() |
|
self.cnn.apply(bn_eval) |
|
return self |
|
|
|
def forward(self, input_dict): |
|
output_dict = self.cnn(input_dict) |
|
output_dict["attn"] = output_dict["attn_emb"] |
|
output_dict["attn_len"] = output_dict["attn_emb_len"] |
|
del output_dict["attn_emb"], output_dict["attn_emb_len"] |
|
output_dict = self.trm(output_dict) |
|
return output_dict |
|
|
|
|
|
|
|
|
|
|
|
|