import os import time import streamlit as st from PIL import Image from transformers import pipeline from gtts import gTTS import tempfile # --- Requirements --- # Update requirements.txt to include: """ streamlit>=1.20 pillow>=9.0 torch>=2.0.0 transformers>=4.40 sentencepiece>=0.2.0 gTTS>=2.3.1 accelerate>=0.30 """ # --- Page Setup --- st.set_page_config(page_title="Magic Story Generator", layout="centered") st.title("📖✨ Turn Images into Children's Stories") # --- Load Pipelines (cached) --- @st.cache_resource(show_spinner=False) def load_pipelines(): # 1) Image-captioning pipeline (BLIP) captioner = pipeline( task="image-to-text", model="Salesforce/blip-image-captioning-base", device=-1 ) # 2) Modified story-generation pipeline using Qwen3-1.7B storyteller = pipeline( task="text-generation", model="Qwen/Qwen3-1.7B", device_map="auto", trust_remote_code=True, torch_dtype="auto", max_new_tokens=150, temperature=0.7, top_p=0.9, repetition_penalty=1.2, eos_token_id=151645 # Specific to Qwen3 tokenizer ) return captioner, storyteller captioner, storyteller = load_pipelines() # --- Main App --- uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"]) if uploaded: # Load and display the image img = Image.open(uploaded).convert("RGB") st.image(img, use_container_width=True) # Generate caption with st.spinner("🔍 Generating caption..."): cap = captioner(img) caption = cap[0].get("generated_text", "").strip() if isinstance(cap, list) else "" if not caption: st.error("😢 Couldn't understand this image. Try another one!") st.stop() st.success(f"**Caption:** {caption}") # Build prompt and generate story prompt = ( f"<|im_start|>system\n" f"You are a children's story writer. Create a 50-100 word story based on this image description: {caption}\n" f"<|im_end|>\n" f"<|im_start|>user\n" f"Write a coherent, child-friendly story that flows naturally with simple vocabulary.<|im_end|>\n" f"<|im_start|>assistant\n" ) with st.spinner("📝 Writing story..."): start = time.time() out = storyteller( prompt, do_sample=True, num_return_sequences=1 ) gen_time = time.time() - start st.text(f"⏱ Generated in {gen_time:.1f}s") # Process output story = out[0]['generated_text'].split("<|im_start|>assistant\n")[-1] story = story.replace("<|im_end|>", "").strip() # Enforce ≤100 words and proper ending words = story.split() if len(words) > 100: story = " ".join(words[:100]) if not story.endswith(('.', '!', '?')): story += '.' # Display story st.subheader("📚 Your Magical Story") st.write(story) # Convert to audio with st.spinner("🔊 Converting to audio..."): try: tts = gTTS(text=story, lang="en", slow=False) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") tts.save(tmp.name) st.audio(tmp.name, format="audio/mp3") except Exception as e: st.warning(f"⚠️ TTS failed: {e}") # Footer st.markdown("---\nMade with ❤️ by your friendly story wizard")