1 / app.py
mayf's picture
Update app.py
ac11067 verified
raw
history blame
3.92 kB
# Import Streamlit first
import streamlit as st
st.set_page_config(
page_title="Magic Story Generator",
layout="centered",
page_icon="📖"
)
# Other imports
import re
import time
import torch
import tempfile
from PIL import Image
from gtts import gTTS
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
# --- Initialize Models First ---
@st.cache_resource(show_spinner=False)
def load_models():
"""Load and return both models at startup"""
try:
# 1. Image Captioning Model
caption_pipe = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device=0 if torch.cuda.is_available() else -1
)
# 2. Story Generation Model
story_tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-0.6B",
trust_remote_code=True
)
story_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
device_map="auto",
torch_dtype=torch.float16
)
story_pipe = pipeline(
"text-generation",
model=story_model,
tokenizer=story_tokenizer,
max_new_tokens=230,
temperature=0.9,
top_k=50,
top_p=0.9,
repetition_penalty=1.1,
eos_token_id=151645
)
return caption_pipe, story_pipe
except Exception as e:
st.error(f"🚨 Model loading failed: {str(e)}")
st.stop()
# Initialize models immediately when app starts
caption_pipe, story_pipe = load_models()
# --- Rest of Application ---
st.title("📖✨ Turn Images into Children's Stories")
def clean_story_text(raw_text):
"""Improved cleaning function"""
clean = re.sub(r'<\|.*?\|>', '', raw_text) # Remove special tokens
clean = re.sub(r'Okay, I need.*?(?=\n|$)', '', clean, flags=re.DOTALL) # Remove thinking chains
return clean.strip()
uploaded_image = st.file_uploader(
"Upload a children's book style image:",
type=["jpg", "jpeg", "png"]
)
if uploaded_image:
image = Image.open(uploaded_image).convert("RGB")
# Updated parameter here ↓
st.image(image, use_container_width=True) # Changed use_column_width to use_container_width
with st.spinner("🔍 Analyzing image..."):
try:
caption_result = caption_pipe(image)
image_caption = caption_result[0].get("generated_text", "")
st.success(f"**Image Understanding:** {image_caption}")
except Exception as e:
st.error(f"❌ Image analysis failed: {str(e)}")
st.stop()
# Story generation prompt
story_prompt = f"""Write a children's story about: {image_caption}
Rules:
- Use simple words (Grade 2 level)
- Exclude thinking processes
- 3 paragraphs maximum
Story:"""
try:
with st.spinner("📝 Crafting magical story..."):
story_result = story_pipe(
story_prompt,
do_sample=True,
top_p=0.9,
repetition_penalty=1.2
)
raw_story = story_result[0]['generated_text']
final_story = clean_story_text(raw_story.split("Story:")[-1])
st.subheader("✨ Your Magical Story")
st.write(final_story)
# Audio conversion
with st.spinner("🔊 Creating audio version..."):
audio = gTTS(text=final_story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
audio.save(tmp_file.name)
st.audio(tmp_file.name, format="audio/mp3")
except Exception as e:
st.error(f"❌ Story generation failed: {str(e)}")
# Footer
st.markdown("---")
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")