|
import torch |
|
import transformers |
|
from typing import Union, Optional, Tuple |
|
from transformers import AutoConfig, AutoModel |
|
from transformers.models.clip.modeling_clip import CLIPTextModelOutput |
|
|
|
|
|
class MCLIPConfig(transformers.PretrainedConfig): |
|
model_type = "M-CLIP" |
|
|
|
def __init__(self, modelBase='xlm-roberta-large', transformerDimSize=1024, imageDimSize=768, **kwargs): |
|
self.transformerDimensions = transformerDimSize |
|
self.numDims = imageDimSize |
|
self.modelBase = modelBase |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
class MultilingualCLIP(transformers.PreTrainedModel): |
|
config_class = MCLIPConfig |
|
|
|
def __init__(self, config, *args, **kwargs): |
|
super().__init__(config, *args, **kwargs) |
|
self.transformer = transformers.AutoModel.from_pretrained(config.modelBase) |
|
self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions, |
|
out_features=config.numDims) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CLIPTextModelOutput]: |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
text_outputs = self.transformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = text_outputs[1] |
|
|
|
text_embeds = self.LinearTransformation(pooled_output) |
|
|
|
if not return_dict: |
|
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] |
|
return tuple(output for output in outputs if output is not None) |
|
|
|
return CLIPTextModelOutput( |
|
text_embeds=text_embeds, |
|
last_hidden_state=text_outputs.last_hidden_state, |
|
hidden_states=text_outputs.hidden_states, |
|
attentions=text_outputs.attentions, |
|
) |
|
|
|
@classmethod |
|
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True): |
|
model.load_state_dict(state_dict) |
|
return model, [], [], [] |
|
|
|
AutoConfig.register("M-CLIP", MCLIPConfig) |
|
AutoModel.register(MCLIPConfig, MultilingualCLIP) |