# story_generator.py import re import time import tempfile import streamlit as st from PIL import Image from gtts import gTTS from transformers import pipeline # --- Initialize Streamlit Config --- st.set_page_config( page_title="Magic Story Generator", layout="centered", page_icon="📖" ) # --- Model Loading (Cached) --- @st.cache_resource(show_spinner=False) def load_models(): # Image captioning model captioner = pipeline( "image-to-text", model="Salesforce/blip-image-captioning-base", device=-1 ) # Story generation model with optimized settings storyteller = pipeline( "text-generation", model="Qwen/Qwen3-1.7B", device_map="auto", trust_remote_code=True, torch_dtype="auto", model_kwargs={ "revision": "main", "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.1, "pad_token_id": 151645 } ) return captioner, storyteller # --- Text Processing Utilities --- def clean_generated_text(raw_text): # Split at first assistant marker clean_text = raw_text.split("<|im_start|>assistant\n", 1)[-1] # Remove any subsequent chat turns clean_text = clean_text.split("<|im_start|>")[0] # Remove special tokens and whitespace clean_text = clean_text.replace("<|im_end|>", "").strip() # Regex cleanup for remaining markers clean_text = re.sub( r'^(assistant[\s\-\:>]*)+', '', clean_text, flags=re.IGNORECASE ).strip() # Format punctuation and capitalization sentences = [] for sent in re.split(r'(?<=[.!?]) +', clean_text): sent = sent.strip() if not sent: continue if sent[-1] not in {'.', '!', '?'}: sent += '.' sentences.append(sent[0].upper() + sent[1:]) return ' '.join(sentences) # --- Main Application UI --- st.title("📖✨ Magic Story Generator") uploaded_image = st.file_uploader( "Upload a children's book style image:", type=["jpg", "jpeg", "png"] ) if uploaded_image: # Display uploaded image image = Image.open(uploaded_image).convert("RGB") st.image(image, use_column_width=True) # Load models only when needed caption_pipe, story_pipe = load_models() # Generate image caption with st.spinner("🔍 Analyzing image..."): try: caption_result = caption_pipe(image) image_caption = caption_result[0].get("generated_text", "").strip() if not image_caption: raise ValueError("Couldn't generate caption") st.success(f"**Image Understanding:** {image_caption}") except Exception as e: st.error("❌ Failed to analyze image. Please try another.") st.stop() # Create story prompt story_prompt = ( f"<|im_start|>system\n" f"You are a children's book author. Create a 150-word story based on: {image_caption}\n" "Include these elements:\n" "- Friendly characters\n" "- Simple vocabulary\n" "- Positive lesson\n" "- Clear story structure\n" "<|im_end|>\n" f"<|im_start|>user\n" f"Write an engaging story suitable for ages 6-8.<|im_end|>\n" f"<|im_start|>assistant\n" ) # Generate story text with st.spinner("📝 Crafting magical story..."): try: story_result = story_pipe( story_prompt, max_new_tokens=300, do_sample=True, num_return_sequences=1 ) raw_story = story_result[0]['generated_text'] except Exception as e: st.error("❌ Story generation failed. Please try again.") st.stop() # Process and display story final_story = clean_generated_text(raw_story) st.subheader("✨ Your Story") st.write(final_story) # Generate audio version with st.spinner("🔊 Creating audio version..."): try: tts = gTTS(text=final_story, lang='en', slow=False) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp: tts.save(fp.name) st.audio(fp.read(), format="audio/mp3") except Exception as e: st.warning("⚠️ Audio conversion failed. Text version still available.") # Footer st.markdown("---") st.caption("Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")