dmolino's picture
Upload 225 files
9a7fe1f verified
raw
history blame
6.05 kB
from typing import List
import os
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from core.models.common.get_model import register
from einops import rearrange
from transformers import CLIPTokenizer, CLIPTextModel
from .clip_modules import CLIPProcessor, CLIPModel, CLIPTokenizer, CLIPConfig
version = '0'
symbol = 'clip'
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
@register('clip_text_frozen', version)
class FrozenCLIPTextEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
@register('clip_frozen', version)
class FrozenCLIP(AbstractEncoder):
def __init__(self,
version="openai/clip-vit-large-patch14",
max_length=77,
encode_type='encode_text',
fp16=False,
data_dir='.'):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.processor = CLIPProcessor.from_pretrained(version)
config = CLIPConfig.from_pretrained(version)
self.model = CLIPModel(config, add_temporal_attention=True)
self.max_length = max_length
self.encode_type = encode_type
self.fp16 = fp16
@property
def dtype(self):
return torch.float32
@property
def device(self):
return self.model.text_projection.weight.device
def get_device(self):
# A trick to get device
return self.model.text_projection.weight.device
def encode_text_pooled(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.get_device())
outputs = self.model.get_text_features(input_ids=tokens)
return outputs
def encode_vision_pooled(self, images):
inputs = self.processor(images=images, return_tensors="pt")
pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values']
pixels = pixels.to(self.get_device())
return self.model.get_image_features(pixel_values=pixels)
def encode_text_noproj(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.get_device())
if self.dtype == torch.half:
tokens = tokens.short()
outputs = self.model.text_model(input_ids=tokens)
return outputs.last_hidden_state
def encode_vision_noproj(self, vision_inputs):
# vision_inputs = ((vision_inputs + 1) / 2).to('cpu').numpy()
vision_inputs = vision_inputs.to('cpu').numpy()
if vision_inputs.ndim == 5:
num_frames = vision_inputs.shape[2]
vision_inputs = rearrange(vision_inputs, 'b c f h w -> (b f) h w c')
else:
num_frames = 1
vision_inputs = rearrange(vision_inputs, 'b c h w -> b h w c')
vision_inputs = [vi for vi in vision_inputs]
inputs = self.processor(images=vision_inputs, return_tensors="pt")
pixels = inputs['pixel_values'].to(self.dtype).to(self.device)
if num_frames > 1:
pixels = rearrange(pixels, '(b f) h w c -> b f h w c', f=num_frames)
outputs = self.model.vision_model(pixel_values=pixels)
return outputs
def encode_text(self, text):
if isinstance(text, List):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.get_device())
else:
tokens = text
outputs = self.model.text_model(input_ids=tokens)
z_pooled = outputs.pooler_output
z_pooled = self.model.text_projection(z_pooled)
z_pooled = z_pooled / torch.norm(z_pooled, dim=-1, keepdim=True)
return z_pooled.unsqueeze(1)
def encode_vision(self, images):
z = self.encode_vision_noproj(images)
z_pooled = z.pooler_output
z_pooled = self.model.visual_projection(z_pooled)
z_pooled = z_pooled / torch.norm(z_pooled, dim=-1, keepdim=True)
return z_pooled.unsqueeze(1)
def encode(self, *args, **kwargs):
return getattr(self, self.encode_type)(*args, **kwargs)
def forward(self, input, encode_type):
if encode_type == 'encode_text':
return self.encode_text(input)
elif encode_type == 'encode_vision':
# Se il numero di canali è 1, copiamo l'immagine su 3 canali essendo un'immagine in scala di grigi
if input.shape[1] == 1:
input = torch.cat([input, input, input], dim=1)
return self.encode_vision(input)