import torch import clip import PIL.Image import skimage.io as io import streamlit as st from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel from model import generate2,ClipCaptionModel from engine import inference model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu'))) image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning") def show_n_generate(img, greedy = True, model): image = Image.open(img) pixel_values = image_processor(image, return_tensors ="pt").pixel_values plt.imshow(np.asarray(image)) plt.show() if greedy: generated_ids = model.generate(pixel_values, max_new_tokens = 30) else: generated_ids = model.generate( pixel_values, do_sample=True, max_new_tokens = 30, top_k=5) generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text device = "cpu" clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") prefix_length = 10 model = ClipCaptionModel(prefix_length) model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu'))) model = model.eval() coco_model = ClipCaptionModel(prefix_length) coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu'))) model = model.eval() def ui(): st.markdown("# Image Captioning") uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg']) if uploaded_file is not None: image = io.imread(uploaded_file) pil_image = PIL.Image.fromarray(image) image = preprocess(pil_image).unsqueeze(0).to(device) option = st.selectbox('Please select the Model',('Model', 'COCO Model','PreTrained Model','Fine Tuned Model')) if option=='Model': with torch.no_grad(): prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed) st.image(uploaded_file, width = 500, channels = 'RGB') st.markdown("**PREDICTION:** " + generated_text_prefix) elif option=='COCO Model': with torch.no_grad(): prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) generated_text_prefix = generate2(coco_model, tokenizer, embed=prefix_embed) st.image(uploaded_file, width = 500, channels = 'RGB') st.markdown("**PREDICTION:** " + generated_text_prefix) elif option=='PreTrained Model': out = inference(uploaded_file) st.image(uploaded_file, width = 500, channels = 'RGB') st.markdown("**PREDICTION:** " + out) elif option=='Fine Tuned Model': out=show_n_generate(uploaded_file, greedy = False, model = model_trained) st.image(uploaded_file, width = 500, channels = 'RGB') st.markdown("**PREDICTION:** " + out) if __name__ == '__main__': ui()