Spaces:
Running
Running
File size: 3,224 Bytes
0225049 f56e8d2 0225049 389e486 dd61d2c 655168b 0225049 389e486 6493c43 1911df6 389e486 d058d78 389e486 cd670ba 0225049 1911df6 0225049 1911df6 99b28b8 0225049 6493c43 0225049 bbba82b 0225049 bbba82b 0225049 bbba82b 655168b 7f328f8 389e486 0225049 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
import clip
import PIL.Image
from PIL import Image
import skimage.io as io
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
from model import generate2,ClipCaptionModel
from engine import inference
model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu')),strict=False)
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
def show_n_generate(img, model, greedy = True):
image = Image.open(img)
pixel_values = image_processor(image, return_tensors ="pt").pixel_values
if greedy:
generated_ids = model.generate(pixel_values, max_new_tokens = 30)
else:
generated_ids = model.generate(
pixel_values,
do_sample=True,
max_new_tokens = 30,
top_k=5)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
device = "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
prefix_length = 10
model = ClipCaptionModel(prefix_length)
model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu')),strict=False)
model = model.eval()
coco_model = ClipCaptionModel(prefix_length)
coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu')),strict=False)
model = model.eval()
def ui():
st.markdown("# Image Captioning")
# st.markdown("## Done By- Vageesh and Rushil")
uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
image = io.imread(uploaded_file)
pil_image = PIL.Image.fromarray(image)
image = preprocess(pil_image).unsqueeze(0).to(device)
option = st.selectbox('Please select the Model',('Clip Captioning','Attention Decoder','VIT+GPT2'))
if option=='Clip Captioning':
with torch.no_grad():
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
st.image(uploaded_file, width = 500, channels = 'RGB')
st.markdown("**PREDICTION:** " + generated_text_prefix)
elif option=='Attention Decoder':
out = inference(uploaded_file)
st.image(uploaded_file, width = 500, channels = 'RGB')
st.markdown("**PREDICTION:** " + out)
# elif option=='VIT+GPT2':
# out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
# st.image(uploaded_file, width = 500, channels = 'RGB')
# st.markdown("**PREDICTION:** " + out)
if __name__ == '__main__':
ui()
|