import torch import clip import PIL.Image import skimage.io as io import streamlit as st from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup from model import generate2,ClipCaptionModel #model loading code device = "cpu" clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") prefix_length = 10 model = ClipCaptionModel(prefix_length) model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu'))) model = model.eval() coco_model = ClipCaptionModel(prefix_length) coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu'))) model = model.eval() def ui(): st.markdown("# Image Captioning") uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg']) if uploaded_file is not None: image = io.imread(uploaded_file) pil_image = PIL.Image.fromarray(image) image = preprocess(pil_image).unsqueeze(0).to(device) option = st.selectbox('Please select the Model',('Model', 'COCO Model')) if option=='Model': with torch.no_grad(): prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed) st.image(uploaded_file, width = 500, channels = 'RGB') st.markdown("**PREDICTION:** " + generated_text_prefix) elif option=='COCO Model': with torch.no_grad(): prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) generated_text_prefix = generate2(coco_model, tokenizer, embed=prefix_embed) st.image(uploaded_file, width = 500, channels = 'RGB') st.markdown("**PREDICTION:** " + generated_text_prefix) if __name__ == '__main__': ui()