|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, PreTrainedEncoder, PreTrainedDecoder |
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class CSUMLMEncoder(PreTrainedEncoder): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
|
|
return encoder_outputs |
|
|
|
class CSUMLMDecoder(PreTrainedDecoder): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
head_mask=None, |
|
cross_attn_head_mask=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
|
|
return decoder_outputs |
|
|
|
class CSUMLMModel(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.encoder = CSUMLMEncoder(config) |
|
self.decoder = CSUMLMDecoder(config) |
|
self.multimodal_fusion = MultimodalFusion(config) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
decoder_input_ids=None, |
|
decoder_attention_mask=None, |
|
head_mask=None, |
|
decoder_head_mask=None, |
|
cross_attn_head_mask=None, |
|
encoder_outputs=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
decoder_inputs_embeds=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
|
|
return output |
|
|
|
|
|
CSUMLMModel.register_for_auto_class() |