import torch from torch import nn, Tensor from .blocks import LayerNorm, Transformer class CLIPTextEncoder(nn.Module): def __init__( self, embed_dim: int, context_length: int, vocab_size: int, transformer_width: int, transformer_heads: int, transformer_layers: int, ) -> None: super().__init__() self.context_length = context_length self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.transformer = Transformer( width=transformer_width, layers=transformer_layers, heads=transformer_heads, attn_mask=self.build_attention_mask(), ) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask @property def dtype(self): return self.transformer.resblocks[0].attn.in_proj_weight.dtype def forward(self, text: Tensor): x = self.token_embedding(text).type(self.dtype) x = x + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x