File size: 3,224 Bytes
0225049
 
 
f56e8d2
0225049
 
 
389e486
dd61d2c
655168b
0225049
389e486
6493c43
1911df6
389e486
 
 
d058d78
389e486
 
 
 
 
 
 
 
 
 
 
 
cd670ba
0225049
 
 
 
 
 
 
 
 
1911df6
0225049
 
 
 
1911df6
99b28b8
0225049
 
 
 
6493c43
0225049
 
 
 
 
 
 
bbba82b
0225049
bbba82b
0225049
 
 
 
 
 
 
bbba82b
655168b
 
 
 
7f328f8
 
 
 
389e486
 
0225049
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch 
import clip
import PIL.Image
from PIL import Image
import skimage.io as io
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
from model import generate2,ClipCaptionModel
from engine import inference


model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu')),strict=False)
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer       = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

def show_n_generate(img, model, greedy = True):
    image = Image.open(img)
    pixel_values   = image_processor(image, return_tensors ="pt").pixel_values

    if greedy:
        generated_ids  = model.generate(pixel_values, max_new_tokens = 30)
    else:
        generated_ids  = model.generate(
            pixel_values,
            do_sample=True,
            max_new_tokens = 30,
            top_k=5)
    generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

device =  "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

prefix_length = 10

model = ClipCaptionModel(prefix_length)

model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu')),strict=False) 

model = model.eval() 

coco_model = ClipCaptionModel(prefix_length)
coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu')),strict=False)
model = model.eval()  


def ui():
    st.markdown("# Image Captioning")
    # st.markdown("## Done By- Vageesh and Rushil")
    uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg'])

    if uploaded_file is not None:
        image = io.imread(uploaded_file)
        pil_image = PIL.Image.fromarray(image)
        image = preprocess(pil_image).unsqueeze(0).to(device)

        option = st.selectbox('Please select the Model',('Clip Captioning','Attention Decoder','VIT+GPT2'))

        if option=='Clip Captioning':
            with torch.no_grad():
                prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
                prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
            generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)

            st.image(uploaded_file, width = 500, channels = 'RGB')
            st.markdown("**PREDICTION:** " + generated_text_prefix)
        elif option=='Attention Decoder': 
            out = inference(uploaded_file)
            st.image(uploaded_file, width = 500, channels = 'RGB')
            st.markdown("**PREDICTION:** " + out)

        # elif option=='VIT+GPT2': 
        #     out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
        #     st.image(uploaded_file, width = 500, channels = 'RGB')
        #     st.markdown("**PREDICTION:** " + out)
            


if __name__ == '__main__':
    ui()