1 / app.py
mayf's picture
Update app.py
2aae3c9 verified
raw
history blame
4.24 kB
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi, login
from transformers import pipeline
from gtts import gTTS
import tempfile
# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")
# —––––––– Clients (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
hf_token = st.secrets["HF_TOKEN"]
# Authenticate for both HF Hub and transformers
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
login(hf_token)
# Pin transformers cache locally via env var
cache_dir = "./hf_cache"
os.makedirs(cache_dir, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = cache_dir
# 1) BLIP image-captioning client
caption_client = InferenceApi(
repo_id="Salesforce/blip-image-captioning-base",
token=hf_token
)
# 2) Text-generation pipeline on CPU (no cache_dir arg here!)
t0 = time.time()
story_generator = pipeline(
task="text-generation",
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
device=-1 # force CPU
)
st.text(f"✅ Story model loaded in {time.time() - t0:.1f}s (cached thereafter)")
return caption_client, story_generator
caption_client, story_generator = load_clients()
# —––––––– Helper: Generate Caption —–––––––
def generate_caption(img: Image.Image) -> str:
buf = BytesIO()
img.save(buf, format="JPEG")
try:
resp = caption_client(data=buf.getvalue())
if isinstance(resp, list) and resp:
return resp[0].get("generated_text", "").strip()
except Exception as e:
st.error(f"Caption generation error: {e}")
return ""
# —––––––– Helper: Generate Story via pipeline —–––––––
def generate_story(caption: str) -> str:
prompt = f"""
You are a creative children’s-story author.
Below is an image description:
{caption}
Write a coherent 50–100 word story that:
1. Introduces the main character.
2. Shows a simple problem or discovery.
3. Has a happy resolution.
4. Uses clear language for ages 3–8.
5. Keeps sentences under 20 words.
Story:
"""
t0 = time.time()
outputs = story_generator(
prompt,
max_new_tokens=120,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
no_repeat_ngram_size=3,
do_sample=True
)
st.text(f"⏱ Generated in {time.time() - t0:.1f}s on CPU")
text = outputs[0]["generated_text"].strip()
# strip the prompt echo
if text.startswith(prompt):
text = text[len(prompt):].strip()
# enforce ≤100 words
words = text.split()
if len(words) > 100:
text = " ".join(words[:100])
if not text.endswith("."):
text += "."
return text
# —––––––– Main App Flow —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
if uploaded:
img = Image.open(uploaded).convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048, 2048))
st.image(img, use_container_width=True)
with st.spinner("🔍 Generating caption..."):
caption = generate_caption(img)
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
with st.spinner("📝 Writing magical story..."):
story = generate_story(caption)
st.subheader("📚 Your Magical Story")
st.write(story)
with st.spinner("🔊 Converting to audio..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ TTS failed: {e}")
# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")