ThisHumanA's picture
Update app.py
5b43d75 verified
raw
history blame
4.19 kB
import streamlit as st
import gc
import os
from PIL import Image
import torch
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
DetrImageProcessor,
DetrForObjectDetection,
)
import google.generativeai as genai
# Configure Generative AI
# genai.configure(api_key=google_api_key)
# gemini_model = genai.GenerativeModel("gemini-1.5-flash")
# Configure Generative AI with API key from environment variable
api_key = os.getenv('google_api_key') # Retrieve the secret API key
if not api_key:
st.error("API key is missing. Please configure 'google_api_key' in your Hugging Face Secrets.")
else:
genai.configure(api_key=api_key)
gemini_model = genai.GenerativeModel("gemini-1.5-flash")
# Load BLIP model and processor
@st.cache_resource
def load_blip():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
return processor, model
# Load DETR model and processor
@st.cache_resource
def load_detr():
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
return processor, model
blip_processor, blip_model = load_blip()
detr_processor, detr_model = load_detr()
# Streamlit app
st.title("Art Of Visual Storytelling")
# Dropdown menu for genre selection
genre = st.selectbox(
"Select the genre of the story:",
["Fantasy", "Adventure", "Sci-Fi", "Romance", "Mystery", "Horror", "Comedy", "Drama"]
)
# Slider for specifying story length
story_length = st.slider(
"Select the desired story length (number of words):",
min_value=50, max_value=1000, value=200, step=50
)
# Image upload and processing
uploaded_image = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])
if uploaded_image:
# Display the uploaded image
raw_image = Image.open(uploaded_image).convert('RGB')
st.image(raw_image, caption='Uploaded Image', use_column_width=True)
# Detect objects using DETR model
detr_inputs = detr_processor(images=raw_image, return_tensors="pt")
detr_model.eval()
with torch.no_grad():
detr_outputs = detr_model(**detr_inputs)
target_sizes = torch.tensor([raw_image.size[::-1]]) # (height, width)
results = detr_processor.post_process_object_detection(detr_outputs, target_sizes=target_sizes, threshold=0.9)[0]
# Extract object names
detected_objects = []
for score, label in zip(results["scores"], results["labels"]):
label_name = detr_model.config.id2label[label.item()]
detected_objects.append(label_name)
# Display the detected objects before generating the caption
st.subheader("Detected Objects:")
st.text(", ".join(set(detected_objects))) # Show unique objects
# Generate caption using BLIP model
blip_inputs = blip_processor(raw_image, return_tensors="pt")
blip_outputs = blip_model.generate(**blip_inputs)
caption = blip_processor.decode(blip_outputs[0], skip_special_tokens=True)
# Display the generated caption after detected objects
st.subheader("Generated Caption:")
st.text(caption)
# Language selection box appears after the caption
st.subheader("Select Story Language:")
language = st.selectbox(
"Select the language of the story:",
["English", "Hindi", "Bengali", "Tamil"]
)
# Submit button to generate the story
if st.button("Generate Story"):
prompt = (
f"I have an image. The caption is '{caption}', and the detected objects are {', '.join(set(detected_objects))}. "
f"Write a {genre.lower()} story in {language.lower()} with approximately {story_length} words using elements of this caption and the detected objects."
)
response = gemini_model.generate_content(prompt)
st.subheader("Generated Story:")
st.text_area("Story Output", value=response.text, height=300)
# Cleanup memory
del raw_image, detr_inputs, detr_outputs, blip_inputs, blip_outputs, results
gc.collect()