Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from PIL import Image | |
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer | |
import torch | |
from nltk.corpus import wordnet | |
import nltk | |
nltk.download('wordnet') | |
# Load the pre-trained model for image captioning | |
model_name = "nlpconnect/vit-gpt2-image-captioning" | |
model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def generate_caption(image): | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
output_ids = model.generate(pixel_values) | |
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return caption | |
def get_synonyms(word): | |
synonyms = set() | |
for syn in wordnet.synsets(word): | |
for lemma in syn.lemmas(): | |
synonyms.add(lemma.name()) | |
return synonyms | |
def search_captions(query, captions): | |
query_words = query.split() | |
query_synonyms = set(query_words) | |
for word in query_words: | |
query_synonyms.update(get_synonyms(word)) | |
results = [] | |
for path, caption in captions.items(): | |
if any(word in caption.split() for word in query_synonyms): | |
results.append((path, caption)) | |
return results | |
def main(): | |
st.title("Image Gallery with Captioning and Search") | |
folder_path = st.text_input("Enter the folder path containing images:") | |
if folder_path and os.path.isdir(folder_path): | |
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('png', 'jpg', 'jpeg'))] | |
captions = {} | |
for image_file in image_files: | |
image_path = os.path.join(folder_path, image_file) | |
image = Image.open(image_path) | |
caption = generate_caption(image) | |
captions[image_path] = caption | |
st.image(image, caption=caption) | |
query = st.text_input("Search images by caption:") | |
if query: | |
results = search_captions(query, captions) | |
for image_path, caption in results: | |
st.image(image_path, caption=caption) | |
if __name__ == "__main__": | |
main() | |