|
|
|
|
|
import streamlit as st |
|
from PIL import Image |
|
from transformers import pipeline |
|
from gtts import gTTS |
|
import tempfile |
|
|
|
|
|
st.set_page_config(page_title="Storyteller for Kids", layout="centered") |
|
st.title("🖼️ ➡️ 📖 Interactive Storyteller") |
|
|
|
|
|
@st.cache_resource |
|
def load_pipelines(): |
|
|
|
captioner = pipeline( |
|
"image-to-text", |
|
model="Salesforce/blip-image-captioning-base", |
|
device=0 |
|
) |
|
|
|
storyteller = pipeline( |
|
"text2text-generation", |
|
model="google/flan-t5-large", |
|
device=0 |
|
) |
|
|
|
|
|
dummy = Image.new("RGB", (384, 384), color=(128,128,128)) |
|
captioner(dummy) |
|
storyteller("Hello", max_new_tokens=1) |
|
|
|
return captioner, storyteller |
|
|
|
captioner, storyteller = load_pipelines() |
|
|
|
|
|
uploaded = st.file_uploader("Upload an image:", type=["jpg","jpeg","png"]) |
|
if uploaded: |
|
|
|
image = Image.open(uploaded).convert("RGB") |
|
image = image.resize((384,384), Image.LANCZOS) |
|
st.image(image, caption="Your image", use_container_width=True) |
|
|
|
|
|
with st.spinner("🔍 Generating caption..."): |
|
cap = captioner(image)[0]["generated_text"].strip() |
|
st.markdown(f"**Caption:** {cap}") |
|
|
|
|
|
prompt = ( |
|
f"Here’s an image description: “{cap}”.\n\n" |
|
"Write a playful, 80–100 word story for 3–10 year-old children.\n" |
|
"- Focus only on the panda and what it’s doing.\n" |
|
"- Do not introduce any other characters (no kids, no parents).\n" |
|
"- Be vivid: mention the panda’s feelings or the crunchy meat.\n\n" |
|
"Story:" |
|
) |
|
with st.spinner("✍️ Writing story..."): |
|
out = storyteller( |
|
prompt, |
|
max_new_tokens=130, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
top_k=50, |
|
repetition_penalty=1.3, |
|
no_repeat_ngram_size=3 |
|
) |
|
|
|
raw = out[0]["generated_text"] |
|
story = raw.split("Story:")[-1].strip() |
|
|
|
st.markdown("**Story:**") |
|
st.write(story) |
|
|
|
|
|
with st.spinner("🔊 Converting to speech..."): |
|
tts = gTTS(text=story, lang="en") |
|
tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) |
|
tts.write_to_fp(tmp) |
|
tmp.flush() |
|
st.audio(tmp.name, format="audio/mp3") |
|
|