Text-to-Image
Diffusers
Safetensors
anytext / text_embedding_module /frozen_clip_embedder_t3.py
tolgacangoz's picture
Upload frozen_clip_embedder_t3.py
3809f07 verified
raw
history blame
8.75 kB
import torch
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenCLIPEmbedderT3(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cpu",
max_length=77,
freeze=True,
use_fp16=False,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(
version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32
).to(device)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def embedding_forward(
self,
input_ids=None,
position_ids=None,
inputs_embeds=None,
embedding_manager=None,
):
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
self.transformer.text_model.embeddings
)
def encoder_forward(
self,
inputs_embeds,
attention_mask=None,
causal_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
def text_encoder_forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(
input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager
)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = self.final_layer_norm(last_hidden_state)
return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
def transformer_forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
embedding_manager=embedding_manager,
)
self.transformer.forward = transformer_forward.__get__(self.transformer)
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(
text,
truncation=False,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="longest",
return_tensors="pt",
)
input_ids = batch_encoding["input_ids"]
tokens_list = self.split_chunks(input_ids)
z_list = []
for tokens in tokens_list:
tokens = tokens.to(self.device)
_z = self.transformer(input_ids=tokens, **kwargs)
z_list += [_z]
return torch.cat(z_list, dim=1)
def encode(self, text, **kwargs):
return self(text, **kwargs)
def split_chunks(self, input_ids, chunk_size=75):
tokens_list = []
bs, n = input_ids.shape
id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1]
id_end = input_ids[:, -1].unsqueeze(1)
if n == 2: # empty caption
tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))
trimmed_encoding = input_ids[:, 1:-1]
num_full_groups = (n - 2) // chunk_size
for i in range(num_full_groups):
group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size]
group_pad = torch.cat((id_start, group, id_end), dim=1)
tokens_list.append(group_pad)
remaining_columns = (n - 2) % chunk_size
if remaining_columns > 0:
remaining_group = trimmed_encoding[:, -remaining_columns:]
padding_columns = chunk_size - remaining_group.shape[1]
padding = id_end.expand(bs, padding_columns)
remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1)
tokens_list.append(remaining_group_pad)
return tokens_list
def to(self, *args, **kwargs):
self.transformer = self.transformer.to(*args, **kwargs)
self.device = self.transformer.device
return self