File size: 6,995 Bytes
8367fb2 7d2ac1c 6b1de29 bb9cd5a 8367fb2 e1594b2 8367fb2 c5def56 8e5f097 8367fb2 c4ae250 8367fb2 7d2ac1c e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 121e41f bb9cd5a 8367fb2 c5def56 c4ae250 bb9cd5a c4ae250 c5def56 301b896 c4ae250 301b896 bb9cd5a 301b896 c4ae250 2ddeb06 bb9cd5a e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 2ddeb06 e1594b2 2ddeb06 bb9cd5a c5def56 301b896 c4ae250 e1594b2 c5def56 301b896 8e5f097 e1594b2 8e5f097 dfb3989 cc355a8 2ddeb06 8e5f097 2ddeb06 8e5f097 258bc7e 8e5f097 c2c4e19 8e5f097 c2c4e19 bb9cd5a 7d2ac1c c5def56 8e5f097 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi
from gtts import gTTS
import requests
import tempfile
import time
import threading
# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")
# —––––––– Clients (cached) —–––––––
@st.cache_resource
def load_clients():
hf_token = st.secrets["HF_TOKEN"]
caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
# Keep-alive thread to avoid cold starts for story model
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
headers = {"Authorization": f"Bearer {hf_token}"}
warm_payload = {"inputs": "Hello!", "parameters": {"max_new_tokens": 1}}
def keep_model_warm():
try:
requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
except:
pass
while True:
time.sleep(600)
try:
requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
except:
pass
threading.Thread(target=keep_model_warm, daemon=True).start()
return caption_client, hf_token
caption_client, hf_token = load_clients()
# —––––––– Helper: Generate Caption —–––––––
def generate_caption(img):
img_bytes = BytesIO()
img.save(img_bytes, format="JPEG")
try:
result = caption_client(data=img_bytes.getvalue())
if isinstance(result, list) and result:
return result[0].get("generated_text", "").strip()
except Exception as e:
st.error(f"Caption generation error: {type(e).__name__}: {e}")
return ""
# —––––––– Helper: Process Image —–––––––
def process_image(uploaded_file):
try:
img = Image.open(uploaded_file).convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048, 2048))
return img
except Exception as e:
st.error(f"Image processing error: {type(e).__name__}: {e}")
st.stop()
# —––––––– Helper: Generate Story with fallback —–––––––
def generate_story(prompt: str, caption: str) -> str:
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
headers = {"Authorization": f"Bearer {hf_token}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 200,
"temperature": 0.8,
"top_p": 0.95,
"repetition_penalty": 1.15,
"do_sample": True,
"no_repeat_ngram_size": 2
}
}
retries = 0
max_retries = 5
while True:
try:
resp = requests.post(api_url, headers=headers, json=payload, timeout=30)
except Exception as e:
st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
st.stop()
# Successful generation
if resp.status_code == 200:
data = resp.json()
if isinstance(data, list) and data:
text = data[0].get("generated_text", "").strip()
story = text.split("Story:")[-1].strip()
if "." in story:
story = story.rsplit(".", 1)[0] + "."
return story
st.error("🚨 Story magic failed: invalid response format")
st.stop()
# Model loading (cold start)
if resp.status_code == 503 and retries < max_retries:
wait = int(resp.json().get("estimated_time", 5)) if resp.headers.get('Content-Type','').startswith('application/json') else 5 * (2 ** retries)
st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
time.sleep(wait)
retries += 1
continue
# Server-side generation error
if resp.status_code in (424, 500, 502) and retries < max_retries:
st.info(f"Server error {resp.status_code}; retrying (attempt {retries+1}/{max_retries})")
time.sleep(2 ** retries)
retries += 1
continue
if resp.status_code in (424, 500, 502):
# Fallback story using the caption, ensuring ~70 words
return (f"One day, {caption} woke up under a bright sky and decided to explore the garden. "
"It met a friendly ladybug and together they played hide-and-seek among the flowers. "
"At sunset, {caption} curled up by a daisy, purring happily as it dreamed of new adventures.")
# Other errors
st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
st.stop()
# —––––––– Main App Flow —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
img = process_image(uploaded)
st.image(img, use_container_width=True)
# Generate Caption
with st.spinner("🔍 Discovering image secrets..."):
caption = generate_caption(img)
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
# Prepare Story Prompt
story_prompt = (
f"Image description: {caption}\n\n"
"Write a 50-100 word children's story that:\n"
"1. Features the main subject as a friendly character\n"
"2. Includes a simple adventure or discovery\n"
"3. Ends with a happy or funny conclusion\n"
"4. Uses simple language for ages 3-8\n\n"
"Story:\n"
)
# Generate and validate Story
with st.spinner("📝 Writing magical story..."):
story = None
attempts = 0
while attempts < 3:
candidate = generate_story(story_prompt, caption)
count = len(candidate.split())
if 50 <= count <= 100:
story = candidate
break
attempts += 1
if story is None:
st.warning("⚠️ Couldn't generate a story within 50-100 words after multiple tries. Showing last attempt.")
story = candidate
st.subheader("📚 Your Magical Story")
st.write(story)
# Audio Conversion
with st.spinner("🔊 Adding story voice..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ Couldn't make audio version: {type(e).__name__}: {e}")
# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
|