Spaces:
Sleeping
Sleeping
File size: 3,662 Bytes
0225049 f56e8d2 0225049 389e486 dd61d2c 655168b 0225049 389e486 d058d78 389e486 cd670ba 0225049 cbbcfd4 0225049 cbbcfd4 0225049 389e486 0225049 655168b 389e486 0225049 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
import clip
import PIL.Image
from PIL import Image
import skimage.io as io
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
from model import generate2,ClipCaptionModel
from engine import inference
model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu')))
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
def show_n_generate(img, model, greedy = True):
image = Image.open(img)
pixel_values = image_processor(image, return_tensors ="pt").pixel_values
plt.imshow(np.asarray(image))
plt.show()
if greedy:
generated_ids = model.generate(pixel_values, max_new_tokens = 30)
else:
generated_ids = model.generate(
pixel_values,
do_sample=True,
max_new_tokens = 30,
top_k=5)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
device = "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
prefix_length = 10
model = ClipCaptionModel(prefix_length)
model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu')))
model = model.eval()
coco_model = ClipCaptionModel(prefix_length)
coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu')))
model = model.eval()
def ui():
st.markdown("# Image Captioning")
uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg'])
if uploaded_file is not None:
image = io.imread(uploaded_file)
pil_image = PIL.Image.fromarray(image)
image = preprocess(pil_image).unsqueeze(0).to(device)
option = st.selectbox('Please select the Model',('Model', 'COCO Model','PreTrained Model','Fine Tuned Model'))
if option=='Model':
with torch.no_grad():
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
st.image(uploaded_file, width = 500, channels = 'RGB')
st.markdown("**PREDICTION:** " + generated_text_prefix)
elif option=='COCO Model':
with torch.no_grad():
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
generated_text_prefix = generate2(coco_model, tokenizer, embed=prefix_embed)
st.image(uploaded_file, width = 500, channels = 'RGB')
st.markdown("**PREDICTION:** " + generated_text_prefix)
elif option=='PreTrained Model':
out = inference(uploaded_file)
st.image(uploaded_file, width = 500, channels = 'RGB')
st.markdown("**PREDICTION:** " + out)
elif option=='Fine Tuned Model':
out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
st.image(uploaded_file, width = 500, channels = 'RGB')
st.markdown("**PREDICTION:** " + out)
if __name__ == '__main__':
ui()
|