File size: 7,395 Bytes
8367fb2 7d2ac1c 6b1de29 bb9cd5a 96d517c 8367fb2 e1594b2 8367fb2 c5def56 8e5f097 8367fb2 c4ae250 8367fb2 7d2ac1c e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 121e41f bb9cd5a 8367fb2 c5def56 c4ae250 bb9cd5a c4ae250 c5def56 301b896 c4ae250 301b896 bb9cd5a 301b896 c4ae250 96d517c 2ddeb06 bb9cd5a e1594b2 2ddeb06 96d517c e1594b2 96d517c e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 96d517c e1594b2 2ddeb06 96d517c 2ddeb06 bb9cd5a c5def56 301b896 c4ae250 e1594b2 c5def56 301b896 8e5f097 e1594b2 8e5f097 dfb3989 cc355a8 2ddeb06 8e5f097 2ddeb06 8e5f097 258bc7e 8e5f097 c2c4e19 8e5f097 c2c4e19 bb9cd5a 7d2ac1c c5def56 8e5f097 96d517c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi
from gtts import gTTS
import requests
from requests.exceptions import ReadTimeout
import tempfile
import time
import threading
# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")
# —––––––– Clients (cached) —–––––––
@st.cache_resource
def load_clients():
hf_token = st.secrets["HF_TOKEN"]
caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
# Keep-alive thread to avoid cold starts for story model
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
headers = {"Authorization": f"Bearer {hf_token}"}
warm_payload = {"inputs": "Hello!", "parameters": {"max_new_tokens": 1}}
def keep_model_warm():
try:
requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
except:
pass
while True:
time.sleep(600)
try:
requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
except:
pass
threading.Thread(target=keep_model_warm, daemon=True).start()
return caption_client, hf_token
caption_client, hf_token = load_clients()
# —––––––– Helper: Generate Caption —–––––––
def generate_caption(img):
img_bytes = BytesIO()
img.save(img_bytes, format="JPEG")
try:
result = caption_client(data=img_bytes.getvalue())
if isinstance(result, list) and result:
return result[0].get("generated_text", "").strip()
except Exception as e:
st.error(f"Caption generation error: {type(e).__name__}: {e}")
return ""
# —––––––– Helper: Process Image —–––––––
def process_image(uploaded_file):
try:
img = Image.open(uploaded_file).convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048, 2048))
return img
except Exception as e:
st.error(f"Image processing error: {type(e).__name__}: {e}")
st.stop()
# —––––––– Helper: Generate Story with improved retry and timeout —–––––––
def generate_story(prompt: str, caption: str) -> str:
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
headers = {"Authorization": f"Bearer {hf_token}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 200,
"temperature": 0.8,
"top_p": 0.95,
"repetition_penalty": 1.15,
"do_sample": True,
"no_repeat_ngram_size": 2
}
}
retries = 0
max_retries = 5
timeout = 60 # allow up to 60s for large model
while True:
try:
resp = requests.post(api_url, headers=headers, json=payload, timeout=timeout)
except ReadTimeout:
if retries < max_retries:
wait = 2 ** retries
st.info(f"Request timed out; retrying in {wait}s (attempt {retries+1}/{max_retries})")
time.sleep(wait)
retries += 1
continue
st.error("🚨 Story magic failed: request timed out after multiple attempts.")
st.stop()
except Exception as e:
st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
st.stop()
# Successful generation
if resp.status_code == 200:
data = resp.json()
if isinstance(data, list) and data:
text = data[0].get("generated_text", "").strip()
story = text.split("Story:")[-1].strip()
if "." in story:
story = story.rsplit(".", 1)[0] + "."
return story
st.error("🚨 Story magic failed: invalid response format")
st.stop()
# Model loading (cold start)
if resp.status_code == 503 and retries < max_retries:
wait = int(resp.json().get("estimated_time", 5))
st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
time.sleep(wait)
retries += 1
continue
# Server-side generation error
if resp.status_code in (424, 500, 502) and retries < max_retries:
wait = 2 ** retries
st.info(f"Server error {resp.status_code}; retrying in {wait}s (attempt {retries+1}/{max_retries})")
time.sleep(wait)
retries += 1
continue
if resp.status_code in (424, 500, 502):
return (f"One day, {caption} woke up under a bright sky and decided to explore the garden. "
"It met a friendly ladybug and together they played hide-and-seek among the flowers. "
"At sunset, {caption} curled up by a daisy, purring happily as it dreamed of new adventures.")
# Other errors
st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
st.stop()
# —––––––– Main App Flow —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
img = process_image(uploaded)
st.image(img, use_container_width=True)
# Generate Caption
with st.spinner("🔍 Discovering image secrets..."):
caption = generate_caption(img)
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
# Prepare Story Prompt
story_prompt = (
f"Image description: {caption}\n\n"
"Write a 50-100 word children's story that:\n"
"1. Features the main subject as a friendly character\n"
"2. Includes a simple adventure or discovery\n"
"3. Ends with a happy or funny conclusion\n"
"4. Uses simple language for ages 3-8\n\n"
"Story:\n"
)
# Generate and validate Story
with st.spinner("📝 Writing magical story..."):
story = None
attempts = 0
while attempts < 3:
candidate = generate_story(story_prompt, caption)
count = len(candidate.split())
if 50 <= count <= 100:
story = candidate
break
attempts += 1
if story is None:
st.warning("⚠️ Couldn't generate a story within 50-100 words after multiple tries. Showing last attempt.")
story = candidate
st.subheader("📚 Your Magical Story")
st.write(story)
# Audio Conversion
with st.spinner("🔊 Adding story voice..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ Couldn't make audio version: {type(e).__name__}: {e}")
# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
|