1 / app.py
mayf's picture
Update app.py
fd1d947 verified
raw
history blame
3.77 kB
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from transformers import pipeline
from huggingface_hub import login
from gtts import gTTS
import tempfile
# —––––––– Requirements —–––––––
# pip install streamlit pillow gTTS transformers huggingface_hub
# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator (Local Pipeline)", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")
# —––––––– Load Clients & Pipelines (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
# Authenticate to pull private or remote-code models if needed
hf_token = st.secrets.get("HF_TOKEN")
if hf_token:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
login(hf_token)
# 1) Image-captioning pipeline (BLIP)
captioner = pipeline(
task="image-to-text",
model="Salesforce/blip-image-captioning-base",
device=-1 # CPU; change to 0 for GPU
)
# 2) Story-generation pipeline (DeepSeek-R1-Distill-Qwen)
storyteller = pipeline(
task="text-generation",
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
trust_remote_code=True,
device=-1, # CPU; set 0+ for GPU
temperature=0.6,
top_p=0.9,
repetition_penalty=1.1,
no_repeat_ngram_size=2,
max_new_tokens=120,
return_full_text=False
)
return captioner, storyteller
captioner, storyteller = load_clients()
# —––––––– Helpers —–––––––
def generate_caption(img: Image.Image) -> str:
# Use the BLIP pipeline to generate a caption
result = captioner(img)
if isinstance(result, list) and result:
return result[0].get("generated_text", "").strip()
return ""
def generate_story(caption: str) -> str:
# Build a simple prompt incorporating the caption
prompt = (
f"Image description: {caption}\n"
"Write a coherent 50-100 word children's story that flows naturally."
)
t0 = time.time()
outputs = storyteller(
prompt
)
gen_time = time.time() - t0
st.text(f"⏱ Generated in {gen_time:.1f}s")
story = outputs[0].get("generated_text", "").strip()
# Truncate to 100 words
words = story.split()
if len(words) > 100:
story = " ".join(words[:100]) + ('.' if not story.endswith('.') else '')
return story
# —––––––– Main App —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
img = Image.open(uploaded).convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048, 2048))
st.image(img, use_container_width=True)
with st.spinner("🔍 Generating caption..."):
caption = generate_caption(img)
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
with st.spinner("📝 Writing story..."):
story = generate_story(caption)
st.subheader("📚 Your Magical Story")
st.write(story)
with st.spinner("🔊 Converting to audio..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ TTS failed: {e}")
# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")