File size: 2,223 Bytes
c34bc48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import streamlit as st
import os
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from nltk.corpus import wordnet
import nltk

nltk.download('wordnet')

# Load the pre-trained model for image captioning
model_name = "nlpconnect/vit-gpt2-image-captioning"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def generate_caption(image):
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    output_ids = model.generate(pixel_values)
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

def get_synonyms(word):
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonyms.add(lemma.name())
    return synonyms

def search_captions(query, captions):
    query_words = query.split()
    query_synonyms = set(query_words)
    for word in query_words:
        query_synonyms.update(get_synonyms(word))
    
    results = []
    for path, caption in captions.items():
        if any(word in caption.split() for word in query_synonyms):
            results.append((path, caption))
    
    return results

def main():
    st.title("Image Gallery with Captioning and Search")

    folder_path = st.text_input("Enter the folder path containing images:")
    
    if folder_path and os.path.isdir(folder_path):
        image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('png', 'jpg', 'jpeg'))]
        captions = {}

        for image_file in image_files:
            image_path = os.path.join(folder_path, image_file)
            image = Image.open(image_path)
            caption = generate_caption(image)
            captions[image_path] = caption
            st.image(image, caption=caption)
        
        query = st.text_input("Search images by caption:")
        if query:
            results = search_captions(query, captions)
            for image_path, caption in results:
                st.image(image_path, caption=caption)

if __name__ == "__main__":
    main()