seddiktrk's picture
Update app.py
0536626 verified
raw
history blame
8.83 kB
import streamlit as st
from PIL import Image
import time
from tqdm.auto import tqdm
import numpy as np
import torch
from torch import nn
print(torch.__version__)
device = torch.device('cpu')
print(device)
print('importing tokenizer')
from transformers import GPT2Tokenizer,GPT2LMHeadModel,DataCollatorWithPadding
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token_id = 0
collator = DataCollatorWithPadding(tokenizer = tokenizer)
class EncoderAttention(nn.Module):
def __init__(self,embed_dim=768, num_heads=8, dropout=0.1):
super().__init__()
self.mha = nn.MultiheadAttention(embed_dim, num_heads,batch_first=True, dropout=dropout)
self.layernorm = nn.LayerNorm(embed_dim)
def forward(self,x):
attn, _ = self.mha(query=x,
value=x,
key=x,
need_weights=False,
)
x = x + attn
return self.layernorm(x)
class FeedForward(nn.Module):
def __init__(self, embed_dim=768, dropout_rate=0.1):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(embed_dim, embed_dim*2),
nn.ReLU(),
nn.Linear(embed_dim*2, embed_dim),
nn.Dropout(dropout_rate)
)
self.layernorm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = x + self.seq(x)
return self.layernorm(x)
class MapperLayer(nn.Module):
def __init__(self, embed_dim=768, num_heads=8, dropout_rate=0.1):
super().__init__()
self.attn = EncoderAttention( num_heads=num_heads,
embed_dim=embed_dim,
dropout=dropout_rate)
self.ff = FeedForward(embed_dim=embed_dim,
dropout_rate=dropout_rate)
def forward(self, x):
x = self.attn(x)
x = self.ff(x)
return x
class Transformer(nn.Module):
def __init__(self,
num_layers=8,
num_heads=8,
embed_dim=768,
dropout_rate=0.1
):
super().__init__()
layers = [MapperLayer(embed_dim=embed_dim,
num_heads=num_heads,
dropout_rate=dropout_rate) for i in range(num_layers)]
self.layers = nn.ModuleList(layers)
def forward(self,x):
for layer in self.layers:
x = layer(x)
return x
class TransformerMapper(nn.Module):
def forward(self, x):
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape) # (B,prefix_len,embed_dim)
prefix = torch.cat((x, prefix), dim=1)
return self.transformer(prefix)[:, self.clip_length:]
def __init__(self,
dim_clip = 768,
embed_dim = 768,
prefix_length = 16,
clip_length = 10,
num_layers = 8,
num_heads = 8,
dropout_rate = 0.1
):
super().__init__()
self.clip_length = clip_length
self.transformer = Transformer(
num_layers=num_layers,
num_heads=num_heads,
embed_dim=embed_dim,
dropout_rate=dropout_rate
)
self.linear = nn.Linear(dim_clip, self.clip_length * embed_dim) # CLIP prefixes (clip_length prefixes) (B,clip_len*768)
self.prefix_const = nn.Parameter(torch.randn(prefix_length, embed_dim), requires_grad=True)
class ClipCaptionModel(nn.Module):
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def forward(self,
tokens: torch.Tensor,
prefix: torch.Tensor,
mask: torch.Tensor,
labels=None):
# create embeddings for the gpt model
embedding_text = self.gpt.transformer.wte(tokens)
prefix_projections = self.clip_project(prefix)
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
# prepare mask
if mask.shape[1] != embedding_cat.shape[1]:
dummy_mask = torch.ones(tokens.shape[0],self.prefix_length, dtype=torch.int64, device=mask.device)
mask = torch.cat([dummy_mask,mask],dim=1)
return self.gpt(inputs_embeds=embedding_cat,
labels=labels,
attention_mask=mask)
def __init__(self,
dim_clip = 768,
embed_dim = 768,
prefix_length = 16,
clip_length = 10,
num_layers = 8,
num_heads = 8,
dropout_rate = 0.1,
):
super().__init__()
self.prefix_length = prefix_length
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
self.clip_project = TransformerMapper(
dim_clip = dim_clip,
embed_dim = self.gpt_embedding_size,
prefix_length = prefix_length,
clip_length = clip_length,
num_layers = num_layers,
num_heads = num_heads,
dropout_rate = dropout_rate)
print('loading model')
print()
## Prepare Model
CliPGPT = ClipCaptionModel()
path = "model_epoch_1.pt"
state_dict = torch.load(path,map_location=torch.device('cpu'))
# Apply the weights to the model
CliPGPT.load_state_dict(state_dict)
CliPGPT.to(device)
print('importing CLIP')
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
model.eval()
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def sample_from_logits(logits, temperature=0.3):
logits = logits / temperature
probabilities = torch.softmax(logits, dim=-1)
return torch.multinomial(probabilities, 1).squeeze()
def generate(image,
device=device,
max_tokens=48,
temperature=0.3,
verbose=True,
sample=True,
):
model.to(device)
CliPGPT.to(device)
# encode image
with torch.inference_mode():
input = torch.tensor(np.stack(processor.image_processor(image).pixel_values,axis=0)).to(device)
embeds = model.vision_model(input)
embeds = embeds.pooler_output
CliPGPT.eval()
prefix_length = CliPGPT.prefix_length
# prepare initial token '#' used as token to begin generation of caption
tokens = ['#']
input_ids,attention_mask = collator(tokenizer(tokens)).values()
# forward pass
for i in tqdm(range(max_tokens),desc='generating... '):
input_ids = input_ids.to(device)
embeds = embeds.to(device)
attention_mask = attention_mask.to(device)
with torch.inference_mode():
out = CliPGPT(
tokens= input_ids,
prefix= embeds,
mask= attention_mask,
)
logits = out.logits
logits = logits[:,prefix_length:,:]
# Sampling Technique
if sample:
next_token = sample_from_logits(logits[:, -1, :],
temperature=temperature)
else:
next_token = torch.argmax(logits[:,-1,:],dim=-1).squeeze()
token = next_token.item()
if token == tokenizer.eos_token_id:
break
# update string
tokens = [tokens[0] + tokenizer.decode(next_token)]
# update tokens
input_ids,attention_mask = collator(tokenizer(tokens)).values()
if verbose:
print(token)
print(tokens[0])
print()
return tokens[0].replace('#','').strip()
print('app starts')
st.title("CLIP GPT2 Image Captionning")
st.write("This is a web app for generating captions for images using a model built with CLIP & GPT2.")
# Image upload section
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
# Display the uploaded image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Generate caption button
if st.button('Submit'):
with st.spinner('Generating caption...'):
start_time = time.time()
caption = generate(image)
end_time = time.time()
st.text_area('Output', caption)
st.write(f"Inference time: {end_time - start_time:.2f} seconds")