File size: 4,187 Bytes
c106bd2
 
5b43d75
c106bd2
 
 
 
 
 
 
 
 
 
 
5b43d75
 
 
 
 
 
 
 
 
 
 
 
 
c106bd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
import gc
import os
from PIL import Image
import torch
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    DetrImageProcessor,
    DetrForObjectDetection,
)
import google.generativeai as genai

# Configure Generative AI
# genai.configure(api_key=google_api_key)
# gemini_model = genai.GenerativeModel("gemini-1.5-flash")

# Configure Generative AI with API key from environment variable


api_key = os.getenv('google_api_key')  # Retrieve the secret API key
if not api_key:
    st.error("API key is missing. Please configure 'google_api_key' in your Hugging Face Secrets.")
else:
    genai.configure(api_key=api_key)
    gemini_model = genai.GenerativeModel("gemini-1.5-flash")


# Load BLIP model and processor
@st.cache_resource
def load_blip():
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
    return processor, model

# Load DETR model and processor
@st.cache_resource
def load_detr():
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
    return processor, model

blip_processor, blip_model = load_blip()
detr_processor, detr_model = load_detr()

# Streamlit app
st.title("Art Of Visual Storytelling")

# Dropdown menu for genre selection
genre = st.selectbox(
    "Select the genre of the story:",
    ["Fantasy", "Adventure", "Sci-Fi", "Romance", "Mystery", "Horror", "Comedy", "Drama"]
)

# Slider for specifying story length
story_length = st.slider(
    "Select the desired story length (number of words):",
    min_value=50, max_value=1000, value=200, step=50
)

# Image upload and processing
uploaded_image = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])

if uploaded_image:
    # Display the uploaded image
    raw_image = Image.open(uploaded_image).convert('RGB')
    st.image(raw_image, caption='Uploaded Image', use_column_width=True)

    # Detect objects using DETR model
    detr_inputs = detr_processor(images=raw_image, return_tensors="pt")
    detr_model.eval()
    with torch.no_grad():
        detr_outputs = detr_model(**detr_inputs)

    target_sizes = torch.tensor([raw_image.size[::-1]])  # (height, width)
    results = detr_processor.post_process_object_detection(detr_outputs, target_sizes=target_sizes, threshold=0.9)[0]

    # Extract object names
    detected_objects = []
    for score, label in zip(results["scores"], results["labels"]):
        label_name = detr_model.config.id2label[label.item()]
        detected_objects.append(label_name)

    # Display the detected objects before generating the caption
    st.subheader("Detected Objects:")
    st.text(", ".join(set(detected_objects)))  # Show unique objects

    # Generate caption using BLIP model
    blip_inputs = blip_processor(raw_image, return_tensors="pt")
    blip_outputs = blip_model.generate(**blip_inputs)
    caption = blip_processor.decode(blip_outputs[0], skip_special_tokens=True)

    # Display the generated caption after detected objects
    st.subheader("Generated Caption:")
    st.text(caption)

    # Language selection box appears after the caption
    st.subheader("Select Story Language:")
    language = st.selectbox(
        "Select the language of the story:",
        ["English", "Hindi", "Bengali", "Tamil"]
    )

    # Submit button to generate the story
    if st.button("Generate Story"):
        prompt = (
            f"I have an image. The caption is '{caption}', and the detected objects are {', '.join(set(detected_objects))}. "
            f"Write a {genre.lower()} story in {language.lower()} with approximately {story_length} words using elements of this caption and the detected objects."
        )
        response = gemini_model.generate_content(prompt)
        st.subheader("Generated Story:")
        st.text_area("Story Output", value=response.text, height=300)

    # Cleanup memory
    del raw_image, detr_inputs, detr_outputs, blip_inputs, blip_outputs, results
    gc.collect()