File size: 2,189 Bytes
dfb3989 8367fb2 dfb3989 8367fb2 dfb3989 8367fb2 dfb3989 8367fb2 504dc12 dfb3989 8367fb2 b3f64ee 8367fb2 b3f64ee 8367fb2 dfb3989 8367fb2 dfb3989 8367fb2 dfb3989 8367fb2 dfb3989 b3f64ee dfb3989 dd4f7ba b3f64ee dfb3989 1c165f8 dfb3989 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# app.py
import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS
import tempfile
# —––––––– Page config
st.set_page_config(page_title="Storyteller for Kids", layout="centered")
st.title("🖼️ ➡️ 📖 Interactive Storyteller")
# —––––––– Cache model loading
@st.cache_resource
def load_pipelines():
# 1) Image-to-text (captioning)
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base"
)
# 2) Story generation with GPT-Neo 2.7B
storyteller = pipeline(
"text-generation",
model="EleutherAI/gpt-neo-2.7B",
device=-1 # set to -1 if you only have CPU
)
return captioner, storyteller
captioner, storyteller = load_pipelines()
# —––––––– Image upload
uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
if uploaded:
image = Image.open(uploaded).convert("RGB")
st.image(image, caption="Your image", use_column_width=True)
# —––––––– 1. Caption
with st.spinner("🔍 Looking at the image..."):
cap_outputs = captioner(image)
cap = cap_outputs[0].get("generated_text", "").strip()
st.markdown(f"**Caption:** {cap}")
# —––––––– 2. Story generation
prompt = (
"Write a playful, 80–100 word story for 3–10 year-old children "
f"based on this description:\n\n“{cap}”\n\nStory:"
)
with st.spinner("✍️ Writing a story..."):
out = storyteller(
prompt,
max_new_tokens=120, # allow space for ~100 words
do_sample=True,
top_p=0.9,
temperature=0.8,
num_return_sequences=1
)
story = out[0]["generated_text"].strip()
st.markdown("**Story:**")
st.write(story)
# —––––––– 3. Text-to-Speech
with st.spinner("🔊 Converting to speech..."):
tts = gTTS(story, lang="en")
tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
tts.write_to_fp(tmp)
tmp.flush()
st.audio(tmp.name, format="audio/mp3")
|