Spaces:
Runtime error
Runtime error
File size: 6,720 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class SeqGenerationHead(BaseModule):
"""Generation head for multi-modal pre-trained task, adopted by BLIP.
Normally used for generation task.
Args:
decoder (dict): Decoder for blip generation head.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
def __init__(
self,
decoder: dict,
ignore_index=-100,
loss: dict = dict(type='LabelSmoothLoss', label_smooth_val=0.1),
init_cfg: Optional[dict] = None,
) -> None:
super(SeqGenerationHead, self).__init__(init_cfg=init_cfg)
self.decoder = MODELS.build(decoder)
self.loss_fn = MODELS.build(loss)
self.ignore_index = ignore_index
def forward(self, input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor, labels: torch.Tensor):
"""Forward to get decoder output.
Args:
input_ids (torch.Tensor): The tokenized input text tensor.
encoder_hidden_states (torch.Tensor): Hidden states from image
embeddings.
encoder_attention_mask (torch.Tensor): Image embeddings hidden
states attention mask.
labels (torch.Tensor): Decoder target for calculate loss.
Returns:
dict[str, Tensor]: a dictionary of decoder outputs.
"""
decoder_out = self.decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
labels=labels,
return_dict=True,
)
return decoder_out
def loss(self, input_ids, encoder_hidden_states, encoder_attention_mask,
labels):
"""Calculate losses from the extracted features.
Args:
input_ids (torch.Tensor): The tokenized input text tensor.
encoder_hidden_states (torch.Tensor): Hidden states from image
embeddings.
encoder_attention_mask (torch.Tensor): Image embeddings hidden
states attention mask.
labels (torch.Tensor): Decoder target for calculate loss.
Returns:
dict[str, Tensor]: a dictionary of loss components.
"""
decoder_out = self(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
labels=labels,
)
prediction_scores = decoder_out['logits']
# we are doing next-token prediction;
# shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
vocab_size = prediction_scores.shape[-1]
# mask ignored index
if (labels == self.ignore_index).any():
labels = labels.view(-1).clone()
ignore_mask = (labels == self.ignore_index)
labels.masked_fill_(ignore_mask, 0)
weight = torch.logical_not(ignore_mask)
avg_factor = max(weight.sum(), 1)
else:
weight = None
avg_factor = labels.size(0)
lm_loss = self.loss_fn(
shifted_prediction_scores.view(-1, vocab_size),
labels,
weight=weight,
avg_factor=avg_factor,
)
losses = {
'seq_gen_lm_loss': lm_loss,
}
return losses
def predict(self,
input_ids,
encoder_hidden_states,
sep_token_id,
pad_token_id,
use_nucleus_sampling=False,
num_beams=3,
max_length=20,
min_length=2,
top_p=0.9,
repetition_penalty=1.0,
**kwargs):
"""Decoder prediction method.
Args:
input_ids (torch.Tensor): The tokenized input text tensor.
encoder_hidden_states (torch.Tensor): Hidden states from image
embeddings.
sep_token_id (int): Tokenid of separation token.
pad_token_id (int): Tokenid of pad token.
use_nucleus_sampling (bool): Whether to use nucleus sampling in
prediction. Defaults to False.
num_beams (int): Number of beams used in predition.
Defaults to 3.
max_length (int): Max length of generated text in predition.
Defaults to 20.
min_length (int): Min length of generated text in predition.
Defaults to 20.
top_p (float):
If < 1.0, only keep the top tokens with cumulative probability
>= top_p (nucleus filtering). Defaults to 0.9.
repetition_penalty (float): The parameter for repetition penalty.
Defaults to 1.0.
**kwarg: Other arguments that might used in generation.
Returns:
dict[str, Tensor]: a dictionary of generation outputs.
"""
device = encoder_hidden_states.device
# TODO: In old version of transformers
# Additional repeat interleave of hidden states should be add here.
image_atts = torch.ones(
encoder_hidden_states.size()[:-1], dtype=torch.long).to(device)
model_kwargs = {
'encoder_hidden_states': encoder_hidden_states,
'encoder_attention_mask': image_atts,
}
model_kwargs.update(kwargs)
if use_nucleus_sampling:
# nucleus sampling
outputs = self.decoder.generate(
input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
eos_token_id=sep_token_id,
pad_token_id=pad_token_id,
repetition_penalty=1.1,
**model_kwargs)
else:
# beam search
outputs = self.decoder.generate(
input_ids=input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
eos_token_id=sep_token_id,
pad_token_id=pad_token_id,
repetition_penalty=repetition_penalty,
**model_kwargs)
return outputs
|