unit_test / uniperceiver /modeling /decode_strategy /caption_beam_searcher_v3.py
herrius's picture
Upload 259 files
32b542e
# Copyright (c) 2019, AImageLab
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from uniperceiver.config import configurable
from uniperceiver.functional import expand_tensor
from .decode_strategy import DecodeStrategy
from .build import DECODE_STRATEGY_REGISTRY
from uniperceiver.utils import comm
import math
from torch.cuda.amp import autocast
@DECODE_STRATEGY_REGISTRY.register()
class CaptionBeamSearcherV3(DecodeStrategy):
def data_half(self, data):
if self.fp16:
for k, v in data.items():
if isinstance(v, torch.Tensor) and v.dtype == torch.float32:
data[k] = v.half()
# print(k)
return data
else:
return data
def _select(self, batch_size, beam_size, t, candidate_logprob):
selected_logprob, selected_idx = torch.sort(candidate_logprob.view(batch_size, -1), -1, descending=True)
selected_logprob, selected_idx = selected_logprob[:, :beam_size], selected_idx[:, :beam_size]
return selected_idx, selected_logprob
def _expand_state(self, states, selected_beam, batch_size, beam_size, cur_beam_size):
for i in range(len(states)):
shape = list(states[i].shape)
beam = selected_beam
for _ in shape[1:]:
beam = beam.unsqueeze(-1)
states[i] = torch.gather(states[i].view(*([batch_size, cur_beam_size] + shape[1:])), 1,
beam.expand(*([batch_size, beam_size] + shape[1:])))
states[i] = states[i].view(*([-1, ] + shape[1:]))
def _forward(self, batched_inputs, model):
# only two caption tasks are generative task now!
# for caption tasks, the computations are:
# 1. encode the image sequence; save for further use.
# 2. if no cached encoded dictionary, encode the dictionary and save; otherwise reuse cache.
# 3. compute attention. We use cross attention insted of self attention.
# batched_inputs[kfg.IMAGE] = batched_inputs.pop(kfg.VIDEO).squeeze(1)
inputs = batched_inputs
inputs = self.data_half(inputs)
out_size = batched_inputs.get('OUT_SIZE', 1)
task_info = inputs['task_info']
bs = task_info['batch_size']
if isinstance(bs, torch.Tensor):
bs = bs.item()
image_input = inputs['input_sample_list']
vocab_input = inputs['shared_target_sets'][self.vocab_name]
# 1. encode the image/video sequence.
moe_embedding = None
for image_data in image_input:
if 'moe_embedding' in image_data:
moe_embedding = image_data['moe_embedding']
image_encode = model._forward_data(image_input, task_info=task_info, return_all=True)[0]['data']
# 2. encode the vocabulary - if no pre-computed, add that into input
if getattr(self, 'pre_computed_word_embeds', None) is None:
vocab_encode = model._forward_data(vocab_input, task_info=task_info, return_all=False)[0]
self.pre_computed_word_embeds = vocab_encode
else:
vocab_encode = self.pre_computed_word_embeds
# 3. compute attention
comm._CAPTION_GEN_MODE = True
task_info.update({"prefix_spe_before_fuse": False})
beam_size = self.beam_size
log_probs = []
selected_words = None
seq_logprob = torch.zeros((bs, 1, 1)).cuda() # bs, 1, 1
seq_mask = torch.ones((bs, beam_size, 1)).cuda()
wt = Variable(torch.zeros(bs, dtype=torch.long).cuda().unsqueeze(1)) + self.spe_token_id
u_tokens_type = wt.new_zeros(wt.shape) # [Note] we assume the type tokens are 0.
history_states = image_encode[:-1]
len_prefix = history_states[0].shape[1]
total_history_states = [ history_states[0].new_zeros(beam_size * bs, image_encode[0].shape[1] + self.max_seq_len, image_encode[0].shape[2]) for _ in history_states]
for i, ths in enumerate(total_history_states):
hs = history_states[i]
ths[:hs.shape[0], :hs.shape[1], :] = hs
outputs = []
common_info = {
"modality": "text",
'data_type': 'input',
'moe_embedding': moe_embedding,
}
for t in range(self.max_seq_len):
cur_beam_size = 1 if t == 0 else beam_size
history_states = [ ths[ :(cur_beam_size*bs), :(len_prefix+t), :] for ths in total_history_states]
step_data = { "data": wt,
"time_step": t,
"sample_info":
{
"data_cum_length": [1, len_prefix, len_prefix+t+1]
}
}
step_data.update(common_info)
step_encode = model._forward_data([step_data], task_info=task_info, history_states=history_states, return_all=False)
step_predictor_input = {
"input_sample_list": step_encode,
"target_sample_list": [],
"shared_target_sets": {self.vocab_name: [vocab_encode]},
"target_set_list": [self.vocab_name],
"target_idx_list": [],
"task_info": task_info
}
logit = model.loss_prepare(**step_predictor_input)['output']
with autocast(enabled=not self.cfg.SOLVER.FORCE_SOFTMAX_FP16):
word_logprob = F.log_softmax(logit, dim=-1)
word_logprob = word_logprob.view(bs, cur_beam_size, -1)
candidate_logprob = seq_logprob + word_logprob
# # Mask sequence if it reaches EOS
# if t > 0:
# mask = (selected_words.view(bs, cur_beam_size) != 0).float().unsqueeze(-1) # 为什么是不等于0
# seq_mask = seq_mask * mask
# word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
# old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous()
# old_seq_logprob[:, :, 1:] = -999
# candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask)
if t > 0:
mask = (selected_words.view(bs, cur_beam_size) != self.eos_token_id).float().unsqueeze(-1)
seq_mask = seq_mask * mask
word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous()
old_seq_logprob[:, :, :self.eos_token_id] = -999
old_seq_logprob[:, :, self.eos_token_id + 1:] = -999
candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask)
selected_idx, selected_logprob = self._select(bs, beam_size, t, candidate_logprob) # bs beam
selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode='floor')
selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
self._expand_state(history_states, selected_beam, bs, beam_size, cur_beam_size)
seq_logprob = selected_logprob.unsqueeze(-1)
seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1))
outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
outputs.append(selected_words.unsqueeze(-1))
this_word_logprob = torch.gather(word_logprob, 1,
selected_beam.unsqueeze(-1).expand(bs, beam_size, word_logprob.shape[-1]))
this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
log_probs = list(
torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(bs, beam_size, 1)) for o in log_probs)
log_probs.append(this_word_logprob)
selected_words = selected_words.view(-1, 1)
# wt = selected_words
if t == 0:
u_tokens_type = expand_tensor(u_tokens_type, beam_size)
wt = expand_tensor(wt, beam_size)
step_selected_data = {"data": selected_words, "time_step": t, "sample_info": {"data_cum_length": [1, len_prefix, len_prefix+t+1]}}
step_selected_data.update(common_info)
step_selected_encode = model._forward_data([step_selected_data], task_info=task_info, history_states=history_states, return_all=True)
for i, ths in enumerate(total_history_states):
hs = history_states[i]
ths[:hs.shape[0], :hs.shape[1], :] = hs
ths[:hs.shape[0], hs.shape[1], :] = step_selected_encode[0]['data'][i].squeeze(1)
outputs = torch.cat(outputs, -1)
if self.len_penalty > 0:
step = outputs.ne(self.eos_token_id).sum(-1, keepdim=True) + 1
seq_logprob /= step ** self.len_penalty
seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
outputs = torch.gather(outputs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len))
log_probs = torch.cat(log_probs, -1)
log_probs = torch.gather(log_probs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len))
outputs = outputs.contiguous()[:, :out_size]
log_probs = log_probs.contiguous()[:, :out_size]
if out_size == 1:
outputs = outputs.squeeze(1)
log_probs = log_probs.squeeze(1)
comm._CAPTION_GEN_MODE = False
ids = torch.stack([torch.tensor(v['id']) for v in inputs['input_sample_list'][0]['sample_info']])
return {
"IDS": ids,
"G_SENTS_IDS": outputs,
"G_LOGP": log_probs
}