File size: 5,844 Bytes
8367fb2 7d2ac1c 6b1de29 bb9cd5a 8367fb2 e1594b2 8367fb2 c5def56 8e5f097 8367fb2 c4ae250 8367fb2 7d2ac1c e1594b2 121e41f bb9cd5a 8367fb2 c5def56 c4ae250 bb9cd5a c4ae250 c5def56 301b896 c4ae250 301b896 bb9cd5a 301b896 c4ae250 e1594b2 bb9cd5a e1594b2 bb9cd5a c5def56 301b896 c4ae250 e1594b2 c5def56 301b896 8e5f097 e1594b2 8e5f097 dfb3989 cc355a8 e1594b2 8e5f097 bb9cd5a 8e5f097 258bc7e 8e5f097 c2c4e19 8e5f097 c2c4e19 bb9cd5a 7d2ac1c c5def56 8e5f097 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi
from gtts import gTTS
import requests
import tempfile
import time
import threading
# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")
# —––––––– Clients (cached) —–––––––
@st.cache_resource
def load_clients():
hf_token = st.secrets["HF_TOKEN"]
caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
# Start background keep-alive for story model to avoid cold starts
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
headers = {"Authorization": f"Bearer {hf_token}"}
keep_alive_payload = {
"inputs": "Hello!",
"parameters": {"max_new_tokens": 1}
}
def keep_model_warm():
# Initial warm-up
try:
requests.post(api_url, headers=headers, json=keep_alive_payload, timeout=10)
except:
pass
# Periodic keep-alive every 10 minutes
while True:
time.sleep(600)
try:
requests.post(api_url, headers=headers, json=keep_alive_payload, timeout=10)
except:
pass
threading.Thread(target=keep_model_warm, daemon=True).start()
return caption_client, hf_token
caption_client, hf_token = load_clients()
# —––––––– Helper: Generate Caption —–––––––
def generate_caption(img):
img_bytes = BytesIO()
img.save(img_bytes, format="JPEG")
try:
result = caption_client(data=img_bytes.getvalue())
if isinstance(result, list) and result:
return result[0].get("generated_text", "").strip()
except Exception as e:
st.error(f"Caption generation error: {type(e).__name__}: {e}")
return ""
# —––––––– Helper: Process Image —–––––––
def process_image(uploaded_file):
try:
img = Image.open(uploaded_file).convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048, 2048))
return img
except Exception as e:
st.error(f"Image processing error: {type(e).__name__}: {e}")
st.stop()
# —––––––– Helper: Generate Story —–––––––
def generate_story(prompt: str) -> str:
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
headers = {"Authorization": f"Bearer {hf_token}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 200,
"temperature": 0.8,
"top_p": 0.95,
"repetition_penalty": 1.15,
"do_sample": True,
"no_repeat_ngram_size": 2
}
}
max_retries = 5
retries = 0
while True:
try:
resp = requests.post(api_url, headers=headers, json=payload, timeout=30)
except Exception as e:
st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
st.stop()
if resp.status_code == 200:
data = resp.json()
if isinstance(data, list) and data:
text = data[0].get("generated_text", "").strip()
story = text.split("Story:")[-1].strip()
if "." in story:
story = story.rsplit(".", 1)[0] + "."
return story
else:
st.error("🚨 Story magic failed: invalid response format")
st.stop()
if resp.status_code == 503 and retries < max_retries:
try:
wait = int(resp.json().get("estimated_time", 5))
except:
wait = 5 * (2 ** retries)
st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
time.sleep(wait)
retries += 1
continue
st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
st.stop()
# —––––––– Main App Flow —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
img = process_image(uploaded)
st.image(img, use_container_width=True)
# Generate Caption
with st.spinner("🔍 Discovering image secrets..."):
caption = generate_caption(img)
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
# Prepare Story Prompt
story_prompt = (
f"Image description: {caption}\n\n"
"Write a 50-100 word children's story that:\n"
"1. Features the main subject as a friendly character\n"
"2. Includes a simple adventure or discovery\n"
"3. Ends with a happy or funny conclusion\n"
"4. Uses simple language for ages 3-8\n\n"
"Story:\n"
)
# Generate and Display Story
with st.spinner("📝 Writing magical story..."):
story = generate_story(story_prompt)
st.subheader("📚 Your Magical Story")
st.write(story)
# Audio Conversion
with st.spinner("🔊 Adding story voice..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ Couldn't make audio version: {type(e).__name__}: {e}")
# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
|