image-caption / model.py
Lwasinam's picture
Upload 5 files
22ca2be verified
raw
history blame
15.9 kB
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torchvision.transforms.functional import to_pil_image, to_tensor
import time
import numpy as np
from matplotlib.image import imread
from transformers import ViTFeatureExtractor
from io import BytesIO
from base64 import b64decode
import base64
from transformers import ViTImageProcessor, ViTModel
## code from @jankrepl on github
class PretrainedVit():
def __init__(self):
self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
def forward(self, x):
self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
self.model.config.output_hidden_states = True
outputs = self.model(x)
# print(outputs)
last_hidden_states = outputs.hidden_states
return list(last_hidden_states)
class PatchEmbed(nn.Module):
"""Split image into patches and then embed them.
Parameters
----------
img_size : int
Size of the image (it is a square).
patch_size : int
Size of the patch (it is a square).
in_chans : int
Number of input channels.
embed_dim : int
The emmbedding dimension.
Attributes
----------
n_patches : int
Number of patches inside of our image.
proj : nn.Conv2d
Convolutional layer that does both the splitting into patches
and their embedding.
"""
def __init__(self, img_size, patch_size, in_chans=3, embed_dim=1024, num_registers = 6):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.norm = RMSNorm()
self.n_patches = (img_size // patch_size) ** 2
self.pos_embed = nn.Parameter(
torch.zeros(1, self.n_patches+1+num_registers, embed_dim)
)
# Adding CLS token as a learnable parameter
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.register_token = nn.Parameter(torch.zeros(num_registers, embed_dim))
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
)
def forward(self, x):
"""Run forward pass.
Parameters
----------
x : torch.Tensor
Shape `(n_samples, in_chans, img_size, img_size)`.
Returns
-------
torch.Tensor
Shape `(n_samples, n_patches, embed_dim)`.
"""
x = self.proj(x) # (n_samples, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
x = x.flatten(2) # (n_samples, embed_dim, n_patches)
x = x.transpose(1, 2) # (n_samples, n_patches, embed_dim)
batch_size = x.shape[0]
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # Expand CLS tokens for the batch
x = torch.cat([cls_tokens, x], dim=1)
# x: (n_samples, n_patches + 1 + num_registers, embed_dimension) add register tokens
register_tokens = self.register_token.unsqueeze(0).expand(batch_size, -1, -1)
x = torch.cat([x, register_tokens], dim=1)
X = self.norm(x)
x = x + self.pos_embed # Learnable pos embed -> (n_samples, n_patches_embed_dim)
return x
## not used
class RMSNorm(nn.Module):
def __init__(self, dim: int = 1024, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = dim
# The gamma parameter
self.weight = nn.Parameter(torch.ones(self.dim))
def _norm(self, x: torch.Tensor):
# (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
# rsqrt: 1 / sqrt(x)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor):
# (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
return self.weight * self._norm(x.float()).type_as(x)
class LayerNormalization(nn.Module):
def __init__(self, eps:float=1e-12) -> None:
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(1)) # alpha is a learnable parameter
self.bias = nn.Parameter(torch.zeros(1)) # bias is a learnable parameter
def forward(self, x):
# x: (batch, seq_len, hidden_size)
# Keep the dimension for broadcasting
mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
# Keep the dimension for broadcasting
std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
# eps is to prevent dividing by zero or when std is very small
# print(f'mean shape {mean.squeeze(-1).shape}')
return self.alpha * (x - mean) / (std + self.eps) + self.bias
class FeedForwardBlock(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2
def forward(self, x):
# (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
class InputEmbeddings(nn.Module):
def __init__(self, d_model: int, vocab_size: int) -> None:
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, d_model)
def forward(self, x):
# (batch, seq_len) --> (batch, seq_len, d_model)
# Multiply by sqrt(d_model) to scale the embeddings according to the paper
return self.embedding(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.dropout = nn.Dropout(dropout)
# Create a matrix of shape (seq_len, d_model)
pe = torch.zeros(seq_len, d_model)
# Create a vector of shape (seq_len)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
# Create a vector of shape (d_model)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
# Apply sine to even indices
pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
# Apply cosine to odd indices
pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
# Add a batch dimension to the positional encoding
pe = pe.unsqueeze(0) # (1, seq_len, d_model)
# Register the positional encoding as a buffer
self.register_buffer('pe', pe)
def forward(self, x):
x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
return self.dropout(x)
class ResidualConnection(nn.Module):
def __init__(self, dropout: float) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization()
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
class MultiHeadAttentionBlock(nn.Module):
def __init__(self, d_model: int, h: int, dropout: float) -> None:
super().__init__()
self.d_model = d_model # Embedding vector size
self.h = h # Number of heads
# Make sure d_model is divisible by h
assert d_model % h == 0, "d_model is not divisible by h"
self.d_k = d_model // h # Dimension of vector seen by each head
self.w_q = nn.Linear(d_model, d_model) # Wq
self.w_k = nn.Linear(d_model, d_model) # Wk
self.w_v = nn.Linear(d_model, d_model) # Wv
self.w_o = nn.Linear(d_model, d_model) # Wo
self.dropout = nn.Dropout(dropout)
@staticmethod
def attention(query, key, value, mask, dropout: nn.Dropout):
d_k = query.shape[-1]
# Just apply the formula from the paper
# (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
# Write a very low value (indicating -inf) to the positions where mask == 0
attention_scores.masked_fill_(mask == 0, -1e9)
attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
if dropout is not None:
attention_scores = dropout(attention_scores)
# (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
# return attention scores which can be used for visualization
# attention_viz(attention_scores)
return (attention_scores @ value), attention_scores
def forward(self, q, k, v, mask, is_cross=False):
query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
# (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
# Calculate attention
x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
if is_cross:
attention_viz(self.attention_scores)
# Combine all the heads together
# (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
# Multiply by Wo
# (batch, seq_len, d_model) --> (batch, seq_len, d_model)
return self.w_o(x)
class EncoderBlock(nn.Module):
def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float, layer: int ) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
self.layer = layer
def forward(self, x, src_mask, index):
# print(x.shape)
# print(self.layer)
out = x[11]
# out = self.residual_connections[1](out, self.feed_forward_block)
return out
class Encoder(nn.Module):
def __init__(self, layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization()
def forward(self, x, mask):
for index, layer in enumerate(self.layers):
# print(index)
x = layer(x, mask, index)
break
return self.norm(x)
class DecoderBlock(nn.Module):
def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.cross_attention_block = cross_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
def forward(self, x, encoder_output, src_mask, tgt_mask):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
x = self.residual_connections[2](x, self.feed_forward_block)
return x
class Decoder(nn.Module):
def __init__(self, layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization()
def forward(self, x, encoder_output, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.norm(x)
class ProjectionLayer(nn.Module):
def __init__(self, d_model, vocab_size) -> None:
super().__init__()
self.proj = nn.Linear(d_model, vocab_size)
def forward(self, x) -> None:
# (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
return torch.log_softmax(self.proj(x), dim = -1)
class Transformer(nn.Module):
def __init__(self, encoder: Encoder, decoder: Decoder, tgt_embed: InputEmbeddings, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer, att: PretrainedVit) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
# self.src_embed = src_embed
self.tgt_embed = tgt_embed
# self.src_pos = src_pos
self.tgt_pos = tgt_pos
self.projection_layer = projection_layer
self.patch_embed = PatchEmbed(img_size=224, patch_size=14)
self.att = att
def encode(self, src, src_mask):
# (batch, seq_len, d_model)
attention_list = self.att.forward(src)
# src = self.src_pos(src)
return self.encoder(attention_list[1:], src_mask)
def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
# (batch, seq_len, d_model)
tgt = self.tgt_embed(tgt)
tgt = self.tgt_pos(tgt)
return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
def project(self, x):
# (batch, seq_len, vocab_size)
return self.projection_layer(x)
def build_transformer(tgt_vocab_size: int, tgt_seq_len: int, d_model: int=768, N: int=10, h: int=12, dropout: float=0.1, d_ff: int=3072) -> Transformer:
# Create the embedding layers
tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
# Create the positional encoding layers
# src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
#attention from pretrained vit
att = PretrainedVit()
# Create the encoder blocks
encoder_blocks = []
for _ in range(N):
print()
encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout, _)
encoder_blocks.append(encoder_block)
# Create the decoder blocks
decoder_blocks = []
for _ in range(N):
decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
decoder_blocks.append(decoder_block)
# Create the encoder and decoder
encoder = Encoder(nn.ModuleList(encoder_blocks))
decoder = Decoder(nn.ModuleList(decoder_blocks))
# Create the projection layer
projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
# Create the transformer
transformer = Transformer(encoder, decoder, tgt_embed, tgt_pos, projection_layer, att)
# Initialize the parameters
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return transformer