import streamlit as st from PIL import Image from io import BytesIO from transformers import pipeline from gtts import gTTS import tempfile # —––––––– Page config and title st.set_page_config(page_title="Storyteller for Kids", layout="centered") st.title("🖼️ ➡️ 📖 Interactive Storyteller") # —––––––– Load pipelines (cached) @st.experimental_singleton def load_pipelines(): # 1. Image captioning captioner = pipeline( "image-captioning", model="Salesforce/blip-image-captioning-base", device=0 if not st.session_state.get("CPU_ONLY", False) else -1 ) # 2. Story generation (you can swap to a kid-friendly fine-tuned model) storyteller = pipeline( "text-generation", model="gpt2", device=0 if not st.session_state.get("CPU_ONLY", False) else -1 ) return captioner, storyteller captioner, storyteller = load_pipelines() # —––––––– Sidebar: CPU/GPU toggle (optional) st.sidebar.write("### Settings") st.sidebar.checkbox("Force CPU only", key="CPU_ONLY") # —––––––– Main UI: image upload uploaded = st.file_uploader("Upload an image:", type=["jpg","jpeg","png"]) if uploaded: image = Image.open(uploaded).convert("RGB") st.image(image, caption="Your picture", use_column_width=True) # —––––––– 1. Caption with st.spinner("🔍 Looking at the image..."): caption = captioner(image)[0]["generated_text"] st.markdown(f"**Caption:** {caption}") # —––––––– 2. Story generation prompt = ( f"Use the following description to write a playful story (50–100 words) " f"for 3–10 year-old children:\n\n“{caption}”\n\nStory:" ) with st.spinner("✍️ Writing a story..."): output = storyteller( prompt, max_length= prompt.count(" ") + 100, # approx ~100 words num_return_sequences=1, do_sample=True, top_p=0.9, temperature=0.8 ) story = output[0]["generated_text"].split("Story:")[-1].strip() st.markdown("**Story:**") st.write(story) # —––––––– 3. Text-to-Speech with st.spinner("🔊 Converting to speech..."): tts = gTTS(story, lang="en") tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) tts.write_to_fp(tmp) tmp.flush() st.audio(tmp.name, format="audio/mp3")