|
|
|
import random |
|
import torch |
|
import torch.nn as nn |
|
|
|
from .base_model import CaptionModel |
|
from .utils import repeat_tensor |
|
import audio_to_text.captioning.models.decoder |
|
|
|
|
|
class TransformerModel(CaptionModel): |
|
|
|
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): |
|
if not hasattr(self, "compatible_decoders"): |
|
self.compatible_decoders = ( |
|
audio_to_text.captioning.models.decoder.TransformerDecoder, |
|
) |
|
super().__init__(encoder, decoder, **kwargs) |
|
|
|
def seq_forward(self, input_dict): |
|
cap = input_dict["cap"] |
|
cap_padding_mask = (cap == self.pad_idx).to(cap.device) |
|
cap_padding_mask = cap_padding_mask[:, :-1] |
|
output = self.decoder( |
|
{ |
|
"word": cap[:, :-1], |
|
"attn_emb": input_dict["attn_emb"], |
|
"attn_emb_len": input_dict["attn_emb_len"], |
|
"cap_padding_mask": cap_padding_mask |
|
} |
|
) |
|
return output |
|
|
|
def prepare_decoder_input(self, input_dict, output): |
|
decoder_input = { |
|
"attn_emb": input_dict["attn_emb"], |
|
"attn_emb_len": input_dict["attn_emb_len"] |
|
} |
|
t = input_dict["t"] |
|
|
|
|
|
|
|
|
|
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: |
|
word = input_dict["cap"][:, :t+1] |
|
else: |
|
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() |
|
if t == 0: |
|
word = start_word |
|
else: |
|
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) |
|
|
|
decoder_input["word"] = word |
|
|
|
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) |
|
decoder_input["cap_padding_mask"] = cap_padding_mask |
|
return decoder_input |
|
|
|
def prepare_beamsearch_decoder_input(self, input_dict, output_i): |
|
decoder_input = {} |
|
t = input_dict["t"] |
|
i = input_dict["sample_idx"] |
|
beam_size = input_dict["beam_size"] |
|
|
|
|
|
|
|
if t == 0: |
|
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) |
|
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) |
|
output_i["attn_emb"] = attn_emb |
|
output_i["attn_emb_len"] = attn_emb_len |
|
decoder_input["attn_emb"] = output_i["attn_emb"] |
|
decoder_input["attn_emb_len"] = output_i["attn_emb_len"] |
|
|
|
|
|
|
|
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() |
|
if t == 0: |
|
word = start_word |
|
else: |
|
word = torch.cat((start_word, output_i["seq"]), dim=-1) |
|
decoder_input["word"] = word |
|
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) |
|
decoder_input["cap_padding_mask"] = cap_padding_mask |
|
|
|
return decoder_input |
|
|
|
|
|
class M2TransformerModel(CaptionModel): |
|
|
|
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): |
|
if not hasattr(self, "compatible_decoders"): |
|
self.compatible_decoders = ( |
|
captioning.models.decoder.M2TransformerDecoder, |
|
) |
|
super().__init__(encoder, decoder, **kwargs) |
|
self.check_encoder_compatibility() |
|
|
|
def check_encoder_compatibility(self): |
|
assert isinstance(self.encoder, captioning.models.encoder.M2TransformerEncoder), \ |
|
f"only M2TransformerModel is compatible with {self.__class__.__name__}" |
|
|
|
|
|
def seq_forward(self, input_dict): |
|
cap = input_dict["cap"] |
|
output = self.decoder( |
|
{ |
|
"word": cap[:, :-1], |
|
"attn_emb": input_dict["attn_emb"], |
|
"attn_emb_mask": input_dict["attn_emb_mask"], |
|
} |
|
) |
|
return output |
|
|
|
def prepare_decoder_input(self, input_dict, output): |
|
decoder_input = { |
|
"attn_emb": input_dict["attn_emb"], |
|
"attn_emb_mask": input_dict["attn_emb_mask"] |
|
} |
|
t = input_dict["t"] |
|
|
|
|
|
|
|
|
|
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: |
|
word = input_dict["cap"][:, :t+1] |
|
else: |
|
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() |
|
if t == 0: |
|
word = start_word |
|
else: |
|
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) |
|
|
|
decoder_input["word"] = word |
|
|
|
return decoder_input |
|
|
|
def prepare_beamsearch_decoder_input(self, input_dict, output_i): |
|
decoder_input = {} |
|
t = input_dict["t"] |
|
i = input_dict["sample_idx"] |
|
beam_size = input_dict["beam_size"] |
|
|
|
|
|
|
|
if t == 0: |
|
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) |
|
attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size) |
|
output_i["attn_emb"] = attn_emb |
|
output_i["attn_emb_mask"] = attn_emb_mask |
|
decoder_input["attn_emb"] = output_i["attn_emb"] |
|
decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"] |
|
|
|
|
|
|
|
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() |
|
if t == 0: |
|
word = start_word |
|
else: |
|
word = torch.cat((start_word, output_i["seq"]), dim=-1) |
|
decoder_input["word"] = word |
|
|
|
return decoder_input |
|
|
|
|
|
class EventEncoder(nn.Module): |
|
""" |
|
Encode the Label information in AudioCaps and AudioSet |
|
""" |
|
def __init__(self, emb_dim, vocab_size=527): |
|
super(EventEncoder, self).__init__() |
|
self.label_embedding = nn.Parameter( |
|
torch.randn((vocab_size, emb_dim)), requires_grad=True) |
|
|
|
def forward(self, word_idxs): |
|
indices = word_idxs / word_idxs.sum(dim=1, keepdim=True) |
|
embeddings = indices @ self.label_embedding |
|
return embeddings |
|
|
|
|
|
class EventCondTransformerModel(TransformerModel): |
|
|
|
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): |
|
if not hasattr(self, "compatible_decoders"): |
|
self.compatible_decoders = ( |
|
captioning.models.decoder.EventTransformerDecoder, |
|
) |
|
super().__init__(encoder, decoder, **kwargs) |
|
self.label_encoder = EventEncoder(decoder.emb_dim, 527) |
|
self.train_forward_keys += ["events"] |
|
self.inference_forward_keys += ["events"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_decoder_input(self, input_dict, output): |
|
decoder_input = super().prepare_decoder_input(input_dict, output) |
|
decoder_input["events"] = self.label_encoder(input_dict["events"]) |
|
return decoder_input |
|
|
|
def prepare_beamsearch_decoder_input(self, input_dict, output_i): |
|
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) |
|
t = input_dict["t"] |
|
i = input_dict["sample_idx"] |
|
beam_size = input_dict["beam_size"] |
|
if t == 0: |
|
output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size) |
|
decoder_input["events"] = output_i["events"] |
|
return decoder_input |
|
|
|
|
|
class KeywordCondTransformerModel(TransformerModel): |
|
|
|
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): |
|
if not hasattr(self, "compatible_decoders"): |
|
self.compatible_decoders = ( |
|
captioning.models.decoder.KeywordProbTransformerDecoder, |
|
) |
|
super().__init__(encoder, decoder, **kwargs) |
|
self.train_forward_keys += ["keyword"] |
|
self.inference_forward_keys += ["keyword"] |
|
|
|
def seq_forward(self, input_dict): |
|
cap = input_dict["cap"] |
|
cap_padding_mask = (cap == self.pad_idx).to(cap.device) |
|
cap_padding_mask = cap_padding_mask[:, :-1] |
|
keyword = input_dict["keyword"] |
|
output = self.decoder( |
|
{ |
|
"word": cap[:, :-1], |
|
"attn_emb": input_dict["attn_emb"], |
|
"attn_emb_len": input_dict["attn_emb_len"], |
|
"keyword": keyword, |
|
"cap_padding_mask": cap_padding_mask |
|
} |
|
) |
|
return output |
|
|
|
def prepare_decoder_input(self, input_dict, output): |
|
decoder_input = super().prepare_decoder_input(input_dict, output) |
|
decoder_input["keyword"] = input_dict["keyword"] |
|
return decoder_input |
|
|
|
def prepare_beamsearch_decoder_input(self, input_dict, output_i): |
|
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) |
|
t = input_dict["t"] |
|
i = input_dict["sample_idx"] |
|
beam_size = input_dict["beam_size"] |
|
if t == 0: |
|
output_i["keyword"] = repeat_tensor(input_dict["keyword"][i], |
|
beam_size) |
|
decoder_input["keyword"] = output_i["keyword"] |
|
return decoder_input |
|
|
|
|