1 / app.py
mayf's picture
Update app.py
e537b6d verified
raw
history blame
4.11 kB
# Must be FIRST import and FIRST Streamlit command
import streamlit as st
st.set_page_config(
page_title="Magic Story Generator",
layout="centered",
page_icon="📖"
)
# Other imports AFTER Streamlit config
import re
import time
import tempfile
from PIL import Image
from gtts import gTTS
from transformers import pipeline
# --- Constants & Setup ---
st.title("📖✨ Turn Images into Children's Stories")
# --- Model Loading (Cached) ---
@st.cache_resource(show_spinner=False)
def load_models():
# Image captioning model
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device=-1 # Use -1 for CPU, 0 for GPU
)
# Story generation model (Qwen3-1.7B)
storyteller = pipeline(
"text-generation",
model="Qwen/Qwen3-1.7B",
device_map="auto",
trust_remote_code=True,
torch_dtype="auto",
max_new_tokens=250,
temperature=0.7,
top_p=0.85,
repetition_penalty=1.15,
eos_token_id=151645
)
return captioner, storyteller
caption_pipe, story_pipe = load_models()
# --- Main Application Flow ---
uploaded_image = st.file_uploader(
"Upload a children's book style image:",
type=["jpg", "jpeg", "png"]
)
if uploaded_image:
# Process image
image = Image.open(uploaded_image).convert("RGB")
st.image(image, use_container_width=True)
# Generate caption
with st.spinner("🔍 Analyzing image..."):
caption_result = caption_pipe(image)
image_caption = caption_result[0].get("generated_text", "").strip()
if not image_caption:
st.error("❌ Couldn't understand this image. Please try another!")
st.stop()
st.success(f"**Image Understanding:** {image_caption}")
# Create story prompt
story_prompt = (
f"<|im_start|>system\n"
f"You are a children's book author. Create a 100-150 word story based on: {image_caption}\n"
"Use simple language, friendly characters, and a positive lesson.<|im_end|>\n"
f"<|im_start|>user\n"
f"Write a child-friendly story with a clear beginning, middle, and end.<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# Generate story
with st.spinner("📝 Crafting magical story..."):
start_time = time.time()
story_result = story_pipe(
story_prompt,
do_sample=True,
num_return_sequences=1,
pad_token_id=151645
)
generation_time = time.time() - start_time
# Process output
raw_story = story_result[0]['generated_text']
# Clean up story text
clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
clean_story = clean_story.split("<|im_start|>")[0] # Remove any new turns
clean_story = clean_story.replace("<|im_end|>", "").strip()
# Remove assistant mentions using regex
clean_story = re.sub(
r'^(assistant[:>]?\s*)+',
'',
clean_story,
flags=re.IGNORECASE
).strip()
# Format story punctuation
final_story = []
for sentence in clean_story.split(". "):
sentence = sentence.strip()
if not sentence:
continue
if not sentence.endswith('.'):
sentence += '.'
final_story.append(sentence[0].upper() + sentence[1:])
final_story = " ".join(final_story).replace("..", ".")[:800]
# Display story
st.subheader("✨ Your Magical Story")
st.write(final_story)
# Audio conversion
with st.spinner("🔊 Creating audio version..."):
try:
audio = gTTS(text=final_story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
audio.save(tmp_file.name)
st.audio(tmp_file.name, format="audio/mp3")
except Exception as e:
st.error(f"❌ Audio conversion failed: {str(e)}")
# Footer
st.markdown("---")
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")