import streamlit as st import os import zipfile import tempfile import base64 from PIL import Image from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer import pandas as pd from nltk.corpus import wordnet import spacy import io from spacy.cli import download # Download and load the spaCy model download("en_core_web_sm") nlp = spacy.load("en_core_web_sm") # Download NLTK WordNet data import nltk nltk.download('wordnet') nltk.download('omw-1.4') # Load the pre-trained model for image captioning model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-85k-11" model = VisionEncoderDecoderModel.from_pretrained(model_name) feature_extractor = ViTImageProcessor.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token # Update the model config model.config.eos_token_id = tokenizer.eos_token_id model.config.decoder_start_token_id = tokenizer.bos_token_id model.config.pad_token_id = tokenizer.pad_token_id def generate_caption(image): pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values output_ids = model.generate(pixel_values) caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) return caption def get_synonyms(word): synonyms = set() for syn in wordnet.synsets(word): for lemma in syn.lemmas(): synonyms.add(lemma.name()) return synonyms def preprocess_query(query): doc = nlp(query) tokens = set() for token in doc: tokens.add(token.text.lower()) tokens.add(token.lemma_.lower()) tokens.update(get_synonyms(token.text.lower())) return tokens def search_captions(query, captions): query_tokens = preprocess_query(query) results = [] for path, caption in captions.items(): caption_tokens = preprocess_query(caption) if query_tokens & caption_tokens: results.append((path, caption)) return results st.title("Image Captioning Gallery") # Sidebar for search functionality with st.sidebar: query = st.text_input("Search images by caption:") # Options for input strategy input_option = st.selectbox("Select input method:", ["Folder Path", "Upload Images", "Upload ZIP"]) image_files = [] if input_option == "Folder Path": folder_path = st.text_input("Enter the folder path containing images:") if folder_path and os.path.isdir(folder_path): image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith(('png', 'jpg', 'jpeg'))] elif input_option == "Upload Images": uploaded_files = st.file_uploader("Upload image files", type=["png", "jpg", "jpeg"], accept_multiple_files=True) if uploaded_files: for uploaded_file in uploaded_files: with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as temp_file: temp_file.write(uploaded_file.read()) image_files.append(temp_file.name) elif input_option == "Upload ZIP": uploaded_zip = st.file_uploader("Upload a ZIP file containing images", type=["zip"]) if uploaded_zip: with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(uploaded_zip.read()) with zipfile.ZipFile(temp_file.name, 'r') as zip_ref: zip_ref.extractall("/tmp/images") image_files = [os.path.join("/tmp/images", f) for f in zip_ref.namelist() if f.lower().endswith(('png', 'jpg', 'jpeg'))] captions = {} if st.button("Generate Captions", key='generate_captions'): for image_file in image_files: try: image = Image.open(image_file) caption = generate_caption(image) captions[image_file] = caption except Exception as e: st.error(f"Error processing {image_file}: {e}") # Display images in a grid st.subheader("Images and Captions:") cols = st.columns(4) idx = 0 for image_path, caption in captions.items(): col = cols[idx % 4] with col: try: with open(image_path, "rb") as img_file: img_bytes = img_file.read() encoded_image = base64.b64encode(img_bytes).decode() st.markdown( f"""
{caption}
{image_path}
{caption}
{image_path}