Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import torch | |
import nltk | |
from openai import OpenAI | |
from transformers import pipeline | |
from diffusers import StableDiffusionPipeline | |
from ultralytics import YOLO | |
from gtts import gTTS | |
from PIL import Image | |
import numpy as np | |
from nltk.tokenize import sent_tokenize | |
import spaces | |
# Ensure minimal GPU usage | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
raise ValueError("\u26a0\ufe0f OpenAI API Key is missing! Add it as a Secret in Hugging Face Spaces.") | |
client = OpenAI(api_key=api_key) | |
# Use smallest YOLO model | |
yolo_model = YOLO("yolov8n.pt") | |
# Lightweight Stable Diffusion configuration | |
stable_diffusion = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
stable_diffusion.vae.enable_tiling = True # Enable tiling to reduce memory usage | |
nltk.download("punkt", quiet=True) | |
summarizer = pipeline( | |
"summarization", | |
model="sshleifer/distilbart-cnn-6-6" | |
) | |
def detect_objects(image): | |
try: | |
# Move model to appropriate device | |
yolo_model.to(device) | |
image_array = np.array(image) | |
results = yolo_model(image_array) | |
detected_objects = [] | |
for r in results: | |
for box in r.boxes: | |
class_id = int(box.cls.item()) | |
label = yolo_model.names[class_id] | |
detected_objects.append(label) | |
return list(set(detected_objects)) # Remove duplicates | |
except Exception as e: | |
print(f"Object detection error: {e}") | |
return ["generic", "objects"] | |
def generate_story(detected_objects): | |
try: | |
story_prompt = f"Write a concise, creative short story using these objects: {', '.join(detected_objects)}" | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", # More lightweight model | |
messages=[{"role": "user", "content": story_prompt}], | |
max_tokens=150 # Reduced token count | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
print(f"Story generation error: {e}") | |
return "A mysterious tale of adventure and discovery." | |
def summarize_story(story): | |
try: | |
summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text'] | |
scenes = sent_tokenize(summary) | |
# If fewer than 4 scenes, create additional scenes | |
while len(scenes) < 4: | |
# Duplicate or slightly modify existing scenes to reach 4 | |
scenes.append(f"Continuation of previous scene: {scenes[-1]}") | |
return scenes[:4] # Explicitly return 4 scenes | |
except Exception as e: | |
print(f"Story summarization error: {e}") | |
# Fallback to 4 generic scenes if summarization fails | |
return [ | |
"A peaceful scene at the beginning", | |
"An exciting moment of conflict", | |
"A turning point in the story", | |
"A dramatic conclusion" | |
] | |
def generate_images(story): | |
scenes = summarize_story(story) | |
images = [] | |
for prompt in scenes: | |
try: | |
with torch.no_grad(): | |
# Create slightly varied prompts for each scene | |
prompt_text = f"Simple illustration: {prompt}, soft colors, story scene" | |
image = stable_diffusion( | |
prompt_text, | |
num_inference_steps=20, # Reduced steps | |
guidance_scale=6.0, # Slightly lower guidance | |
height=256, # Smaller image | |
width=256 | |
).images[0] | |
images.append(image) | |
# Aggressive memory clearing | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
except Exception as e: | |
print(f"Image generation error: {e}") | |
# Fallback to ensure 4 images | |
while len(images) < 4: | |
images.append(Image.new('RGB', (256, 256), color='lightgray')) | |
return images | |
def text_to_speech(story): | |
try: | |
tts = gTTS(text=story, lang="en", slow=False) # Limit to first 500 chars | |
audio_file_path = "story_audio.mp3" | |
tts.save(audio_file_path) | |
return audio_file_path | |
except Exception as e: | |
print(f"Text-to-speech error: {e}") | |
return None | |
def full_pipeline(image): | |
# Wrap entire process with error handling | |
try: | |
detected_objects = detect_objects(image) | |
story = generate_story(detected_objects) | |
scenes = summarize_story(story) | |
images = generate_images(story) | |
audio = text_to_speech(story) | |
return ( | |
story or "A story could not be generated.", | |
scenes or ["Scene 1", "Scene 2"], | |
images, | |
audio | |
) | |
except Exception as e: | |
print(f"Full pipeline error: {e}") | |
return ( | |
"An unexpected error occurred.", | |
["Something went wrong"], | |
[Image.new('RGB', (256, 256), color='lightgray')], | |
None | |
) | |
# **Gradio UI** | |
demo = gr.Interface( | |
fn=full_pipeline, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Textbox(label="Generated Story"), | |
gr.Textbox(label="Story Scenes"), | |
gr.Gallery(label="Generated Images"), | |
gr.Audio(label="Story Audio"), | |
], | |
title="AI-Powered Storytelling Assistant", | |
description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |