import os import time import streamlit as st from PIL import Image from io import BytesIO from huggingface_hub import InferenceApi, login from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer import torch from gtts import gTTS import tempfile # —––––––– Page Config —––––––– st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered") st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)") # —––––––– Load Clients & Pipelines (cached) —––––––– @st.cache_resource(show_spinner=False) def load_clients(): hf_token = st.secrets["HF_TOKEN"] # Authenticate for Hugging Face Hub os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token login(hf_token) # 1) BLIP captioning via HTTP API caption_client = InferenceApi( repo_id="Salesforce/blip-image-captioning-base", token=hf_token ) # 2) Load Qwen2.5-Omni model & tokenizer t0 = time.time() tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen2.5-Omni-7B", trust_remote_code=True ) model = AutoModelForSeq2SeqLM.from_pretrained( "Qwen/Qwen2.5-Omni-7B", trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) # 3) Build text2text pipeline storyteller = pipeline( task="text2text-generation", model=model, tokenizer=tokenizer, device_map="auto", temperature=0.7, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=3, max_new_tokens=120 ) load_time = time.time() - t0 st.text(f"✅ Story model loaded in {load_time:.1f}s (cached)") return caption_client, storyteller caption_client, storyteller = load_clients() # —––––––– Helpers —––––––– def generate_caption(img: Image.Image) -> str: buf = BytesIO() img.save(buf, format="JPEG") resp = caption_client(data=buf.getvalue()) if isinstance(resp, list) and resp: return resp[0].get("generated_text", "").strip() return "" def generate_story(caption: str) -> str: prompt = ( "You are a creative children's-story author.\n" f"Image description: “{caption}”\n\n" "Write a coherent 50–100 word story that:\n" "1. Introduces the main character.\n" "2. Shows a simple problem or discovery.\n" "3. Has a happy resolution.\n" "4. Uses clear language for ages 3–8.\n" "5. Keeps each sentence under 20 words.\n" ) t0 = time.time() result = storyteller(prompt) gen_time = time.time() - t0 st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU") story = result[0]["generated_text"].strip() # Enforce ≤100 words words = story.split() if len(words) > 100: story = " ".join(words[:100]) if not story.endswith('.'): story += '.' return story # —––––––– Main App —––––––– uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"]) if uploaded: img = Image.open(uploaded).convert("RGB") if max(img.size) > 2048: img.thumbnail((2048,2048)) st.image(img, use_container_width=True) with st.spinner("🔍 Generating caption..."): caption = generate_caption(img) if not caption: st.error("😢 Couldn't understand this image. Try another one!") st.stop() st.success(f"**Caption:** {caption}") with st.spinner("📝 Writing story..."): story = generate_story(caption) st.subheader("📚 Your Magical Story") st.write(story) with st.spinner("🔊 Converting to audio..."): try: tts = gTTS(text=story, lang="en", slow=False) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp: tts.save(fp.name) st.audio(fp.name, format="audio/mp3") except Exception as e: st.warning(f"⚠️ TTS failed: {e}") # Footer st.markdown("---\n*Made with ❤️ by your friendly story wizard*")