1 / app.py
mayf's picture
Update app.py
422a749 verified
raw
history blame
3.71 kB
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi
from transformers import pipeline
from gtts import gTTS
import tempfile
# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")
# —––––––– Clients (cached) —–––––––
@st.cache_resource
def load_clients():
hf_token = st.secrets["HF_TOKEN"]
# image captioning client as before
caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
# text-generation pipeline for story
story_generator = pipeline(
"text-generation",
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
use_auth_token=hf_token,
device=0 # or -1 for CPU
)
return caption_client, story_generator
caption_client, story_generator = load_clients()
# —––––––– Helper: Generate Caption —–––––––
def generate_caption(img):
buf = BytesIO()
img.save(buf, format="JPEG")
try:
out = caption_client(data=buf.getvalue())
return out[0].get("generated_text", "").strip()
except Exception as e:
st.error(f"Caption error: {e}")
return ""
# —––––––– Helper: Generate Story via pipeline —–––––––
def generate_story(prompt: str) -> str:
# generate up to ~200 tokens to cover 100 words margin
outputs = story_generator(
prompt,
max_new_tokens=200,
temperature=0.8,
top_p=0.95,
repetition_penalty=1.15,
no_repeat_ngram_size=2,
do_sample=True
)
text = outputs[0]["generated_text"].strip()
# everything after "Story:" (if you kept that in your prompt)
if "Story:" in text:
text = text.split("Story:", 1)[1].strip()
# truncate to 100 words
words = text.split()
if len(words) > 100:
text = " ".join(words[:100])
if not text.endswith("."):
text += "."
return text
# —––––––– Main App Flow —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
img = Image.open(uploaded).convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048, 2048))
st.image(img, use_container_width=True)
caption = generate_caption(img)
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
story_prompt = (
f"Image description: {caption}\n\n"
"Write a 50-100 word children's story that:\n"
"1. Features the main subject as a friendly character\n"
"2. Includes a simple adventure or discovery\n"
"3. Ends with a happy or funny conclusion\n"
"4. Uses simple language for ages 3-8\n\n"
"Story:\n"
)
with st.spinner("📝 Writing magical story..."):
story = generate_story(story_prompt)
st.subheader("📚 Your Magical Story")
st.write(story)
# Audio Conversion
with st.spinner("🔊 Adding story voice..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ Couldn't make audio version: {e}")
# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")