|
import streamlit as st |
|
from PIL import Image |
|
from io import BytesIO |
|
from huggingface_hub import InferenceApi |
|
from gtts import gTTS |
|
import requests |
|
import tempfile |
|
import time |
|
import threading |
|
|
|
|
|
st.set_page_config(page_title="Magic Story Generator", layout="centered") |
|
st.title("📖✨ Turn Images into Children's Stories") |
|
|
|
|
|
@st.cache_resource |
|
def load_clients(): |
|
hf_token = st.secrets["HF_TOKEN"] |
|
caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token) |
|
|
|
|
|
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
warm_payload = {"inputs": "Hello!", "parameters": {"max_new_tokens": 1}} |
|
def keep_model_warm(): |
|
try: |
|
requests.post(api_url, headers=headers, json=warm_payload, timeout=10) |
|
except: |
|
pass |
|
while True: |
|
time.sleep(600) |
|
try: |
|
requests.post(api_url, headers=headers, json=warm_payload, timeout=10) |
|
except: |
|
pass |
|
threading.Thread(target=keep_model_warm, daemon=True).start() |
|
|
|
return caption_client, hf_token |
|
|
|
caption_client, hf_token = load_clients() |
|
|
|
|
|
def generate_caption(img): |
|
img_bytes = BytesIO() |
|
img.save(img_bytes, format="JPEG") |
|
try: |
|
result = caption_client(data=img_bytes.getvalue()) |
|
if isinstance(result, list) and result: |
|
return result[0].get("generated_text", "").strip() |
|
except Exception as e: |
|
st.error(f"Caption generation error: {type(e).__name__}: {e}") |
|
return "" |
|
|
|
|
|
def process_image(uploaded_file): |
|
try: |
|
img = Image.open(uploaded_file).convert("RGB") |
|
if max(img.size) > 2048: |
|
img.thumbnail((2048, 2048)) |
|
return img |
|
except Exception as e: |
|
st.error(f"Image processing error: {type(e).__name__}: {e}") |
|
st.stop() |
|
|
|
|
|
def generate_story(prompt: str, caption: str) -> str: |
|
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": 200, |
|
"temperature": 0.8, |
|
"top_p": 0.95, |
|
"repetition_penalty": 1.15, |
|
"do_sample": True, |
|
"no_repeat_ngram_size": 2 |
|
} |
|
} |
|
retries = 0 |
|
max_retries = 5 |
|
while True: |
|
try: |
|
resp = requests.post(api_url, headers=headers, json=payload, timeout=30) |
|
except Exception as e: |
|
st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}") |
|
st.stop() |
|
|
|
|
|
if resp.status_code == 200: |
|
data = resp.json() |
|
if isinstance(data, list) and data: |
|
text = data[0].get("generated_text", "").strip() |
|
story = text.split("Story:")[-1].strip() |
|
if "." in story: |
|
story = story.rsplit(".", 1)[0] + "." |
|
return story |
|
st.error("🚨 Story magic failed: invalid response format") |
|
st.stop() |
|
|
|
|
|
if resp.status_code == 503 and retries < max_retries: |
|
wait = int(resp.json().get("estimated_time", 5)) if resp.headers.get('Content-Type','').startswith('application/json') else 5 * (2 ** retries) |
|
st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})") |
|
time.sleep(wait) |
|
retries += 1 |
|
continue |
|
|
|
|
|
if resp.status_code in (424, 500, 502) and retries < max_retries: |
|
st.info(f"Server error {resp.status_code}; retrying (attempt {retries+1}/{max_retries})") |
|
time.sleep(2 ** retries) |
|
retries += 1 |
|
continue |
|
if resp.status_code in (424, 500, 502): |
|
|
|
return (f"One day, {caption} woke up under a bright sky and decided to explore the garden. " |
|
"It met a friendly ladybug and together they played hide-and-seek among the flowers. " |
|
"At sunset, {caption} curled up by a daisy, purring happily as it dreamed of new adventures.") |
|
|
|
|
|
st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}") |
|
st.stop() |
|
|
|
|
|
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"]) |
|
if uploaded: |
|
img = process_image(uploaded) |
|
st.image(img, use_container_width=True) |
|
|
|
|
|
with st.spinner("🔍 Discovering image secrets..."): |
|
caption = generate_caption(img) |
|
if not caption: |
|
st.error("😢 Couldn't understand this image. Try another one!") |
|
st.stop() |
|
st.success(f"**Caption:** {caption}") |
|
|
|
|
|
story_prompt = ( |
|
f"Image description: {caption}\n\n" |
|
"Write a 50-100 word children's story that:\n" |
|
"1. Features the main subject as a friendly character\n" |
|
"2. Includes a simple adventure or discovery\n" |
|
"3. Ends with a happy or funny conclusion\n" |
|
"4. Uses simple language for ages 3-8\n\n" |
|
"Story:\n" |
|
) |
|
|
|
|
|
with st.spinner("📝 Writing magical story..."): |
|
story = None |
|
attempts = 0 |
|
while attempts < 3: |
|
candidate = generate_story(story_prompt, caption) |
|
count = len(candidate.split()) |
|
if 50 <= count <= 100: |
|
story = candidate |
|
break |
|
attempts += 1 |
|
if story is None: |
|
st.warning("⚠️ Couldn't generate a story within 50-100 words after multiple tries. Showing last attempt.") |
|
story = candidate |
|
|
|
st.subheader("📚 Your Magical Story") |
|
st.write(story) |
|
|
|
|
|
with st.spinner("🔊 Adding story voice..."): |
|
try: |
|
tts = gTTS(text=story, lang="en", slow=False) |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp: |
|
tts.save(fp.name) |
|
st.audio(fp.name, format="audio/mp3") |
|
except Exception as e: |
|
st.warning(f"⚠️ Couldn't make audio version: {type(e).__name__}: {e}") |
|
|
|
|
|
st.markdown("---\n*Made with ❤️ by your friendly story wizard*") |
|
|