|
import streamlit as st |
|
import gc |
|
import os |
|
from PIL import Image |
|
import torch |
|
from transformers import ( |
|
BlipProcessor, |
|
BlipForConditionalGeneration, |
|
DetrImageProcessor, |
|
DetrForObjectDetection, |
|
) |
|
import google.generativeai as genai |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api_key = os.getenv('google_api_key') |
|
if not api_key: |
|
st.error("API key is missing. Please configure 'google_api_key' in your Hugging Face Secrets.") |
|
else: |
|
genai.configure(api_key=api_key) |
|
gemini_model = genai.GenerativeModel("gemini-1.5-flash") |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_blip(): |
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
return processor, model |
|
|
|
|
|
@st.cache_resource |
|
def load_detr(): |
|
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm") |
|
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm") |
|
return processor, model |
|
|
|
blip_processor, blip_model = load_blip() |
|
detr_processor, detr_model = load_detr() |
|
|
|
|
|
st.title("Art Of Visual Storytelling") |
|
|
|
|
|
genre = st.selectbox( |
|
"Select the genre of the story:", |
|
["Fantasy", "Adventure", "Sci-Fi", "Romance", "Mystery", "Horror", "Comedy", "Drama"] |
|
) |
|
|
|
|
|
story_length = st.slider( |
|
"Select the desired story length (number of words):", |
|
min_value=50, max_value=1000, value=200, step=50 |
|
) |
|
|
|
|
|
uploaded_image = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg']) |
|
|
|
if uploaded_image: |
|
|
|
raw_image = Image.open(uploaded_image).convert('RGB') |
|
st.image(raw_image, caption='Uploaded Image', use_column_width=True) |
|
|
|
|
|
detr_inputs = detr_processor(images=raw_image, return_tensors="pt") |
|
detr_model.eval() |
|
with torch.no_grad(): |
|
detr_outputs = detr_model(**detr_inputs) |
|
|
|
target_sizes = torch.tensor([raw_image.size[::-1]]) |
|
results = detr_processor.post_process_object_detection(detr_outputs, target_sizes=target_sizes, threshold=0.9)[0] |
|
|
|
|
|
detected_objects = [] |
|
for score, label in zip(results["scores"], results["labels"]): |
|
label_name = detr_model.config.id2label[label.item()] |
|
detected_objects.append(label_name) |
|
|
|
|
|
st.subheader("Detected Objects:") |
|
st.text(", ".join(set(detected_objects))) |
|
|
|
|
|
blip_inputs = blip_processor(raw_image, return_tensors="pt") |
|
blip_outputs = blip_model.generate(**blip_inputs) |
|
caption = blip_processor.decode(blip_outputs[0], skip_special_tokens=True) |
|
|
|
|
|
st.subheader("Generated Caption:") |
|
st.text(caption) |
|
|
|
|
|
st.subheader("Select Story Language:") |
|
language = st.selectbox( |
|
"Select the language of the story:", |
|
["English", "Hindi", "Bengali", "Tamil"] |
|
) |
|
|
|
|
|
if st.button("Generate Story"): |
|
prompt = ( |
|
f"I have an image. The caption is '{caption}', and the detected objects are {', '.join(set(detected_objects))}. " |
|
f"Write a {genre.lower()} story in {language.lower()} with approximately {story_length} words using elements of this caption and the detected objects." |
|
) |
|
response = gemini_model.generate_content(prompt) |
|
st.subheader("Generated Story:") |
|
st.text_area("Story Output", value=response.text, height=300) |
|
|
|
|
|
del raw_image, detr_inputs, detr_outputs, blip_inputs, blip_outputs, results |
|
gc.collect() |
|
|
|
|