import streamlit as st from PIL import Image from io import BytesIO from huggingface_hub import InferenceApi from gtts import gTTS import requests import tempfile import time import threading # —––––––– Page Config —––––––– st.set_page_config(page_title="Magic Story Generator", layout="centered") st.title("📖✨ Turn Images into Children's Stories") # —––––––– Clients (cached) —––––––– @st.cache_resource def load_clients(): hf_token = st.secrets["HF_TOKEN"] caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token) # Start background keep-alive for story model to avoid cold starts api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" headers = {"Authorization": f"Bearer {hf_token}"} keep_alive_payload = { "inputs": "Hello!", "parameters": {"max_new_tokens": 1} } def keep_model_warm(): # Initial warm-up try: requests.post(api_url, headers=headers, json=keep_alive_payload, timeout=10) except: pass # Periodic keep-alive every 10 minutes while True: time.sleep(600) try: requests.post(api_url, headers=headers, json=keep_alive_payload, timeout=10) except: pass threading.Thread(target=keep_model_warm, daemon=True).start() return caption_client, hf_token caption_client, hf_token = load_clients() # —––––––– Helper: Generate Caption —––––––– def generate_caption(img): img_bytes = BytesIO() img.save(img_bytes, format="JPEG") try: result = caption_client(data=img_bytes.getvalue()) if isinstance(result, list) and result: return result[0].get("generated_text", "").strip() except Exception as e: st.error(f"Caption generation error: {type(e).__name__}: {e}") return "" # —––––––– Helper: Process Image —––––––– def process_image(uploaded_file): try: img = Image.open(uploaded_file).convert("RGB") if max(img.size) > 2048: img.thumbnail((2048, 2048)) return img except Exception as e: st.error(f"Image processing error: {type(e).__name__}: {e}") st.stop() # —––––––– Helper: Generate Story —––––––– def generate_story(prompt: str) -> str: api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" headers = {"Authorization": f"Bearer {hf_token}"} payload = { "inputs": prompt, "parameters": { "max_new_tokens": 200, "temperature": 0.8, "top_p": 0.95, "repetition_penalty": 1.15, "do_sample": True, "no_repeat_ngram_size": 2 } } max_retries = 5 retries = 0 while True: try: resp = requests.post(api_url, headers=headers, json=payload, timeout=30) except Exception as e: st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}") st.stop() if resp.status_code == 200: data = resp.json() if isinstance(data, list) and data: text = data[0].get("generated_text", "").strip() story = text.split("Story:")[-1].strip() if "." in story: story = story.rsplit(".", 1)[0] + "." return story else: st.error("🚨 Story magic failed: invalid response format") st.stop() if resp.status_code == 503 and retries < max_retries: try: wait = int(resp.json().get("estimated_time", 5)) except: wait = 5 * (2 ** retries) st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})") time.sleep(wait) retries += 1 continue st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}") st.stop() # —––––––– Main App Flow —––––––– uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"]) if uploaded: img = process_image(uploaded) st.image(img, use_container_width=True) # Generate Caption with st.spinner("🔍 Discovering image secrets..."): caption = generate_caption(img) if not caption: st.error("😢 Couldn't understand this image. Try another one!") st.stop() st.success(f"**Caption:** {caption}") # Prepare Story Prompt story_prompt = ( f"Image description: {caption}\n\n" "Write a 50-100 word children's story that:\n" "1. Features the main subject as a friendly character\n" "2. Includes a simple adventure or discovery\n" "3. Ends with a happy or funny conclusion\n" "4. Uses simple language for ages 3-8\n\n" "Story:\n" ) # Generate and Display Story with st.spinner("📝 Writing magical story..."): story = generate_story(story_prompt) st.subheader("📚 Your Magical Story") st.write(story) # Audio Conversion with st.spinner("🔊 Adding story voice..."): try: tts = gTTS(text=story, lang="en", slow=False) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp: tts.save(fp.name) st.audio(fp.name, format="audio/mp3") except Exception as e: st.warning(f"⚠️ Couldn't make audio version: {type(e).__name__}: {e}") # Footer st.markdown("---\n*Made with ❤️ by your friendly story wizard*")