import os import gradio as gr import torch import nltk from openai import OpenAI from transformers import pipeline from diffusers import StableDiffusionPipeline from ultralytics import YOLO from gtts import gTTS from PIL import Image import numpy as np from nltk.tokenize import sent_tokenize import spaces # Ensure minimal GPU usage device = "cuda" if torch.cuda.is_available() else "cpu" api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("\u26a0\ufe0f OpenAI API Key is missing! Add it as a Secret in Hugging Face Spaces.") client = OpenAI(api_key=api_key) # Use smallest YOLO model yolo_model = YOLO("yolov8n.pt") # Lightweight Stable Diffusion configuration stable_diffusion = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) stable_diffusion.vae.enable_tiling = True # Enable tiling to reduce memory usage nltk.download("punkt", quiet=True) summarizer = pipeline( "summarization", model="sshleifer/distilbart-cnn-6-6" ) def detect_objects(image): try: # Move model to appropriate device yolo_model.to(device) image_array = np.array(image) results = yolo_model(image_array) detected_objects = [] for r in results: for box in r.boxes: class_id = int(box.cls.item()) label = yolo_model.names[class_id] detected_objects.append(label) return list(set(detected_objects)) # Remove duplicates except Exception as e: print(f"Object detection error: {e}") return ["generic", "objects"] def generate_story(detected_objects): try: story_prompt = f"Write a concise, creative short story using these objects: {', '.join(detected_objects)}" response = client.chat.completions.create( model="gpt-3.5-turbo", # More lightweight model messages=[{"role": "user", "content": story_prompt}], max_tokens=150 # Reduced token count ) return response.choices[0].message.content.strip() except Exception as e: print(f"Story generation error: {e}") return "A mysterious tale of adventure and discovery." def summarize_story(story): try: summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text'] scenes = sent_tokenize(summary) # If fewer than 4 scenes, create additional scenes while len(scenes) < 4: # Duplicate or slightly modify existing scenes to reach 4 scenes.append(f"Continuation of previous scene: {scenes[-1]}") return scenes[:4] # Explicitly return 4 scenes except Exception as e: print(f"Story summarization error: {e}") # Fallback to 4 generic scenes if summarization fails return [ "A peaceful scene at the beginning", "An exciting moment of conflict", "A turning point in the story", "A dramatic conclusion" ] def generate_images(story): scenes = summarize_story(story) images = [] for prompt in scenes: try: with torch.no_grad(): # Create slightly varied prompts for each scene prompt_text = f"Simple illustration: {prompt}, soft colors, story scene" image = stable_diffusion( prompt_text, num_inference_steps=20, # Reduced steps guidance_scale=6.0, # Slightly lower guidance height=256, # Smaller image width=256 ).images[0] images.append(image) # Aggressive memory clearing if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: print(f"Image generation error: {e}") # Fallback to ensure 4 images while len(images) < 4: images.append(Image.new('RGB', (256, 256), color='lightgray')) return images def text_to_speech(story): try: tts = gTTS(text=story, lang="en", slow=False) # Limit to first 500 chars audio_file_path = "story_audio.mp3" tts.save(audio_file_path) return audio_file_path except Exception as e: print(f"Text-to-speech error: {e}") return None @spaces.GPU def full_pipeline(image): # Wrap entire process with error handling try: detected_objects = detect_objects(image) story = generate_story(detected_objects) scenes = summarize_story(story) images = generate_images(story) audio = text_to_speech(story) return ( story or "A story could not be generated.", scenes or ["Scene 1", "Scene 2"], images, audio ) except Exception as e: print(f"Full pipeline error: {e}") return ( "An unexpected error occurred.", ["Something went wrong"], [Image.new('RGB', (256, 256), color='lightgray')], None ) # **Gradio UI** demo = gr.Interface( fn=full_pipeline, inputs=gr.Image(type="pil"), outputs=[ gr.Textbox(label="Generated Story"), gr.Textbox(label="Story Scenes"), gr.Gallery(label="Generated Images"), gr.Audio(label="Story Audio"), ], title="AI-Powered Storytelling Assistant", description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story." ) if __name__ == "__main__": demo.launch()