Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn, Tensor | |
from .blocks import LayerNorm, Transformer | |
class CLIPTextEncoder(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int, | |
context_length: int, | |
vocab_size: int, | |
transformer_width: int, | |
transformer_heads: int, | |
transformer_layers: int, | |
) -> None: | |
super().__init__() | |
self.context_length = context_length | |
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
self.transformer = Transformer( | |
width=transformer_width, | |
layers=transformer_layers, | |
heads=transformer_heads, | |
attn_mask=self.build_attention_mask(), | |
) | |
self.vocab_size = vocab_size | |
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) | |
self.ln_final = LayerNorm(transformer_width) | |
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) | |
def build_attention_mask(self): | |
# lazily create causal attention mask, with full attention between the vision tokens | |
# pytorch uses additive attention mask; fill with -inf | |
mask = torch.empty(self.context_length, self.context_length) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
def dtype(self): | |
return self.transformer.resblocks[0].attn.in_proj_weight.dtype | |
def forward(self, text: Tensor): | |
x = self.token_embedding(text).type(self.dtype) | |
x = x + self.positional_embedding.type(self.dtype) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_final(x).type(self.dtype) | |
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection | |
return x | |