1 / app.py
mayf's picture
Update app.py
cab8adc verified
raw
history blame
3.72 kB
# app.py
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi
from gtts import gTTS
import tempfile
# —––––––– Page config
st.set_page_config(page_title="Storyteller for Kids", layout="centered")
st.title("🖼️ ➡️ 📖 Interactive Storyteller")
# —––––––– Inference clients (cached)
@st.cache_resource
def load_clients():
hf_token = st.secrets["HF_TOKEN"]
caption_client = InferenceApi(
repo_id="Salesforce/blip-image-captioning-base",
task="image-to-text",
token=hf_token
)
story_client = InferenceApi(
repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
task="text-generation",
token=hf_token
)
return caption_client, story_client
caption_client, story_client = load_clients()
# —––––––– Main UI
uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
if not uploaded:
st.info("Please upload a JPG/PNG image to begin.")
else:
# 1) Display image
img = Image.open(uploaded).convert("RGB")
st.image(img, use_container_width=True)
# 2) Generate caption
with st.spinner("🔍 Generating caption..."):
try:
buf = BytesIO()
img.save(buf, format="PNG")
cap_out = caption_client(data=buf.getvalue())
# Handle caption response
if isinstance(cap_out, list) and cap_out:
cap_text = cap_out[0].get("generated_text", "").strip()
elif isinstance(cap_out, dict):
cap_text = cap_out.get("generated_text", "").strip()
else:
cap_text = str(cap_out).strip()
except Exception as e:
st.error(f"🚨 Caption generation failed: {str(e)}")
st.stop()
if not cap_text:
st.error("😕 Couldn’t generate a caption. Try another image.")
st.stop()
st.markdown(f"**Caption:** {cap_text}")
# 3) Build story prompt
prompt = (
f"Here’s an image description: “{cap_text}”.\n\n"
"Write an 80–100 word playful story for 3–10 year-old children that:\n"
"1) Describes the scene and main subject.\n"
"2) Explains what it’s doing and how it feels.\n"
"3) Concludes with a fun, imaginative ending.\n\n"
"Story:"
)
# 4) Generate story with corrected parameter format
with st.spinner("✍️ Generating story..."):
try:
story_out = story_client(
prompt,
max_new_tokens=250, # Direct keyword arguments
temperature=0.7,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
do_sample=True,
no_repeat_ngram_size=2
)
# Handle response format
if isinstance(story_out, list):
story_text = story_out[0].get("generated_text", "")
else: # Handle single-dictionary response
story_text = story_out.get("generated_text", "")
# Extract story content after last prompt mention
story = story_text.split("Story:")[-1].strip()
except Exception as e:
st.error(f"🚨 Story generation failed: {str(e)}")
st.stop()
# 5) Text-to-Speech
with st.spinner("🔊 Converting to speech..."):
try:
tts = gTTS(text=story, lang="en")
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
tts.write_to_fp(tmp)
tmp.seek(0)
st.audio(tmp.name, format="audio/mp3")
except Exception as e:
st.error(f"🔇 Audio conversion failed: {str(e)}")