1 / app.py
mayf's picture
Update app.py
eec20c9 verified
raw
history blame
4.72 kB
# FIRST import and FIRST Streamlit command
import streamlit as st
st.set_page_config(
page_title="Magic Story Generator",
layout="centered",
page_icon="📖"
)
# Other imports
import re
import time
import torch
import tempfile
from PIL import Image
from gtts import gTTS
from transformers import pipeline
# --- Constants & Setup ---
st.title("📖✨ Turn Images into Children's Stories")
# --- Model Loading (Cached) ---
@st.cache_resource(show_spinner=False)
def load_models():
# Image captioning model
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device=0 if torch.cuda.is_available() else -1
)
# Optimized story generation model
storyteller = pipeline(
"text-generation",
model="Qwen/Qwen3-0.5B",
device_map="auto",
trust_remote_code=True,
model_kwargs={"load_in_8bit": True},
torch_dtype=torch.float16,
max_new_tokens=200,
temperature=0.9,
top_k=50,
top_p=0.9,
repetition_penalty=1.1,
eos_token_id=151645
)
return captioner, storyteller
caption_pipe, story_pipe = load_models()
# --- Main Application Flow ---
uploaded_image = st.file_uploader(
"Upload a children's book style image:",
type=["jpg", "jpeg", "png"]
)
if uploaded_image:
# Process image
image = Image.open(uploaded_image).convert("RGB")
st.image(image, use_column_width=True)
# Generate caption
with st.spinner("🔍 Analyzing image..."):
try:
caption_result = caption_pipe(image)
image_caption = caption_result[0].get("generated_text", "").strip()
except Exception as e:
st.error(f"❌ Image analysis failed: {str(e)}")
st.stop()
if not image_caption:
st.error("❌ Couldn't understand this image. Please try another!")
st.stop()
st.success(f"**Image Understanding:** {image_caption}")
# Create story prompt
story_prompt = (
f"<|im_start|>system\n"
f"You're a children's author. Create a short story (100-150 words) based on: {image_caption}\n"
f"Use simple language and include a moral lesson.<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# Generate story with progress
progress_bar = st.progress(0)
status_text = st.empty()
try:
with st.spinner("📝 Crafting magical story..."):
start_time = time.time()
def update_progress(step):
progress = min(step/5, 1.0) # Simulate progress steps
progress_bar.progress(progress)
status_text.text(f"Step {int(step)}/5: {'📖'*int(step)}")
update_progress(1)
story_result = story_pipe(
story_prompt,
do_sample=True,
num_return_sequences=1
)
update_progress(4)
generation_time = time.time() - start_time
st.info(f"Story generated in {generation_time:.1f} seconds")
# Process output
raw_story = story_result[0]['generated_text']
clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
clean_story = re.sub(r'<\|.*?\|>', '', clean_story).strip()
# Format story text
sentences = []
for sent in re.split(r'(?<=[.!?]) +', clean_story):
sent = sent.strip()
if sent:
if len(sent) > 1 and not sent.endswith(('.','!','?')):
sent += '.'
sentences.append(sent[0].upper() + sent[1:])
final_story = ' '.join(sentences)[:600] # Limit length
update_progress(5)
time.sleep(0.5) # Final progress pause
except Exception as e:
st.error(f"❌ Story generation failed: {str(e)}")
st.stop()
finally:
progress_bar.empty()
status_text.empty()
# Display story
st.subheader("✨ Your Magical Story")
st.write(final_story)
# Audio conversion
with st.spinner("🔊 Creating audio version..."):
try:
audio = gTTS(text=final_story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
audio.save(tmp_file.name)
st.audio(tmp_file.name, format="audio/mp3")
except Exception as e:
st.error(f"❌ Audio conversion failed: {str(e)}")
# Footer
st.markdown("---")
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")