|
|
|
|
|
import streamlit as st |
|
from PIL import Image |
|
from io import BytesIO |
|
from huggingface_hub import InferenceApi |
|
from gtts import gTTS |
|
import tempfile |
|
|
|
|
|
st.set_page_config(page_title="Storyteller for Kids", layout="centered") |
|
st.title("🖼️ ➡️ 📖 Interactive Storyteller") |
|
|
|
|
|
@st.cache_resource |
|
def load_clients(): |
|
hf_token = st.secrets["HF_TOKEN"] |
|
caption_client = InferenceApi( |
|
repo_id="Salesforce/blip-image-captioning-base", |
|
task="image-to-text", |
|
token=hf_token |
|
) |
|
story_client = InferenceApi( |
|
repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", |
|
task="text-generation", |
|
token=hf_token |
|
) |
|
return caption_client, story_client |
|
|
|
caption_client, story_client = load_clients() |
|
|
|
|
|
uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"]) |
|
if not uploaded: |
|
st.info("Please upload a JPG/PNG image to begin.") |
|
else: |
|
|
|
img = Image.open(uploaded).convert("RGB") |
|
st.image(img, use_container_width=True) |
|
|
|
|
|
with st.spinner("🔍 Generating caption..."): |
|
try: |
|
buf = BytesIO() |
|
img.save(buf, format="PNG") |
|
cap_out = caption_client(data=buf.getvalue()) |
|
|
|
|
|
if isinstance(cap_out, list) and cap_out: |
|
cap_text = cap_out[0].get("generated_text", "").strip() |
|
elif isinstance(cap_out, dict): |
|
cap_text = cap_out.get("generated_text", "").strip() |
|
else: |
|
cap_text = str(cap_out).strip() |
|
|
|
except Exception as e: |
|
st.error(f"🚨 Caption generation failed: {str(e)}") |
|
st.stop() |
|
|
|
if not cap_text: |
|
st.error("😕 Couldn’t generate a caption. Try another image.") |
|
st.stop() |
|
|
|
st.markdown(f"**Caption:** {cap_text}") |
|
|
|
|
|
prompt = ( |
|
f"Here’s an image description: “{cap_text}”.\n\n" |
|
"Write an 80–100 word playful story for 3–10 year-old children that:\n" |
|
"1) Describes the scene and main subject.\n" |
|
"2) Explains what it’s doing and how it feels.\n" |
|
"3) Concludes with a fun, imaginative ending.\n\n" |
|
"Story:" |
|
) |
|
|
|
|
|
复制代码 |
|
|
|
with st.spinner("✍️ Generating story..."): |
|
try: |
|
story_out = story_client( |
|
prompt, |
|
max_new_tokens=250, |
|
temperature=0.7, |
|
top_p=0.9, |
|
top_k=50, |
|
repetition_penalty=1.1, |
|
do_sample=True, |
|
no_repeat_ngram_size=2 |
|
) |
|
|
|
|
|
if isinstance(story_out, list): |
|
story_text = story_out[0].get("generated_text", "") |
|
else: |
|
story_text = story_out.get("generated_text", "") |
|
|
|
|
|
story = story_text.split("Story:")[-1].strip() |
|
|
|
except Exception as e: |
|
st.error(f"🚨 Story generation failed: {str(e)}") |
|
st.stop() |
|
|
|
|
|
with st.spinner("🔊 Converting to speech..."): |
|
try: |
|
tts = gTTS(text=story, lang="en") |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: |
|
tts.write_to_fp(tmp) |
|
tmp.seek(0) |
|
st.audio(tmp.name, format="audio/mp3") |
|
except Exception as e: |
|
st.error(f"🔇 Audio conversion failed: {str(e)}") |
|
|
|
|