1 / app.py
mayf's picture
Update app.py
c83a777 verified
raw
history blame
3.42 kB
import os
import time
import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS
import tempfile
# --- Requirements ---
# Update requirements.txt to include:
"""
streamlit>=1.20
pillow>=9.0
torch>=2.0.0
transformers>=4.40
sentencepiece>=0.2.0
gTTS>=2.3.1
accelerate>=0.30
"""
# --- Page Setup ---
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")
# --- Load Pipelines (cached) ---
@st.cache_resource(show_spinner=False)
def load_pipelines():
# 1) Image-captioning pipeline (BLIP)
captioner = pipeline(
task="image-to-text",
model="Salesforce/blip-image-captioning-base",
device=-1
)
# 2) Modified story-generation pipeline using Qwen3-1.7B
storyteller = pipeline(
task="text-generation",
model="Qwen/Qwen3-1.7B",
device_map="auto",
trust_remote_code=True,
torch_dtype="auto",
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
eos_token_id=151645 # Specific to Qwen3 tokenizer
)
return captioner, storyteller
captioner, storyteller = load_pipelines()
# --- Main App ---
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
# Load and display the image
img = Image.open(uploaded).convert("RGB")
st.image(img, use_container_width=True)
# Generate caption
with st.spinner("🔍 Generating caption..."):
cap = captioner(img)
caption = cap[0].get("generated_text", "").strip() if isinstance(cap, list) else ""
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
# Build prompt and generate story
prompt = (
f"<|im_start|>system\n"
f"You are a children's story writer. Create a 50-100 word story based on this image description: {caption}\n"
f"<|im_end|>\n"
f"<|im_start|>user\n"
f"Write a coherent, child-friendly story that flows naturally with simple vocabulary.<|im_end|>\n"
f"<|im_start|>assistant\n"
)
with st.spinner("📝 Writing story..."):
start = time.time()
out = storyteller(
prompt,
do_sample=True,
num_return_sequences=1
)
gen_time = time.time() - start
st.text(f"⏱ Generated in {gen_time:.1f}s")
# Process output
story = out[0]['generated_text'].split("<|im_start|>assistant\n")[-1]
story = story.replace("<|im_end|>", "").strip()
# Enforce ≤100 words and proper ending
words = story.split()
if len(words) > 100:
story = " ".join(words[:100])
if not story.endswith(('.', '!', '?')):
story += '.'
# Display story
st.subheader("📚 Your Magical Story")
st.write(story)
# Convert to audio
with st.spinner("🔊 Converting to audio..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
tts.save(tmp.name)
st.audio(tmp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ TTS failed: {e}")
# Footer
st.markdown("---\nMade with ❤️ by your friendly story wizard")