CLIP-EBC / models /clip /_clip /text_encoder.py
Yiming-M's picture
🐣 born
570db9a
import torch
from torch import nn, Tensor
from .blocks import LayerNorm, Transformer
class CLIPTextEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
) -> None:
super().__init__()
self.context_length = context_length
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.transformer.resblocks[0].attn.in_proj_weight.dtype
def forward(self, text: Tensor):
x = self.token_embedding(text).type(self.dtype)
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x