tournas's picture
Update app.py
1aaa563 verified
raw
history blame
5.27 kB
import os
import gradio as gr
import torch
import nltk
from openai import OpenAI
from transformers import pipeline
from diffusers import StableDiffusionPipeline
from ultralytics import YOLO
from gtts import gTTS
from PIL import Image
import numpy as np
from nltk.tokenize import sent_tokenize
import spaces
# Ensure minimal GPU usage
device = "cuda" if torch.cuda.is_available() else "cpu"
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("\u26a0\ufe0f OpenAI API Key is missing! Add it as a Secret in Hugging Face Spaces.")
client = OpenAI(api_key=api_key)
# Use smallest YOLO model
yolo_model = YOLO("yolov8n.pt")
# Lightweight Stable Diffusion configuration
stable_diffusion = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
stable_diffusion.vae.enable_tiling = True # Enable tiling to reduce memory usage
nltk.download("punkt", quiet=True)
summarizer = pipeline(
"summarization",
model="sshleifer/distilbart-cnn-6-6"
)
def detect_objects(image):
try:
# Move model to appropriate device
yolo_model.to(device)
image_array = np.array(image)
results = yolo_model(image_array)
detected_objects = []
for r in results:
for box in r.boxes:
class_id = int(box.cls.item())
label = yolo_model.names[class_id]
detected_objects.append(label)
return list(set(detected_objects)) # Remove duplicates
except Exception as e:
print(f"Object detection error: {e}")
return ["generic", "objects"]
def generate_story(detected_objects):
try:
story_prompt = f"Write a concise, creative short story using these objects: {', '.join(detected_objects)}"
response = client.chat.completions.create(
model="gpt-3.5-turbo", # More lightweight model
messages=[{"role": "user", "content": story_prompt}],
max_tokens=150 # Reduced token count
)
return response.choices[0].message.content.strip()
except Exception as e:
print(f"Story generation error: {e}")
return "A mysterious tale of adventure and discovery."
def summarize_story(story):
try:
summary = summarizer(story, max_length=50, do_sample=False)[0]['summary_text']
scenes = sent_tokenize(summary)
return scenes[:2] # Limit to 2 scenes to reduce computational load
except Exception as e:
print(f"Story summarization error: {e}")
return ["A peaceful scene", "An exciting moment"]
def generate_images(story):
scenes = summarize_story(story)
images = []
for prompt in scenes:
try:
with torch.no_grad():
# Simplified, less computationally intensive prompt
prompt_text = f"Simple illustration: {prompt}, soft colors"
image = stable_diffusion(
prompt_text,
num_inference_steps=20, # Reduced steps
guidance_scale=6.0, # Slightly lower guidance
height=256, # Smaller image
width=256
).images[0]
images.append(image)
# Aggressive memory clearing
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"Image generation error: {e}")
# Fallback if no images generated
return images if images else [Image.new('RGB', (256, 256), color='lightgray')]
def text_to_speech(story):
try:
tts = gTTS(text=story[:500], lang="en", slow=False) # Limit to first 500 chars
audio_file_path = "story_audio.mp3"
tts.save(audio_file_path)
return audio_file_path
except Exception as e:
print(f"Text-to-speech error: {e}")
return None
@spaces.GPU
def full_pipeline(image):
# Wrap entire process with error handling
try:
detected_objects = detect_objects(image)
story = generate_story(detected_objects)
scenes = summarize_story(story)
images = generate_images(story)
audio = text_to_speech(story)
return (
story or "A story could not be generated.",
scenes or ["Scene 1", "Scene 2"],
images,
audio
)
except Exception as e:
print(f"Full pipeline error: {e}")
return (
"An unexpected error occurred.",
["Something went wrong"],
[Image.new('RGB', (256, 256), color='lightgray')],
None
)
# **Gradio UI**
demo = gr.Interface(
fn=full_pipeline,
inputs=gr.Image(type="pil"),
outputs=[
gr.Textbox(label="Generated Story"),
gr.Textbox(label="Story Scenes"),
gr.Gallery(label="Generated Images"),
gr.Audio(label="Story Audio"),
],
title="AI-Powered Storytelling Assistant",
description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story."
)
if __name__ == "__main__":
demo.launch()