mayf's picture
Create app.py
93fc785 verified
raw
history blame
2.49 kB
import streamlit as st
from PIL import Image
from io import BytesIO
from transformers import pipeline
from gtts import gTTS
import tempfile
# —––––––– Page config and title
st.set_page_config(page_title="Storyteller for Kids", layout="centered")
st.title("πŸ–ΌοΈ ➑️ πŸ“– Interactive Storyteller")
# —––––––– Load pipelines (cached)
@st.experimental_singleton
def load_pipelines():
# 1. Image captioning
captioner = pipeline(
"image-captioning",
model="Salesforce/blip-image-captioning-base",
device=0 if not st.session_state.get("CPU_ONLY", False) else -1
)
# 2. Story generation (you can swap to a kid-friendly fine-tuned model)
storyteller = pipeline(
"text-generation",
model="gpt2",
device=0 if not st.session_state.get("CPU_ONLY", False) else -1
)
return captioner, storyteller
captioner, storyteller = load_pipelines()
# —––––––– Sidebar: CPU/GPU toggle (optional)
st.sidebar.write("### Settings")
st.sidebar.checkbox("Force CPU only", key="CPU_ONLY")
# —––––––– Main UI: image upload
uploaded = st.file_uploader("Upload an image:", type=["jpg","jpeg","png"])
if uploaded:
image = Image.open(uploaded).convert("RGB")
st.image(image, caption="Your picture", use_column_width=True)
# —––––––– 1. Caption
with st.spinner("πŸ” Looking at the image..."):
caption = captioner(image)[0]["generated_text"]
st.markdown(f"**Caption:** {caption}")
# —––––––– 2. Story generation
prompt = (
f"Use the following description to write a playful story (50–100 words) "
f"for 3–10 year-old children:\n\nβ€œ{caption}”\n\nStory:"
)
with st.spinner("✍️ Writing a story..."):
output = storyteller(
prompt,
max_length= prompt.count(" ") + 100, # approx ~100 words
num_return_sequences=1,
do_sample=True,
top_p=0.9,
temperature=0.8
)
story = output[0]["generated_text"].split("Story:")[-1].strip()
st.markdown("**Story:**")
st.write(story)
# —––––––– 3. Text-to-Speech
with st.spinner("πŸ”Š Converting to speech..."):
tts = gTTS(story, lang="en")
tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
tts.write_to_fp(tmp)
tmp.flush()
st.audio(tmp.name, format="audio/mp3")