import os import gradio as gr import torch import nltk from openai import OpenAI from transformers import pipeline from diffusers import StableDiffusionPipeline from ultralytics import YOLO from gtts import gTTS from PIL import Image import numpy as np from nltk.tokenize import sent_tokenize import spaces # Ensure minimal GPU usage device = "cuda" if torch.cuda.is_available() else "cpu" api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("\u26a0\ufe0f OpenAI API Key is missing! Add it as a Secret in Hugging Face Spaces.") client = OpenAI(api_key=api_key) # Use smallest YOLO model yolo_model = YOLO("yolov8n.pt") # Lightweight Stable Diffusion configuration stable_diffusion = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) stable_diffusion.vae.enable_tiling = True # Enable tiling to reduce memory usage nltk.download("punkt", quiet=True) summarizer = pipeline( "summarization", model="sshleifer/distilbart-cnn-6-6" ) def detect_objects(image): try: # Move model to appropriate device yolo_model.to(device) image_array = np.array(image) results = yolo_model(image_array) detected_objects = [] for r in results: for box in r.boxes: class_id = int(box.cls.item()) label = yolo_model.names[class_id] detected_objects.append(label) return list(set(detected_objects)) # Remove duplicates except Exception as e: print(f"Object detection error: {e}") return ["generic", "objects"] def generate_story(detected_objects): try: story_prompt = f"Write a concise, creative short story using these objects: {', '.join(detected_objects)}" response = client.chat.completions.create( model="gpt-3.5-turbo", # More lightweight model messages=[{"role": "user", "content": story_prompt}], max_tokens=150 # Reduced token count ) return response.choices[0].message.content.strip() except Exception as e: print(f"Story generation error: {e}") return "A mysterious tale of adventure and discovery." def summarize_story(story): try: summary = summarizer(story, max_length=50, do_sample=False)[0]['summary_text'] scenes = sent_tokenize(summary) return scenes[:2] # Limit to 2 scenes to reduce computational load except Exception as e: print(f"Story summarization error: {e}") return ["A peaceful scene", "An exciting moment"] def generate_images(story): scenes = summarize_story(story) images = [] for prompt in scenes: try: with torch.no_grad(): # Simplified, less computationally intensive prompt prompt_text = f"Simple illustration: {prompt}, soft colors" image = stable_diffusion( prompt_text, num_inference_steps=20, # Reduced steps guidance_scale=6.0, # Slightly lower guidance height=256, # Smaller image width=256 ).images[0] images.append(image) # Aggressive memory clearing if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: print(f"Image generation error: {e}") # Fallback if no images generated return images if images else [Image.new('RGB', (256, 256), color='lightgray')] def text_to_speech(story): try: tts = gTTS(text=story[:500], lang="en", slow=False) # Limit to first 500 chars audio_file_path = "story_audio.mp3" tts.save(audio_file_path) return audio_file_path except Exception as e: print(f"Text-to-speech error: {e}") return None @spaces.GPU def full_pipeline(image): # Wrap entire process with error handling try: detected_objects = detect_objects(image) story = generate_story(detected_objects) scenes = summarize_story(story) images = generate_images(story) audio = text_to_speech(story) return ( story or "A story could not be generated.", scenes or ["Scene 1", "Scene 2"], images, audio ) except Exception as e: print(f"Full pipeline error: {e}") return ( "An unexpected error occurred.", ["Something went wrong"], [Image.new('RGB', (256, 256), color='lightgray')], None ) # **Gradio UI** demo = gr.Interface( fn=full_pipeline, inputs=gr.Image(type="pil"), outputs=[ gr.Textbox(label="Generated Story"), gr.Textbox(label="Story Scenes"), gr.Gallery(label="Generated Images"), gr.Audio(label="Story Audio"), ], title="AI-Powered Storytelling Assistant", description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story." ) if __name__ == "__main__": demo.launch()