# Import Streamlit first import streamlit as st st.set_page_config( page_title="Magic Story Generator", layout="centered", page_icon="📖" ) # Other imports import re import time import torch import tempfile from PIL import Image from gtts import gTTS from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer # --- Initialize Models First --- @st.cache_resource(show_spinner=False) def load_models(): """Load and return both models at startup""" try: # 1. Image Captioning Model caption_pipe = pipeline( "image-to-text", model="Salesforce/blip-image-captioning-base", device=0 if torch.cuda.is_available() else -1 ) # 2. Story Generation Model story_tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen3-0.6B", trust_remote_code=True ) story_model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen3-0.6B", device_map="auto", torch_dtype=torch.float16 ) story_pipe = pipeline( "text-generation", model=story_model, tokenizer=story_tokenizer, max_new_tokens=300, temperature=0.7 ) return caption_pipe, story_pipe except Exception as e: st.error(f"🚨 Model loading failed: {str(e)}") st.stop() # Initialize models immediately when app starts caption_pipe, story_pipe = load_models() # --- Rest of Application --- st.title("📖✨ Turn Images into Children's Stories") def clean_story_text(raw_text): """Improved cleaning function""" clean = re.sub(r'<\|.*?\|>', '', raw_text) # Remove special tokens clean = re.sub(r'Okay, I need.*?(?=\n|$)', '', clean, flags=re.DOTALL) # Remove thinking chains return clean.strip() uploaded_image = st.file_uploader( "Upload a children's book style image:", type=["jpg", "jpeg", "png"] ) if uploaded_image: image = Image.open(uploaded_image).convert("RGB") # Updated parameter here ↓ st.image(image, use_container_width=True) # Changed use_column_width to use_container_width with st.spinner("🔍 Analyzing image..."): try: caption_result = caption_pipe(image) image_caption = caption_result[0].get("generated_text", "") st.success(f"**Image Understanding:** {image_caption}") except Exception as e: st.error(f"❌ Image analysis failed: {str(e)}") st.stop() # Story generation prompt story_prompt = f"""Write a children's story about: {image_caption} Rules: - Use simple words (Grade 2 level) - Exclude thinking processes - 3 paragraphs maximum Story:""" try: with st.spinner("📝 Crafting magical story..."): story_result = story_pipe( story_prompt, do_sample=True, top_p=0.9, repetition_penalty=1.2 ) raw_story = story_result[0]['generated_text'] final_story = clean_story_text(raw_story.split("Story:")[-1]) st.subheader("✨ Your Magical Story") st.write(final_story) # Audio conversion with st.spinner("🔊 Creating audio version..."): audio = gTTS(text=final_story, lang="en", slow=False) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: audio.save(tmp_file.name) st.audio(tmp_file.name, format="audio/mp3") except Exception as e: st.error(f"❌ Story generation failed: {str(e)}") # Footer st.markdown("---") st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")