clip_gpt2 / app.py
Vageesh1's picture
Update app.py
e74c12b
raw
history blame
3.14 kB
import torch
import clip
import PIL.Image
from PIL import Image
import skimage.io as io
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
from model import generate2,ClipCaptionModel
from engine import inference
model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu')))
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
def show_n_generate(img, model, greedy = True):
image = Image.open(img)
pixel_values = image_processor(image, return_tensors ="pt").pixel_values
if greedy:
generated_ids = model.generate(pixel_values, max_new_tokens = 30)
else:
generated_ids = model.generate(
pixel_values,
do_sample=True,
max_new_tokens = 30,
top_k=5)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
device = "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
prefix_length = 10
model = ClipCaptionModel(prefix_length)
model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu')))
model = model.eval()
coco_model = ClipCaptionModel(prefix_length)
coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu')))
model = model.eval()
def ui():
st.markdown("# Image Captioning")
uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
image = io.imread(uploaded_file)
pil_image = PIL.Image.fromarray(image)
image = preprocess(pil_image).unsqueeze(0).to(device)
option = st.selectbox('Please select the Model',(' PreTrained Model','Trained Model','Fine Tuned Model'))
if option=='PreTrained Model':
with torch.no_grad():
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
st.image(uploaded_file, width = 500, channels = 'RGB')
st.markdown("**PREDICTION:** " + generated_text_prefix)
elif option=='Trained Model':
out = inference(uploaded_file)
st.image(uploaded_file, width = 500, channels = 'RGB')
st.markdown("**PREDICTION:** " + out)
elif option=='Fine Tuned Model':
out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
st.image(uploaded_file, width = 500, channels = 'RGB')
st.markdown("**PREDICTION:** " + out)
if __name__ == '__main__':
ui()