|
|
|
|
|
|
|
|
|
from typing import Union, List |
|
from collections import OrderedDict |
|
import torch |
|
from torch import nn |
|
import torch |
|
|
|
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer |
|
|
|
__all__ = ["tokenize"] |
|
|
|
count = 0 |
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |
|
|
|
|
|
class QuickGELU(nn.Module): |
|
def forward(self, x: torch.Tensor): |
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
|
super().__init__() |
|
|
|
self.attn = nn.MultiheadAttention(d_model, n_head) |
|
self.ln_1 = LayerNorm(d_model) |
|
self.mlp = nn.Sequential(OrderedDict([ |
|
("c_fc", nn.Linear(d_model, d_model * 4)), |
|
("gelu", QuickGELU()), |
|
("c_proj", nn.Linear(d_model * 4, d_model)) |
|
])) |
|
self.ln_2 = LayerNorm(d_model) |
|
self.attn_mask = attn_mask |
|
|
|
def attention(self, x: torch.Tensor): |
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
|
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = x + self.attention(self.ln_1(x)) |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): |
|
super().__init__() |
|
self.width = width |
|
self.layers = layers |
|
self.resblocks = nn.Sequential( |
|
*[ResidualAttentionBlock(width, heads, attn_mask) \ |
|
for _ in range(layers)]) |
|
|
|
def forward(self, x: torch.Tensor): |
|
return self.resblocks(x) |
|
|
|
class CLIPTEXT(nn.Module): |
|
def __init__(self, |
|
embed_dim=512, |
|
|
|
context_length=77, |
|
vocab_size=49408, |
|
transformer_width=512, |
|
transformer_heads=8, |
|
transformer_layers=12 |
|
): |
|
super().__init__() |
|
|
|
self._tokenizer = _Tokenizer() |
|
self.context_length = context_length |
|
|
|
self.transformer = Transformer( |
|
width=transformer_width, |
|
layers=transformer_layers, |
|
heads=transformer_heads, |
|
attn_mask=self.build_attention_mask() |
|
) |
|
|
|
self.vocab_size = vocab_size |
|
self.token_embedding = nn.Embedding(vocab_size, transformer_width) |
|
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) |
|
self.ln_final = LayerNorm(transformer_width) |
|
|
|
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) |
|
|
|
|
|
self.initialize_parameters() |
|
|
|
def initialize_parameters(self): |
|
nn.init.normal_(self.token_embedding.weight, std=0.02) |
|
nn.init.normal_(self.positional_embedding, std=0.01) |
|
|
|
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
|
attn_std = self.transformer.width ** -0.5 |
|
fc_std = (2 * self.transformer.width) ** -0.5 |
|
for block in self.transformer.resblocks: |
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
|
|
|
if self.text_projection is not None: |
|
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
|
|
|
def build_attention_mask(self): |
|
|
|
|
|
mask = torch.empty(self.context_length, self.context_length) |
|
mask.fill_(float("-inf")) |
|
mask.triu_(1) |
|
return mask |
|
|
|
@property |
|
def device(self): |
|
return self.text_projection.device |
|
|
|
@property |
|
def dtype(self): |
|
return self.text_projection.dtype |
|
|
|
def tokenize(self, |
|
texts: Union[str, List[str]], \ |
|
context_length: int = 77) -> torch.LongTensor: |
|
""" |
|
""" |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
sot_token = self._tokenizer.encoder["<|startoftext|>"] |
|
eot_token = self._tokenizer.encoder["<|endoftext|>"] |
|
all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] |
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
|
|
|
for i, tokens in enumerate(all_tokens): |
|
if len(tokens) > context_length: |
|
st = torch.randint( |
|
len(tokens) - context_length + 1, (1,))[0].item() |
|
tokens = tokens[st: st + context_length] |
|
|
|
result[i, :len(tokens)] = torch.tensor(tokens) |
|
|
|
return result |
|
|
|
def encode_text(self, text): |
|
x = self.token_embedding(text).type(self.dtype) |
|
x = x + self.positional_embedding.type(self.dtype) |
|
x = x.permute(1, 0, 2) |
|
x = self.transformer(x) |
|
x = x.permute(1, 0, 2) |
|
x = self.ln_final(x).type(self.dtype) |
|
|
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
|
return x |
|
|
|
def forward(self, captions): |
|
''' |
|
captions: list of strings |
|
''' |
|
text = self.tokenize(captions).to(self.device) |
|
features = self.encode_text(text) |
|
return features |
|
|
|
|
|
def build_text_encoder(pretrain=True): |
|
text_encoder = CLIPTEXT() |
|
if pretrain: |
|
import clip |
|
pretrained_model, _ = clip.load("ViT-B/32", device='cpu') |
|
state_dict = pretrained_model.state_dict() |
|
to_delete_keys = ["logit_scale", "input_resolution", \ |
|
"context_length", "vocab_size"] + \ |
|
[k for k in state_dict.keys() if k.startswith('visual.')] |
|
for k in to_delete_keys: |
|
if k in state_dict: |
|
del state_dict[k] |
|
print('Loading pretrained CLIP') |
|
text_encoder.load_state_dict(state_dict) |
|
|
|
return text_encoder |