File size: 3,802 Bytes
f913ab4 4d1f328 8151df4 9862828 8151df4 9862828 8151df4 c83a777 6523fb1 c83a777 6523fb1 8151df4 c83a777 f913ab4 6523fb1 8151df4 fd1d947 982555a 8151df4 6523fb1 8151df4 9862828 6523fb1 9862828 6523fb1 9862828 6523fb1 9862828 6523fb1 eec20c9 6523fb1 eec20c9 6523fb1 eec20c9 6523fb1 eec20c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
# Import Streamlit first
import streamlit as st
st.set_page_config(
page_title="Magic Story Generator",
layout="centered",
page_icon="📖"
)
# Other imports
import re
import time
import torch
import tempfile
from PIL import Image
from gtts import gTTS
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
# --- Initialize Models First ---
@st.cache_resource(show_spinner=False)
def load_models():
"""Load and return both models at startup"""
try:
# 1. Image Captioning Model
caption_pipe = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device=0 if torch.cuda.is_available() else -1
)
# 2. Story Generation Model
story_tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-0.6B",
trust_remote_code=True
)
story_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
device_map="auto",
torch_dtype=torch.float16
)
story_pipe = pipeline(
"text-generation",
model=story_model,
tokenizer=story_tokenizer,
max_new_tokens=300,
temperature=0.7
)
return caption_pipe, story_pipe
except Exception as e:
st.error(f"🚨 Model loading failed: {str(e)}")
st.stop()
# Initialize models immediately when app starts
caption_pipe, story_pipe = load_models()
# --- Rest of Application ---
st.title("📖✨ Turn Images into Children's Stories")
def clean_story_text(raw_text):
"""Improved cleaning function"""
clean = re.sub(r'<\|.*?\|>', '', raw_text) # Remove special tokens
clean = re.sub(r'Okay, I need.*?(?=\n|$)', '', clean, flags=re.DOTALL) # Remove thinking chains
return clean.strip()
uploaded_image = st.file_uploader(
"Upload a children's book style image:",
type=["jpg", "jpeg", "png"]
)
if uploaded_image:
image = Image.open(uploaded_image).convert("RGB")
# Updated parameter here ↓
st.image(image, use_container_width=True) # Changed use_column_width to use_container_width
with st.spinner("🔍 Analyzing image..."):
try:
caption_result = caption_pipe(image)
image_caption = caption_result[0].get("generated_text", "")
st.success(f"**Image Understanding:** {image_caption}")
except Exception as e:
st.error(f"❌ Image analysis failed: {str(e)}")
st.stop()
# Story generation prompt
story_prompt = f"""Write a children's story about: {image_caption}
Rules:
- Use simple words (Grade 2 level)
- Exclude thinking processes
- 3 paragraphs maximum
Story:"""
try:
with st.spinner("📝 Crafting magical story..."):
story_result = story_pipe(
story_prompt,
do_sample=True,
top_p=0.9,
repetition_penalty=1.2
)
raw_story = story_result[0]['generated_text']
final_story = clean_story_text(raw_story.split("Story:")[-1])
st.subheader("✨ Your Magical Story")
st.write(final_story)
# Audio conversion
with st.spinner("🔊 Creating audio version..."):
audio = gTTS(text=final_story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
audio.save(tmp_file.name)
st.audio(tmp_file.name, format="audio/mp3")
except Exception as e:
st.error(f"❌ Story generation failed: {str(e)}")
# Footer
st.markdown("---")
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")
|