|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
from uniperceiver.config import configurable |
|
from uniperceiver.functional import expand_tensor |
|
from .decode_strategy import DecodeStrategy |
|
from .build import DECODE_STRATEGY_REGISTRY |
|
from uniperceiver.utils import comm |
|
import math |
|
from torch.cuda.amp import autocast |
|
|
|
@DECODE_STRATEGY_REGISTRY.register() |
|
class CaptionBeamSearcherV3(DecodeStrategy): |
|
|
|
def data_half(self, data): |
|
if self.fp16: |
|
for k, v in data.items(): |
|
if isinstance(v, torch.Tensor) and v.dtype == torch.float32: |
|
data[k] = v.half() |
|
|
|
return data |
|
else: |
|
return data |
|
|
|
|
|
|
|
|
|
def _select(self, batch_size, beam_size, t, candidate_logprob): |
|
selected_logprob, selected_idx = torch.sort(candidate_logprob.view(batch_size, -1), -1, descending=True) |
|
selected_logprob, selected_idx = selected_logprob[:, :beam_size], selected_idx[:, :beam_size] |
|
return selected_idx, selected_logprob |
|
|
|
def _expand_state(self, states, selected_beam, batch_size, beam_size, cur_beam_size): |
|
for i in range(len(states)): |
|
shape = list(states[i].shape) |
|
beam = selected_beam |
|
for _ in shape[1:]: |
|
beam = beam.unsqueeze(-1) |
|
states[i] = torch.gather(states[i].view(*([batch_size, cur_beam_size] + shape[1:])), 1, |
|
beam.expand(*([batch_size, beam_size] + shape[1:]))) |
|
states[i] = states[i].view(*([-1, ] + shape[1:])) |
|
|
|
|
|
def _forward(self, batched_inputs, model): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = batched_inputs |
|
inputs = self.data_half(inputs) |
|
|
|
|
|
out_size = batched_inputs.get('OUT_SIZE', 1) |
|
|
|
task_info = inputs['task_info'] |
|
bs = task_info['batch_size'] |
|
if isinstance(bs, torch.Tensor): |
|
bs = bs.item() |
|
|
|
image_input = inputs['input_sample_list'] |
|
vocab_input = inputs['shared_target_sets'][self.vocab_name] |
|
|
|
|
|
|
|
moe_embedding = None |
|
for image_data in image_input: |
|
if 'moe_embedding' in image_data: |
|
moe_embedding = image_data['moe_embedding'] |
|
image_encode = model._forward_data(image_input, task_info=task_info, return_all=True)[0]['data'] |
|
|
|
|
|
|
|
if getattr(self, 'pre_computed_word_embeds', None) is None: |
|
vocab_encode = model._forward_data(vocab_input, task_info=task_info, return_all=False)[0] |
|
self.pre_computed_word_embeds = vocab_encode |
|
else: |
|
vocab_encode = self.pre_computed_word_embeds |
|
|
|
|
|
|
|
comm._CAPTION_GEN_MODE = True |
|
task_info.update({"prefix_spe_before_fuse": False}) |
|
|
|
beam_size = self.beam_size |
|
log_probs = [] |
|
selected_words = None |
|
seq_logprob = torch.zeros((bs, 1, 1)).cuda() |
|
seq_mask = torch.ones((bs, beam_size, 1)).cuda() |
|
wt = Variable(torch.zeros(bs, dtype=torch.long).cuda().unsqueeze(1)) + self.spe_token_id |
|
u_tokens_type = wt.new_zeros(wt.shape) |
|
|
|
history_states = image_encode[:-1] |
|
len_prefix = history_states[0].shape[1] |
|
total_history_states = [ history_states[0].new_zeros(beam_size * bs, image_encode[0].shape[1] + self.max_seq_len, image_encode[0].shape[2]) for _ in history_states] |
|
for i, ths in enumerate(total_history_states): |
|
hs = history_states[i] |
|
ths[:hs.shape[0], :hs.shape[1], :] = hs |
|
|
|
outputs = [] |
|
common_info = { |
|
"modality": "text", |
|
'data_type': 'input', |
|
'moe_embedding': moe_embedding, |
|
|
|
} |
|
for t in range(self.max_seq_len): |
|
cur_beam_size = 1 if t == 0 else beam_size |
|
|
|
history_states = [ ths[ :(cur_beam_size*bs), :(len_prefix+t), :] for ths in total_history_states] |
|
|
|
step_data = { "data": wt, |
|
"time_step": t, |
|
"sample_info": |
|
{ |
|
"data_cum_length": [1, len_prefix, len_prefix+t+1] |
|
} |
|
} |
|
step_data.update(common_info) |
|
|
|
step_encode = model._forward_data([step_data], task_info=task_info, history_states=history_states, return_all=False) |
|
|
|
step_predictor_input = { |
|
"input_sample_list": step_encode, |
|
"target_sample_list": [], |
|
"shared_target_sets": {self.vocab_name: [vocab_encode]}, |
|
"target_set_list": [self.vocab_name], |
|
"target_idx_list": [], |
|
"task_info": task_info |
|
} |
|
logit = model.loss_prepare(**step_predictor_input)['output'] |
|
|
|
with autocast(enabled=not self.cfg.SOLVER.FORCE_SOFTMAX_FP16): |
|
word_logprob = F.log_softmax(logit, dim=-1) |
|
word_logprob = word_logprob.view(bs, cur_beam_size, -1) |
|
candidate_logprob = seq_logprob + word_logprob |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if t > 0: |
|
mask = (selected_words.view(bs, cur_beam_size) != self.eos_token_id).float().unsqueeze(-1) |
|
seq_mask = seq_mask * mask |
|
word_logprob = word_logprob * seq_mask.expand_as(word_logprob) |
|
old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous() |
|
old_seq_logprob[:, :, :self.eos_token_id] = -999 |
|
old_seq_logprob[:, :, self.eos_token_id + 1:] = -999 |
|
candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask) |
|
|
|
selected_idx, selected_logprob = self._select(bs, beam_size, t, candidate_logprob) |
|
selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode='floor') |
|
selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] |
|
|
|
self._expand_state(history_states, selected_beam, bs, beam_size, cur_beam_size) |
|
|
|
seq_logprob = selected_logprob.unsqueeze(-1) |
|
seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1)) |
|
outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) |
|
outputs.append(selected_words.unsqueeze(-1)) |
|
|
|
this_word_logprob = torch.gather(word_logprob, 1, |
|
selected_beam.unsqueeze(-1).expand(bs, beam_size, word_logprob.shape[-1])) |
|
this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) |
|
log_probs = list( |
|
torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(bs, beam_size, 1)) for o in log_probs) |
|
log_probs.append(this_word_logprob) |
|
selected_words = selected_words.view(-1, 1) |
|
|
|
|
|
if t == 0: |
|
u_tokens_type = expand_tensor(u_tokens_type, beam_size) |
|
wt = expand_tensor(wt, beam_size) |
|
|
|
step_selected_data = {"data": selected_words, "time_step": t, "sample_info": {"data_cum_length": [1, len_prefix, len_prefix+t+1]}} |
|
step_selected_data.update(common_info) |
|
|
|
step_selected_encode = model._forward_data([step_selected_data], task_info=task_info, history_states=history_states, return_all=True) |
|
|
|
for i, ths in enumerate(total_history_states): |
|
hs = history_states[i] |
|
ths[:hs.shape[0], :hs.shape[1], :] = hs |
|
ths[:hs.shape[0], hs.shape[1], :] = step_selected_encode[0]['data'][i].squeeze(1) |
|
|
|
outputs = torch.cat(outputs, -1) |
|
|
|
|
|
if self.len_penalty > 0: |
|
step = outputs.ne(self.eos_token_id).sum(-1, keepdim=True) + 1 |
|
seq_logprob /= step ** self.len_penalty |
|
seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True) |
|
|
|
outputs = torch.gather(outputs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len)) |
|
log_probs = torch.cat(log_probs, -1) |
|
log_probs = torch.gather(log_probs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len)) |
|
|
|
outputs = outputs.contiguous()[:, :out_size] |
|
log_probs = log_probs.contiguous()[:, :out_size] |
|
if out_size == 1: |
|
outputs = outputs.squeeze(1) |
|
log_probs = log_probs.squeeze(1) |
|
|
|
comm._CAPTION_GEN_MODE = False |
|
|
|
ids = torch.stack([torch.tensor(v['id']) for v in inputs['input_sample_list'][0]['sample_info']]) |
|
|
|
return { |
|
"IDS": ids, |
|
"G_SENTS_IDS": outputs, |
|
"G_LOGP": log_probs |
|
} |
|
|