Artwork-Caption-Generation / pages /Caption_Generator.py
nelbarman053's picture
app and caption generator file running and working
e6856f6
raw
history blame
3.86 kB
import torch
import numpy as np
from PIL import Image
import streamlit as st
from torchvision.transforms import v2
from transformers import GenerationConfig
from transformers import GPT2TokenizerFast
from transformers import ViTImageProcessor
from transformers import VisionEncoderDecoderModel
# Page configuration settings
st.set_page_config(
layout="centered",
page_title="Generate Caption",
initial_sidebar_state="auto",
)
# Initializing session state keys
if all(key not in st.session_state.keys() for key in ("generate", "image")):
st.session_state["generate"] = False
st.session_state["image"] = None
# Loading necessary resources and caching them
@st.cache_resource(show_spinner="Loading Resources...")
def loadResources():
encoder = 'microsoft/swin-base-patch4-window7-224-in22k'
decoder = 'gpt2'
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder, decoder
)
processor = ViTImageProcessor.from_pretrained(encoder)
tokenizer = GPT2TokenizerFast.from_pretrained(decoder)
if 'gpt2' in decoder:
tokenizer.pad_token = tokenizer.eos_token
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
else:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model = torch.load("generator_model.pkl", map_location=torch.device("cpu"))
model.eval()
return processor, tokenizer, model
# Pre-processing image and caching
@st.cache_data
def preprocess_image(_processor, image):
transforms = v2.Compose([
v2.Resize(size=(224,224)),
v2.ToDtype(torch.float32, scale = True),
])
image = transforms(image)
img = _processor(image, return_tensors = 'pt')
return img
# Generating caption and caching
@st.cache_data
def get_caption(_processor, _tokenizer, _model, image):
image = preprocess_image(_processor, image)
output = _model.generate(
**image,
generation_config = GenerationConfig(
pad_token_id = _tokenizer.pad_token_id
)
)
caption = _tokenizer.batch_decode(
output,
skip_special_tokens = True
)
return caption[0]
# Displaying elements
def DisplayInteractionElements():
st.markdown('<div style="display: flex; justify-content: center;"><p style="font-size: 40px; font-weight: bold;">πŸ‘‰ Caption Generator πŸ‘ˆ</p></div>', unsafe_allow_html=True)
st.file_uploader(accept_multiple_files=False, label='Upload an Image', type=['jpg', 'jpeg', 'png'], key="image_uploader")
if st.session_state['image_uploader']:
image = st.session_state['image_uploader']
im_file = Image.open(image).convert("RGB")
im_file = np.array(im_file)
st.session_state['image'] = im_file
col1, col2, col3 = st.columns(3)
col2.image(image=image, caption='Uploaded Image')
st.button(label='Generate Caption', use_container_width=True, type='primary', on_click=generateCaption)
# Triggering generate state
def generateCaption():
st.session_state['generate'] = True
def main():
DisplayInteractionElements()
processor, tokenizer, model = loadResources()
if not st.session_state['image_uploader']:
st.session_state['generate'] = False
if st.session_state['generate'] and st.session_state['image_uploader']:
caption = get_caption(processor, tokenizer, model, st.session_state['image'])
st.markdown(f'<div style="display: flex; justify-content: center;"><p style="font-size: 35px; font-weight: bold; color: blue;">{caption}</p></div>', unsafe_allow_html = True)
if __name__ == "__main__":
main()