File size: 3,944 Bytes
e6856f6
 
 
 
 
 
 
 
 
 
 
 
 
 
71abe40
e6856f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9cbbfa
 
 
e6856f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import numpy as np
from PIL import Image
import streamlit as st
from torchvision.transforms import v2
from transformers import GenerationConfig
from transformers import GPT2TokenizerFast
from transformers import ViTImageProcessor
from transformers import VisionEncoderDecoderModel

# Page configuration settings
st.set_page_config(
    layout="centered",
    page_title="Generate Caption",
    initial_sidebar_state="collapsed",
)

# Initializing session state keys
if all(key not in st.session_state.keys() for key in ("generate", "image")):
    st.session_state["generate"] = False
    st.session_state["image"] = None

# Loading necessary resources and caching them
@st.cache_resource(show_spinner="Loading Resources...")
def loadResources():
    encoder = 'microsoft/swin-base-patch4-window7-224-in22k'
    decoder = 'gpt2'

    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        encoder, decoder
    )

    processor = ViTImageProcessor.from_pretrained(encoder)
    tokenizer = GPT2TokenizerFast.from_pretrained(decoder)

    if 'gpt2' in decoder:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.decoder_start_token_id = tokenizer.bos_token_id
    else:
        model.config.decoder_start_token_id = tokenizer.cls_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

    model = torch.load("generator_model.pkl", map_location=torch.device("cpu"))
    model.eval()
    return processor, tokenizer, model

# Pre-processing image and caching
@st.cache_data
def preprocess_image(_processor, image):
    transforms = v2.Compose([
        v2.Resize(size=(224,224)),
        v2.ToDtype(torch.float32, scale = True),
    ])
    image = transforms(image)
    img = _processor(image, return_tensors = 'pt')
    return img

# Generating caption and caching 
@st.cache_data
def get_caption(_processor, _tokenizer, _model, image):
    image = preprocess_image(_processor, image)
    output = _model.generate(
        **image, 
        generation_config = GenerationConfig(
            pad_token_id = _tokenizer.pad_token_id
        )
    )

    caption = _tokenizer.batch_decode(
        output,
        skip_special_tokens = True
    )

    caption = " ".join([item[0].upper()+item[1:] for item in caption[0].split(" ")])

    return caption

# Displaying elements
def DisplayInteractionElements():
    st.markdown('<div style="display: flex; justify-content: center;"><p style="font-size: 40px; font-weight: bold;">๐Ÿ‘‰ Caption Generator ๐Ÿ‘ˆ</p></div>', unsafe_allow_html=True)
    st.file_uploader(accept_multiple_files=False, label='Upload an Image', type=['jpg', 'jpeg', 'png'], key="image_uploader")

    if st.session_state['image_uploader']:
        image = st.session_state['image_uploader']
        im_file = Image.open(image).convert("RGB")
        im_file = np.array(im_file)
        
        st.session_state['image'] = im_file
        
        col1, col2, col3 = st.columns(3)
        
        col2.image(image=image, caption='Uploaded Image')
        
        st.button(label='Generate Caption', use_container_width=True, type='primary', on_click=generateCaption)

# Triggering generate state
def generateCaption():
        st.session_state['generate'] = True

def main():

    DisplayInteractionElements()

    processor, tokenizer, model = loadResources()

    if not st.session_state['image_uploader']:
        st.session_state['generate'] = False

    if st.session_state['generate'] and st.session_state['image_uploader']:
        caption = get_caption(processor, tokenizer, model, st.session_state['image'])
        st.markdown(f'<div style="display: flex; justify-content: center;"><p style="font-size: 35px; font-weight: bold; color: blue;">{caption}</p></div>', unsafe_allow_html = True)


if __name__ == "__main__":
    main()