import torch import numpy as np from PIL import Image import streamlit as st from torchvision.transforms import v2 from transformers import GenerationConfig from transformers import GPT2TokenizerFast from transformers import ViTImageProcessor from transformers import VisionEncoderDecoderModel # Page configuration settings st.set_page_config( layout="centered", page_title="Generate Caption", initial_sidebar_state="collapsed", ) # Initializing session state keys if all(key not in st.session_state.keys() for key in ("generate", "image")): st.session_state["generate"] = False st.session_state["image"] = None # Loading necessary resources and caching them @st.cache_resource(show_spinner="Loading Resources...") def loadResources(): encoder = 'microsoft/swin-base-patch4-window7-224-in22k' decoder = 'gpt2' model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( encoder, decoder ) processor = ViTImageProcessor.from_pretrained(encoder) tokenizer = GPT2TokenizerFast.from_pretrained(decoder) if 'gpt2' in decoder: tokenizer.pad_token = tokenizer.eos_token model.config.eos_token_id = tokenizer.eos_token_id model.config.pad_token_id = tokenizer.pad_token_id model.config.decoder_start_token_id = tokenizer.bos_token_id else: model.config.decoder_start_token_id = tokenizer.cls_token_id model.config.pad_token_id = tokenizer.pad_token_id model = torch.load("generator_model.pkl", map_location=torch.device("cpu")) model.eval() return processor, tokenizer, model # Pre-processing image and caching @st.cache_data def preprocess_image(_processor, image): transforms = v2.Compose([ v2.Resize(size=(224,224)), v2.ToDtype(torch.float32, scale = True), ]) image = transforms(image) img = _processor(image, return_tensors = 'pt') return img # Generating caption and caching @st.cache_data def get_caption(_processor, _tokenizer, _model, image): image = preprocess_image(_processor, image) output = _model.generate( **image, generation_config = GenerationConfig( pad_token_id = _tokenizer.pad_token_id ) ) caption = _tokenizer.batch_decode( output, skip_special_tokens = True ) caption = " ".join([item[0].upper()+item[1:] for item in caption[0].split(" ")]) return caption # Displaying elements def DisplayInteractionElements(): st.markdown('

👉 Caption Generator 👈

', unsafe_allow_html=True) st.file_uploader(accept_multiple_files=False, label='Upload an Image', type=['jpg', 'jpeg', 'png'], key="image_uploader") if st.session_state['image_uploader']: image = st.session_state['image_uploader'] im_file = Image.open(image).convert("RGB") im_file = np.array(im_file) st.session_state['image'] = im_file col1, col2, col3 = st.columns(3) col2.image(image=image, caption='Uploaded Image') st.button(label='Generate Caption', use_container_width=True, type='primary', on_click=generateCaption) # Triggering generate state def generateCaption(): st.session_state['generate'] = True def main(): DisplayInteractionElements() processor, tokenizer, model = loadResources() if not st.session_state['image_uploader']: st.session_state['generate'] = False if st.session_state['generate'] and st.session_state['image_uploader']: caption = get_caption(processor, tokenizer, model, st.session_state['image']) st.markdown(f'

{caption}

', unsafe_allow_html = True) if __name__ == "__main__": main()