File size: 2,312 Bytes
8478037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from open_clip import create_model
from transformers import PretrainedConfig, PreTrainedModel, CLIPProcessor
from transformers.models.clip.modeling_clip import CLIPOutput
from typing import Optional, Tuple, Union

class MarqoFashionCLIPConfig(PretrainedConfig):
    def __init__(
        self,
        open_clip_model_name: str = "",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.open_clip_model_name = open_clip_model_name
        
        
class MarqoFashionCLIP(PreTrainedModel):
    config_class = MarqoFashionCLIPConfig

    def __init__(self, config: MarqoFashionCLIPConfig):
        super().__init__(config)
        self.config = config
        self.model = create_model(config.open_clip_model_name, output_dict=True)
        self.model.to(self.device)
        self.model.eval()
        
    def get_image_features(
        self, 
        pixel_values: torch.FloatTensor, 
        normalize: bool = False,
        **kwargs
        ) -> torch.FloatTensor:
        
        with torch.inference_mode():
            image_features = self.model.encode_image(pixel_values, normalize=normalize)
        return image_features
    
    def get_text_features(
        self,
        input_ids: torch.Tensor,
        normalize: bool = False,
        **kwargs
    ) -> torch.FloatTensor:
        
        with torch.inference_mode():
            text_features = self.model.encode_text(input_ids, normalize=normalize)
        return text_features
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CLIPOutput]:

        vision_outputs = self.get_image_features(pixel_values=pixel_values, normalize=True)
        text_outputs = self.get_text_features(input_ids=input_ids, normalize=True)

        logits_per_text = text_outputs @ vision_outputs.T
        logits_per_image = logits_per_text.T

        if not return_dict:
            return logits_per_image, logits_per_text, text_outputs, vision_outputs

        return CLIPOutput(
            logits_per_image=logits_per_image,
            logits_per_text=logits_per_text,
            text_embeds=text_outputs,
            image_embeds=vision_outputs
        )