import os import time import streamlit as st from PIL import Image from transformers import pipeline from gtts import gTTS import tempfile # —––––––– Page Setup —––––––– st.set_page_config(page_title="Magic Story Generator", layout="centered") st.title("📖✨ Turn Images into Children's Stories") # —––––––– Load Pipelines (cached) —––––––– @st.cache_resource(show_spinner=False) def load_pipelines(): # 1) Image captioning pipeline captioner = pipeline( task="image-to-text", model="Salesforce/blip-image-captioning-base", device=-1 ) # 2) Story generation pipeline using verified model storyteller = pipeline( task="text2text-generation", model="laxya007/story-generator-t5-small", tokenizer="t5-small", device=-1, max_length=200, temperature=0.7, do_sample=True ) return captioner, storyteller captioner, storyteller = load_pipelines() # —––––––– Main App —––––––– uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"]) if uploaded: img = Image.open(uploaded).convert("RGB") st.image(img, use_column_width=True) # Generate caption with st.spinner("🔍 Generating caption..."): cap = captioner(img) caption = cap[0].get("generated_text", "").strip() if not caption: st.error("😢 Couldn't understand this image. Try another one!") st.stop() st.success(f"**Caption:** {caption}") # Generate story prompt = f"generate story: {caption}" with st.spinner("📝 Writing story..."): start = time.time() story = storyteller(prompt)[0]['generated_text'] gen_time = time.time() - start st.text(f"⏱ Generated in {gen_time:.1f}s") # Format story output story = story.replace("", "").replace("", "").strip() if story.startswith("generate story:"): story = story[15:].strip() # Word limit enforcement words = story.split() story = " ".join(words[:100]) if len(words) > 100 else story # Display story st.subheader("📚 Your Magical Story") st.write(story) # Audio conversion with st.spinner("🔊 Converting to audio..."): try: tts = gTTS(text=story, lang="en", slow=False) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: tts.save(tmp.name) st.audio(tmp.name, format="audio/mp3") except Exception as e: st.warning(f"⚠️ Audio conversion failed: {str(e)}") # Footer st.markdown("---\n*Made with ❤️ by your friendly story wizard*")