File size: 6,356 Bytes
9a7fe1f
 
 
 
 
 
 
 
 
 
 
 
 
168a510
9a7fe1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168a510
 
 
 
 
 
 
 
 
 
9a7fe1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from typing import List
import os

import torch
import torch.nn as nn
import numpy as np
from functools import partial
from core.models.common.get_model import register
from einops import rearrange

from transformers import CLIPTokenizer, CLIPTextModel
from .clip_modules import CLIPProcessor, CLIPModel, CLIPTokenizer, CLIPConfig


version = '0'
symbol = 'clip'


class AbstractEncoder(nn.Module):
    def __init__(self):
        super().__init__()

    def encode(self, *args, **kwargs):
        raise NotImplementedError


@register('clip_text_frozen', version)
class FrozenCLIPTextEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""

    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):  # clip-vit-base-patch32
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens)
        z = outputs.last_hidden_state
        return z

    def encode(self, text):
        return self(text)


@register('clip_frozen', version)
class FrozenCLIP(AbstractEncoder):
    def __init__(self,
                 version="openai/clip-vit-large-patch14",
                 max_length=77,
                 encode_type='encode_text',
                 fp16=False,
                 data_dir='.'):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.processor = CLIPProcessor.from_pretrained(version)
        config = CLIPConfig.from_pretrained(version)
        self.model = CLIPModel(config, add_temporal_attention=True)
        self.max_length = max_length
        self.encode_type = encode_type
        self.fp16 = fp16

    @property
    def dtype(self):
        return torch.float32

    @property
    def device(self):
        return self.model.text_projection.weight.device

    def get_device(self):
        # A trick to get device
        return self.model.text_projection.weight.device

    def freeze(self, modules):
        for module in modules:
            for param in module.parameters():
                param.requires_grad = False

    def unfreeze(self, modules):
        for module in modules:
            for param in module.parameters():
                param.requires_grad = True

    def encode_text_pooled(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.get_device())
        outputs = self.model.get_text_features(input_ids=tokens)
        return outputs

    def encode_vision_pooled(self, images):
        inputs = self.processor(images=images, return_tensors="pt")
        pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values']
        pixels = pixels.to(self.get_device())
        return self.model.get_image_features(pixel_values=pixels)

    def encode_text_noproj(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.get_device())
        if self.dtype == torch.half:
            tokens = tokens.short()
        outputs = self.model.text_model(input_ids=tokens)
        return outputs.last_hidden_state

    def encode_vision_noproj(self, vision_inputs):
        # vision_inputs = ((vision_inputs + 1) / 2).to('cpu').numpy()
        vision_inputs = vision_inputs.to('cpu').numpy()

        if vision_inputs.ndim == 5:
            num_frames = vision_inputs.shape[2]
            vision_inputs = rearrange(vision_inputs, 'b c f h w -> (b f) h w c')
        else:
            num_frames = 1
            vision_inputs = rearrange(vision_inputs, 'b c h w -> b h w c')

        vision_inputs = [vi for vi in vision_inputs]
        inputs = self.processor(images=vision_inputs, return_tensors="pt")

        pixels = inputs['pixel_values'].to(self.dtype).to(self.device)

        if num_frames > 1:
            pixels = rearrange(pixels, '(b f) h w c -> b f h w c', f=num_frames)
        outputs = self.model.vision_model(pixel_values=pixels)
        return outputs

    def encode_text(self, text):
        if isinstance(text, List):
            batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                            return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
            tokens = batch_encoding["input_ids"].to(self.get_device())
        else:
            tokens = text
        outputs = self.model.text_model(input_ids=tokens)
        z_pooled = outputs.pooler_output
        z_pooled = self.model.text_projection(z_pooled)
        z_pooled = z_pooled / torch.norm(z_pooled, dim=-1, keepdim=True)
        return z_pooled.unsqueeze(1)

    def encode_vision(self, images):
        z = self.encode_vision_noproj(images)
        z_pooled = z.pooler_output
        z_pooled = self.model.visual_projection(z_pooled)
        z_pooled = z_pooled / torch.norm(z_pooled, dim=-1, keepdim=True)
        return z_pooled.unsqueeze(1)

    def encode(self, *args, **kwargs):
        return getattr(self, self.encode_type)(*args, **kwargs)

    def forward(self, input, encode_type):
        if encode_type == 'encode_text':
            return self.encode_text(input)
        elif encode_type == 'encode_vision':
            # Se il numero di canali è 1, copiamo l'immagine su 3 canali essendo un'immagine in scala di grigi
            if input.shape[1] == 1:
                input = torch.cat([input, input, input], dim=1)
            return self.encode_vision(input)