Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
from io import BytesIO | |
from transformers import pipeline | |
from gtts import gTTS | |
import tempfile | |
# ββββββββ Page config and title | |
st.set_page_config(page_title="Storyteller for Kids", layout="centered") | |
st.title("πΌοΈ β‘οΈ π Interactive Storyteller") | |
# ββββββββ Load pipelines (cached) | |
def load_pipelines(): | |
# 1. Image captioning | |
captioner = pipeline( | |
"image-captioning", | |
model="Salesforce/blip-image-captioning-base", | |
device=0 if not st.session_state.get("CPU_ONLY", False) else -1 | |
) | |
# 2. Story generation (you can swap to a kid-friendly fine-tuned model) | |
storyteller = pipeline( | |
"text-generation", | |
model="gpt2", | |
device=0 if not st.session_state.get("CPU_ONLY", False) else -1 | |
) | |
return captioner, storyteller | |
captioner, storyteller = load_pipelines() | |
# ββββββββ Sidebar: CPU/GPU toggle (optional) | |
st.sidebar.write("### Settings") | |
st.sidebar.checkbox("Force CPU only", key="CPU_ONLY") | |
# ββββββββ Main UI: image upload | |
uploaded = st.file_uploader("Upload an image:", type=["jpg","jpeg","png"]) | |
if uploaded: | |
image = Image.open(uploaded).convert("RGB") | |
st.image(image, caption="Your picture", use_column_width=True) | |
# ββββββββ 1. Caption | |
with st.spinner("π Looking at the image..."): | |
caption = captioner(image)[0]["generated_text"] | |
st.markdown(f"**Caption:** {caption}") | |
# ββββββββ 2. Story generation | |
prompt = ( | |
f"Use the following description to write a playful story (50β100 words) " | |
f"for 3β10 year-old children:\n\nβ{caption}β\n\nStory:" | |
) | |
with st.spinner("βοΈ Writing a story..."): | |
output = storyteller( | |
prompt, | |
max_length= prompt.count(" ") + 100, # approx ~100 words | |
num_return_sequences=1, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.8 | |
) | |
story = output[0]["generated_text"].split("Story:")[-1].strip() | |
st.markdown("**Story:**") | |
st.write(story) | |
# ββββββββ 3. Text-to-Speech | |
with st.spinner("π Converting to speech..."): | |
tts = gTTS(story, lang="en") | |
tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) | |
tts.write_to_fp(tmp) | |
tmp.flush() | |
st.audio(tmp.name, format="audio/mp3") |