1
File size: 7,395 Bytes
8367fb2
 
7d2ac1c
 
6b1de29
bb9cd5a
96d517c
8367fb2
e1594b2
 
8367fb2
c5def56
8e5f097
 
8367fb2
c4ae250
8367fb2
7d2ac1c
 
e1594b2
 
2ddeb06
e1594b2
 
2ddeb06
e1594b2
 
2ddeb06
e1594b2
 
 
 
 
2ddeb06
e1594b2
 
 
 
 
121e41f
bb9cd5a
8367fb2
c5def56
c4ae250
 
 
 
 
 
 
 
bb9cd5a
 
c4ae250
c5def56
301b896
 
 
c4ae250
301b896
 
 
bb9cd5a
301b896
c4ae250
96d517c
2ddeb06
bb9cd5a
 
 
 
 
 
 
 
 
 
 
 
 
e1594b2
2ddeb06
96d517c
e1594b2
 
96d517c
 
 
 
 
 
 
 
 
 
e1594b2
 
 
 
2ddeb06
e1594b2
 
 
 
 
 
 
 
2ddeb06
 
e1594b2
2ddeb06
e1594b2
96d517c
e1594b2
 
 
 
 
2ddeb06
 
96d517c
 
 
2ddeb06
 
 
 
 
 
 
 
bb9cd5a
 
 
c5def56
301b896
 
 
c4ae250
 
e1594b2
c5def56
301b896
 
 
 
 
8e5f097
e1594b2
8e5f097
 
 
 
 
 
 
 
dfb3989
cc355a8
2ddeb06
8e5f097
2ddeb06
 
 
 
 
 
 
 
 
 
 
 
 
8e5f097
 
258bc7e
8e5f097
 
c2c4e19
8e5f097
 
 
 
c2c4e19
bb9cd5a
7d2ac1c
c5def56
8e5f097
96d517c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi
from gtts import gTTS
import requests
from requests.exceptions import ReadTimeout
import tempfile
import time
import threading

# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")

# —––––––– Clients (cached) —–––––––
@st.cache_resource
def load_clients():
    hf_token = st.secrets["HF_TOKEN"]
    caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)

    # Keep-alive thread to avoid cold starts for story model
    api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    headers = {"Authorization": f"Bearer {hf_token}"}
    warm_payload = {"inputs": "Hello!", "parameters": {"max_new_tokens": 1}}
    def keep_model_warm():
        try:
            requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
        except:
            pass
        while True:
            time.sleep(600)
            try:
                requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
            except:
                pass
    threading.Thread(target=keep_model_warm, daemon=True).start()

    return caption_client, hf_token

caption_client, hf_token = load_clients()

# —––––––– Helper: Generate Caption —–––––––
def generate_caption(img):
    img_bytes = BytesIO()
    img.save(img_bytes, format="JPEG")
    try:
        result = caption_client(data=img_bytes.getvalue())
        if isinstance(result, list) and result:
            return result[0].get("generated_text", "").strip()
    except Exception as e:
        st.error(f"Caption generation error: {type(e).__name__}: {e}")
    return ""

# —––––––– Helper: Process Image —–––––––
def process_image(uploaded_file):
    try:
        img = Image.open(uploaded_file).convert("RGB")
        if max(img.size) > 2048:
            img.thumbnail((2048, 2048))
        return img
    except Exception as e:
        st.error(f"Image processing error: {type(e).__name__}: {e}")
        st.stop()

# —––––––– Helper: Generate Story with improved retry and timeout —–––––––
def generate_story(prompt: str, caption: str) -> str:
    api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    headers = {"Authorization": f"Bearer {hf_token}"}
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens":       200,
            "temperature":          0.8,
            "top_p":                0.95,
            "repetition_penalty":   1.15,
            "do_sample":            True,
            "no_repeat_ngram_size": 2
        }
    }
    retries = 0
    max_retries = 5
    timeout = 60  # allow up to 60s for large model
    while True:
        try:
            resp = requests.post(api_url, headers=headers, json=payload, timeout=timeout)
        except ReadTimeout:
            if retries < max_retries:
                wait = 2 ** retries
                st.info(f"Request timed out; retrying in {wait}s (attempt {retries+1}/{max_retries})")
                time.sleep(wait)
                retries += 1
                continue
            st.error("🚨 Story magic failed: request timed out after multiple attempts.")
            st.stop()
        except Exception as e:
            st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
            st.stop()

        # Successful generation
        if resp.status_code == 200:
            data = resp.json()
            if isinstance(data, list) and data:
                text = data[0].get("generated_text", "").strip()
                story = text.split("Story:")[-1].strip()
                if "." in story:
                    story = story.rsplit(".", 1)[0] + "."
                return story
            st.error("🚨 Story magic failed: invalid response format")
            st.stop()

        # Model loading (cold start)
        if resp.status_code == 503 and retries < max_retries:
            wait = int(resp.json().get("estimated_time", 5))
            st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
            time.sleep(wait)
            retries += 1
            continue

        # Server-side generation error
        if resp.status_code in (424, 500, 502) and retries < max_retries:
            wait = 2 ** retries
            st.info(f"Server error {resp.status_code}; retrying in {wait}s (attempt {retries+1}/{max_retries})")
            time.sleep(wait)
            retries += 1
            continue
        if resp.status_code in (424, 500, 502):
            return (f"One day, {caption} woke up under a bright sky and decided to explore the garden. "
                    "It met a friendly ladybug and together they played hide-and-seek among the flowers. "
                    "At sunset, {caption} curled up by a daisy, purring happily as it dreamed of new adventures.")

        # Other errors
        st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
        st.stop()

# —––––––– Main App Flow —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
    img = process_image(uploaded)
    st.image(img, use_container_width=True)

    # Generate Caption
    with st.spinner("🔍 Discovering image secrets..."):
        caption = generate_caption(img)
        if not caption:
            st.error("😢 Couldn't understand this image. Try another one!")
            st.stop()
    st.success(f"**Caption:** {caption}")

    # Prepare Story Prompt
    story_prompt = (
        f"Image description: {caption}\n\n"
        "Write a 50-100 word children's story that:\n"
        "1. Features the main subject as a friendly character\n"
        "2. Includes a simple adventure or discovery\n"
        "3. Ends with a happy or funny conclusion\n"
        "4. Uses simple language for ages 3-8\n\n"
        "Story:\n"
    )

    # Generate and validate Story
    with st.spinner("📝 Writing magical story..."):
        story = None
        attempts = 0
        while attempts < 3:
            candidate = generate_story(story_prompt, caption)
            count = len(candidate.split())
            if 50 <= count <= 100:
                story = candidate
                break
            attempts += 1
        if story is None:
            st.warning("⚠️ Couldn't generate a story within 50-100 words after multiple tries. Showing last attempt.")
            story = candidate

    st.subheader("📚 Your Magical Story")
    st.write(story)

    # Audio Conversion
    with st.spinner("🔊 Adding story voice..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
                tts.save(fp.name)
                st.audio(fp.name, format="audio/mp3")
        except Exception as e:
            st.warning(f"⚠️ Couldn't make audio version: {type(e).__name__}: {e}")

# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")