File size: 2,312 Bytes
8478037 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
from open_clip import create_model
from transformers import PretrainedConfig, PreTrainedModel, CLIPProcessor
from transformers.models.clip.modeling_clip import CLIPOutput
from typing import Optional, Tuple, Union
class MarqoFashionCLIPConfig(PretrainedConfig):
def __init__(
self,
open_clip_model_name: str = "",
**kwargs,
):
super().__init__(**kwargs)
self.open_clip_model_name = open_clip_model_name
class MarqoFashionCLIP(PreTrainedModel):
config_class = MarqoFashionCLIPConfig
def __init__(self, config: MarqoFashionCLIPConfig):
super().__init__(config)
self.config = config
self.model = create_model(config.open_clip_model_name, output_dict=True)
self.model.to(self.device)
self.model.eval()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
normalize: bool = False,
**kwargs
) -> torch.FloatTensor:
with torch.inference_mode():
image_features = self.model.encode_image(pixel_values, normalize=normalize)
return image_features
def get_text_features(
self,
input_ids: torch.Tensor,
normalize: bool = False,
**kwargs
) -> torch.FloatTensor:
with torch.inference_mode():
text_features = self.model.encode_text(input_ids, normalize=normalize)
return text_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPOutput]:
vision_outputs = self.get_image_features(pixel_values=pixel_values, normalize=True)
text_outputs = self.get_text_features(input_ids=input_ids, normalize=True)
logits_per_text = text_outputs @ vision_outputs.T
logits_per_image = logits_per_text.T
if not return_dict:
return logits_per_image, logits_per_text, text_outputs, vision_outputs
return CLIPOutput(
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_outputs,
image_embeds=vision_outputs
)
|