Spaces:
Running
Running
from typing import Optional, Union, Tuple, List | |
import torch | |
from transformers import VisionEncoderDecoderModel | |
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput | |
class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel): | |
def forward( | |
self, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
decoder_input_boxes: torch.LongTensor = None, | |
# Shape (batch_size, num_boxes, 4), all coords scaled 0 - 1000, with 1001 as padding | |
decoder_input_boxes_mask: torch.LongTensor = None, # Shape (batch_size, num_boxes), 0 if padding, 1 otherwise | |
decoder_input_boxes_counts: torch.LongTensor = None, # Shape (batch_size), number of boxes in each image | |
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, | |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[List[List[int]]] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
**kwargs, | |
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} | |
kwargs_decoder = { | |
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") | |
} | |
if encoder_outputs is None: | |
if pixel_values is None: | |
raise ValueError("You have to specify pixel_values") | |
encoder_outputs = self.encoder( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
**kwargs_encoder, | |
) | |
elif isinstance(encoder_outputs, tuple): | |
encoder_outputs = BaseModelOutput(*encoder_outputs) | |
encoder_hidden_states = encoder_outputs[0] | |
# optionally project encoder_hidden_states | |
if ( | |
self.encoder.config.hidden_size != self.decoder.config.hidden_size | |
and self.decoder.config.cross_attention_hidden_size is None | |
): | |
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) | |
# else: | |
encoder_attention_mask = None | |
# Decode | |
decoder_outputs = self.decoder( | |
input_boxes=decoder_input_boxes, | |
input_boxes_mask=decoder_input_boxes_mask, | |
input_boxes_counts=decoder_input_boxes_counts, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
inputs_embeds=decoder_inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
use_cache=use_cache, | |
past_key_values=past_key_values, | |
return_dict=return_dict, | |
labels=labels, | |
**kwargs_decoder, | |
) | |
if not return_dict: | |
return decoder_outputs + encoder_outputs | |
return Seq2SeqLMOutput( | |
loss=decoder_outputs.loss, | |
logits=decoder_outputs.logits, | |
past_key_values=decoder_outputs.past_key_values, | |
decoder_hidden_states=decoder_outputs.hidden_states, | |
decoder_attentions=decoder_outputs.attentions, | |
cross_attentions=decoder_outputs.cross_attentions, | |
encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
encoder_hidden_states=encoder_outputs.hidden_states, | |
encoder_attentions=encoder_outputs.attentions, | |
) | |