import streamlit as st from PIL import Image from transformers import pipeline from gtts import gTTS import tempfile import os # —––––––– Page config —––––––– st.set_page_config( page_title="Storyteller for Kids", page_icon="📚", layout="centered", initial_sidebar_state="collapsed" ) st.title("🖼️➡️📖 Interactive Storyteller") # —––––––– Cache model loading —––––––– @st.cache_resource def load_pipelines(): # Image-to-text pipeline captioner = pipeline( "image-to-text", model="Salesforce/blip-image-captioning-base", max_new_tokens=50 ) # Story generation pipeline with better parameters storyteller = pipeline( "text2text-generation", model="google/flan-t5-xxl", device_map="auto", model_kwargs={"load_in_8bit": True} ) return captioner, storyteller # —––––––– Main workflow —––––––– def main(): captioner, storyteller = load_pipelines() # —––––––– Image upload —––––––– uploaded = st.file_uploader( "Upload an image:", type=["jpg", "jpeg", "png"], help="Max size: 5MB" ) if uploaded: try: # —––––––– Display image —––––––– image = Image.open(uploaded).convert("RGB") st.image(image, caption="Your Image", use_column_width=True) # —––––––– Generate caption —––––––– with st.spinner("🔍 Analyzing image content..."): cap_outputs = captioner(image) cap = cap_outputs[0].get("generated_text", "").strip() st.subheader("Image Understanding") st.info(f"**Detected:** {cap}") # —––––––– Generate story —––––––– st.subheader("Story Creation") prompt = f"""Create a children's story (3-10 years old) based on this description: {cap} Requirements: - 50-100 words - Playful and imaginative - Positive message - Simple vocabulary - Include animal characters Story:""" with st.spinner("✍️ Crafting a magical story..."): story_output = storyteller( prompt, max_length=300, do_sample=True, top_p=0.95, temperature=0.85, num_beams=4, repetition_penalty=1.2 ) story = story_output[0]["generated_text"].strip() st.success("**Generated Story:**") st.write(story) # —––––––– Text-to-Speech —––––––– st.subheader("Audio Version") with st.spinner("🔊 Generating audio..."): try: tts = gTTS(text=story, lang="en", slow=False) with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp: tts.write_to_fp(tmp) tmp_path = tmp.name st.audio(tmp_path, format="audio/mp3") # Add download button with open(tmp_path, "rb") as f: st.download_button( label="Download Audio Story", data=f, file_name="kids_story.mp3", mime="audio/mpeg" ) finally: if os.path.exists(tmp_path): os.remove(tmp_path) except Exception as e: st.error(f"Error processing your request: {str(e)}") if __name__ == "__main__": main()