Spaces:
Sleeping
Sleeping
import torch | |
import clip | |
import PIL.Image | |
from PIL import Image | |
import skimage.io as io | |
import streamlit as st | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup | |
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel | |
from model import generate2,ClipCaptionModel | |
from engine import inference | |
model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu'))) | |
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
def show_n_generate(img, model, greedy = True): | |
image = Image.open(img) | |
pixel_values = image_processor(image, return_tensors ="pt").pixel_values | |
plt.imshow(np.asarray(image)) | |
plt.show() | |
if greedy: | |
generated_ids = model.generate(pixel_values, max_new_tokens = 30) | |
else: | |
generated_ids = model.generate( | |
pixel_values, | |
do_sample=True, | |
max_new_tokens = 30, | |
top_k=5) | |
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return generated_text | |
device = "cpu" | |
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
prefix_length = 10 | |
model = ClipCaptionModel(prefix_length) | |
model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu'))) | |
model = model.eval() | |
coco_model = ClipCaptionModel(prefix_length) | |
coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu'))) | |
model = model.eval() | |
def ui(): | |
st.markdown("# Image Captioning") | |
uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg']) | |
if uploaded_file is not None: | |
image = io.imread(uploaded_file) | |
pil_image = PIL.Image.fromarray(image) | |
image = preprocess(pil_image).unsqueeze(0).to(device) | |
option = st.selectbox('Please select the Model',('Model', 'COCO Model','PreTrained Model','Fine Tuned Model')) | |
if option=='Model': | |
with torch.no_grad(): | |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) | |
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed) | |
st.image(uploaded_file, width = 500, channels = 'RGB') | |
st.markdown("**PREDICTION:** " + generated_text_prefix) | |
elif option=='COCO Model': | |
with torch.no_grad(): | |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) | |
generated_text_prefix = generate2(coco_model, tokenizer, embed=prefix_embed) | |
st.image(uploaded_file, width = 500, channels = 'RGB') | |
st.markdown("**PREDICTION:** " + generated_text_prefix) | |
elif option=='PreTrained Model': | |
out = inference(uploaded_file) | |
st.image(uploaded_file, width = 500, channels = 'RGB') | |
st.markdown("**PREDICTION:** " + out) | |
elif option=='Fine Tuned Model': | |
out=show_n_generate(uploaded_file, greedy = False, model = model_trained) | |
st.image(uploaded_file, width = 500, channels = 'RGB') | |
st.markdown("**PREDICTION:** " + out) | |
if __name__ == '__main__': | |
ui() | |