tournas's picture
Update app.py
718f0a4 verified
import os
import gradio as gr
import torch
import nltk
from openai import OpenAI
from transformers import pipeline
from diffusers import StableDiffusionPipeline
from ultralytics import YOLO
from gtts import gTTS
from PIL import Image
import numpy as np
from nltk.tokenize import sent_tokenize
import spaces
# Ensure minimal GPU usage
device = "cuda" if torch.cuda.is_available() else "cpu"
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("\u26a0\ufe0f OpenAI API Key is missing! Add it as a Secret in Hugging Face Spaces.")
client = OpenAI(api_key=api_key)
# Use smallest YOLO model
yolo_model = YOLO("yolov8n.pt")
# Lightweight Stable Diffusion configuration
stable_diffusion = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
stable_diffusion.vae.enable_tiling = True # Enable tiling to reduce memory usage
nltk.download("punkt", quiet=True)
summarizer = pipeline(
"summarization",
model="sshleifer/distilbart-cnn-6-6"
)
def detect_objects(image):
try:
# Move model to appropriate device
yolo_model.to(device)
image_array = np.array(image)
results = yolo_model(image_array)
detected_objects = []
for r in results:
for box in r.boxes:
class_id = int(box.cls.item())
label = yolo_model.names[class_id]
detected_objects.append(label)
return list(set(detected_objects)) # Remove duplicates
except Exception as e:
print(f"Object detection error: {e}")
return ["generic", "objects"]
def generate_story(detected_objects):
try:
story_prompt = f"Write a concise, creative short story using these objects: {', '.join(detected_objects)}"
response = client.chat.completions.create(
model="gpt-3.5-turbo", # More lightweight model
messages=[{"role": "user", "content": story_prompt}],
max_tokens=150 # Reduced token count
)
return response.choices[0].message.content.strip()
except Exception as e:
print(f"Story generation error: {e}")
return "A mysterious tale of adventure and discovery."
def summarize_story(story):
try:
summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
scenes = sent_tokenize(summary)
# If fewer than 4 scenes, create additional scenes
while len(scenes) < 4:
# Duplicate or slightly modify existing scenes to reach 4
scenes.append(f"Continuation of previous scene: {scenes[-1]}")
return scenes[:4] # Explicitly return 4 scenes
except Exception as e:
print(f"Story summarization error: {e}")
# Fallback to 4 generic scenes if summarization fails
return [
"A peaceful scene at the beginning",
"An exciting moment of conflict",
"A turning point in the story",
"A dramatic conclusion"
]
def generate_images(story):
scenes = summarize_story(story)
images = []
for prompt in scenes:
try:
with torch.no_grad():
# Create slightly varied prompts for each scene
prompt_text = f"Simple illustration: {prompt}, soft colors, story scene"
image = stable_diffusion(
prompt_text,
num_inference_steps=20, # Reduced steps
guidance_scale=6.0, # Slightly lower guidance
height=256, # Smaller image
width=256
).images[0]
images.append(image)
# Aggressive memory clearing
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"Image generation error: {e}")
# Fallback to ensure 4 images
while len(images) < 4:
images.append(Image.new('RGB', (256, 256), color='lightgray'))
return images
def text_to_speech(story):
try:
tts = gTTS(text=story, lang="en", slow=False) # Limit to first 500 chars
audio_file_path = "story_audio.mp3"
tts.save(audio_file_path)
return audio_file_path
except Exception as e:
print(f"Text-to-speech error: {e}")
return None
@spaces.GPU
def full_pipeline(image):
# Wrap entire process with error handling
try:
detected_objects = detect_objects(image)
story = generate_story(detected_objects)
scenes = summarize_story(story)
images = generate_images(story)
audio = text_to_speech(story)
return (
story or "A story could not be generated.",
scenes or ["Scene 1", "Scene 2"],
images,
audio
)
except Exception as e:
print(f"Full pipeline error: {e}")
return (
"An unexpected error occurred.",
["Something went wrong"],
[Image.new('RGB', (256, 256), color='lightgray')],
None
)
# **Gradio UI**
demo = gr.Interface(
fn=full_pipeline,
inputs=gr.Image(type="pil"),
outputs=[
gr.Textbox(label="Generated Story"),
gr.Textbox(label="Story Scenes"),
gr.Gallery(label="Generated Images"),
gr.Audio(label="Story Audio"),
],
title="AI-Powered Storytelling Assistant",
description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story."
)
if __name__ == "__main__":
demo.launch()