1 / app.py
mayf's picture
Update app.py
4d1f328 verified
raw
history blame
4.62 kB
# story_generator.py
import re
import time
import tempfile
import streamlit as st
from PIL import Image
from gtts import gTTS
from transformers import pipeline
# --- Initialize Streamlit Config ---
st.set_page_config(
page_title="Magic Story Generator",
layout="centered",
page_icon="📖"
)
# --- Model Loading (Cached) ---
@st.cache_resource(show_spinner=False)
def load_models():
# Image captioning model
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device=-1
)
# Story generation model with optimized settings
storyteller = pipeline(
"text-generation",
model="Qwen/Qwen3-1.7B",
device_map="auto",
trust_remote_code=True,
torch_dtype="auto",
model_kwargs={
"revision": "main",
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1,
"pad_token_id": 151645
}
)
return captioner, storyteller
# --- Text Processing Utilities ---
def clean_generated_text(raw_text):
# Split at first assistant marker
clean_text = raw_text.split("<|im_start|>assistant\n", 1)[-1]
# Remove any subsequent chat turns
clean_text = clean_text.split("<|im_start|>")[0]
# Remove special tokens and whitespace
clean_text = clean_text.replace("<|im_end|>", "").strip()
# Regex cleanup for remaining markers
clean_text = re.sub(
r'^(assistant[\s\-\:>]*)+',
'',
clean_text,
flags=re.IGNORECASE
).strip()
# Format punctuation and capitalization
sentences = []
for sent in re.split(r'(?<=[.!?]) +', clean_text):
sent = sent.strip()
if not sent:
continue
if sent[-1] not in {'.', '!', '?'}:
sent += '.'
sentences.append(sent[0].upper() + sent[1:])
return ' '.join(sentences)
# --- Main Application UI ---
st.title("📖✨ Magic Story Generator")
uploaded_image = st.file_uploader(
"Upload a children's book style image:",
type=["jpg", "jpeg", "png"]
)
if uploaded_image:
# Display uploaded image
image = Image.open(uploaded_image).convert("RGB")
st.image(image, use_column_width=True)
# Load models only when needed
caption_pipe, story_pipe = load_models()
# Generate image caption
with st.spinner("🔍 Analyzing image..."):
try:
caption_result = caption_pipe(image)
image_caption = caption_result[0].get("generated_text", "").strip()
if not image_caption:
raise ValueError("Couldn't generate caption")
st.success(f"**Image Understanding:** {image_caption}")
except Exception as e:
st.error("❌ Failed to analyze image. Please try another.")
st.stop()
# Create story prompt
story_prompt = (
f"<|im_start|>system\n"
f"You are a children's book author. Create a 150-word story based on: {image_caption}\n"
"Include these elements:\n"
"- Friendly characters\n"
"- Simple vocabulary\n"
"- Positive lesson\n"
"- Clear story structure\n"
"<|im_end|>\n"
f"<|im_start|>user\n"
f"Write an engaging story suitable for ages 6-8.<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# Generate story text
with st.spinner("📝 Crafting magical story..."):
try:
story_result = story_pipe(
story_prompt,
max_new_tokens=300,
do_sample=True,
num_return_sequences=1
)
raw_story = story_result[0]['generated_text']
except Exception as e:
st.error("❌ Story generation failed. Please try again.")
st.stop()
# Process and display story
final_story = clean_generated_text(raw_story)
st.subheader("✨ Your Story")
st.write(final_story)
# Generate audio version
with st.spinner("🔊 Creating audio version..."):
try:
tts = gTTS(text=final_story, lang='en', slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.read(), format="audio/mp3")
except Exception as e:
st.warning("⚠️ Audio conversion failed. Text version still available.")
# Footer
st.markdown("---")
st.caption("Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")