File size: 3,688 Bytes
a236b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import gc
from PIL import Image
import torch
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    DetrImageProcessor,
    DetrForObjectDetection,
)
import google.generativeai as genai

# Configure Generative AI
genai.configure(api_key='AIzaSyB1wxTDQcB2YT_6l2nm4MrhAmCVPzfkHNU')
gemini_model = genai.GenerativeModel("gemini-1.5-flash")

# Load BLIP model and processor
@st.cache_resource
def load_blip():
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
    return processor, model

# Load DETR model and processor
@st.cache_resource
def load_detr():
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
    return processor, model

blip_processor, blip_model = load_blip()
detr_processor, detr_model = load_detr()

# Streamlit app
st.title("Art Of Visual Storytelling")

# Dropdown menu for genre selection
genre = st.selectbox(
    "Select the genre of the story:",
    ["Fantasy", "Adventure", "Sci-Fi", "Romance", "Mystery", "Horror", "Comedy", "Drama"]
)

# Slider for specifying story length
story_length = st.slider(
    "Select the desired story length (number of words):",
    min_value=50, max_value=1000, value=200, step=50
)

# Image upload and processing
uploaded_image = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])

if uploaded_image:
    # Display the uploaded image
    raw_image = Image.open(uploaded_image).convert('RGB')
    st.image(raw_image, caption='Uploaded Image', use_column_width=True)

    # Detect objects using DETR model
    detr_inputs = detr_processor(images=raw_image, return_tensors="pt")
    detr_model.eval()
    with torch.no_grad():
        detr_outputs = detr_model(**detr_inputs)

    target_sizes = torch.tensor([raw_image.size[::-1]])  # (height, width)
    results = detr_processor.post_process_object_detection(detr_outputs, target_sizes=target_sizes, threshold=0.9)[0]

    # Extract object names
    detected_objects = []
    for score, label in zip(results["scores"], results["labels"]):
        label_name = detr_model.config.id2label[label.item()]
        detected_objects.append(label_name)

    # Display the detected objects before generating the caption
    st.subheader("Detected Objects:")
    st.text(", ".join(set(detected_objects)))  # Show unique objects

    # Generate caption using BLIP model
    blip_inputs = blip_processor(raw_image, return_tensors="pt")
    blip_outputs = blip_model.generate(**blip_inputs)
    caption = blip_processor.decode(blip_outputs[0], skip_special_tokens=True)

    # Display the generated caption after detected objects
    st.subheader("Generated Caption:")
    st.text(caption)

    # Submit button to generate the story
    if st.button("Generate Story"):
        prompt = (
            f"I have an image. The caption is '{caption}', and the detected objects are {', '.join(set(detected_objects))}. "
            f"Write a {genre.lower()} story with approximately {story_length} words using elements of this caption and the detected objects."
        )
        response = gemini_model.generate_content(prompt)
        st.subheader("Generated Story:")
        st.text_area("Story Output", value=response.text, height=300)

    # Cleanup memory
    del raw_image, detr_inputs, detr_outputs, blip_inputs, blip_outputs, results
    gc.collect()