|
import streamlit as st
|
|
import gc
|
|
from PIL import Image
|
|
import torch
|
|
from transformers import (
|
|
BlipProcessor,
|
|
BlipForConditionalGeneration,
|
|
DetrImageProcessor,
|
|
DetrForObjectDetection,
|
|
)
|
|
import google.generativeai as genai
|
|
|
|
|
|
genai.configure(api_key='AIzaSyB1wxTDQcB2YT_6l2nm4MrhAmCVPzfkHNU')
|
|
gemini_model = genai.GenerativeModel("gemini-1.5-flash")
|
|
|
|
|
|
@st.cache_resource
|
|
def load_blip():
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
|
return processor, model
|
|
|
|
|
|
@st.cache_resource
|
|
def load_detr():
|
|
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
|
|
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
|
|
return processor, model
|
|
|
|
blip_processor, blip_model = load_blip()
|
|
detr_processor, detr_model = load_detr()
|
|
|
|
|
|
st.title("Art Of Visual Storytelling")
|
|
|
|
|
|
genre = st.selectbox(
|
|
"Select the genre of the story:",
|
|
["Fantasy", "Adventure", "Sci-Fi", "Romance", "Mystery", "Horror", "Comedy", "Drama"]
|
|
)
|
|
|
|
|
|
story_length = st.slider(
|
|
"Select the desired story length (number of words):",
|
|
min_value=50, max_value=1000, value=200, step=50
|
|
)
|
|
|
|
|
|
uploaded_image = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])
|
|
|
|
if uploaded_image:
|
|
|
|
raw_image = Image.open(uploaded_image).convert('RGB')
|
|
st.image(raw_image, caption='Uploaded Image', use_column_width=True)
|
|
|
|
|
|
detr_inputs = detr_processor(images=raw_image, return_tensors="pt")
|
|
detr_model.eval()
|
|
with torch.no_grad():
|
|
detr_outputs = detr_model(**detr_inputs)
|
|
|
|
target_sizes = torch.tensor([raw_image.size[::-1]])
|
|
results = detr_processor.post_process_object_detection(detr_outputs, target_sizes=target_sizes, threshold=0.9)[0]
|
|
|
|
|
|
detected_objects = []
|
|
for score, label in zip(results["scores"], results["labels"]):
|
|
label_name = detr_model.config.id2label[label.item()]
|
|
detected_objects.append(label_name)
|
|
|
|
|
|
st.subheader("Detected Objects:")
|
|
st.text(", ".join(set(detected_objects)))
|
|
|
|
|
|
blip_inputs = blip_processor(raw_image, return_tensors="pt")
|
|
blip_outputs = blip_model.generate(**blip_inputs)
|
|
caption = blip_processor.decode(blip_outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
st.subheader("Generated Caption:")
|
|
st.text(caption)
|
|
|
|
|
|
if st.button("Generate Story"):
|
|
prompt = (
|
|
f"I have an image. The caption is '{caption}', and the detected objects are {', '.join(set(detected_objects))}. "
|
|
f"Write a {genre.lower()} story with approximately {story_length} words using elements of this caption and the detected objects."
|
|
)
|
|
response = gemini_model.generate_content(prompt)
|
|
st.subheader("Generated Story:")
|
|
st.text_area("Story Output", value=response.text, height=300)
|
|
|
|
|
|
del raw_image, detr_inputs, detr_outputs, blip_inputs, blip_outputs, results
|
|
gc.collect()
|
|
|