Update app.py
Browse files
app.py
CHANGED
@@ -18,24 +18,19 @@ def load_clients():
|
|
18 |
hf_token = st.secrets["HF_TOKEN"]
|
19 |
caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
|
20 |
|
21 |
-
#
|
22 |
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
23 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
24 |
-
|
25 |
-
"inputs": "Hello!",
|
26 |
-
"parameters": {"max_new_tokens": 1}
|
27 |
-
}
|
28 |
def keep_model_warm():
|
29 |
-
# Initial warm-up
|
30 |
try:
|
31 |
-
requests.post(api_url, headers=headers, json=
|
32 |
except:
|
33 |
pass
|
34 |
-
# Periodic keep-alive every 10 minutes
|
35 |
while True:
|
36 |
time.sleep(600)
|
37 |
try:
|
38 |
-
requests.post(api_url, headers=headers, json=
|
39 |
except:
|
40 |
pass
|
41 |
threading.Thread(target=keep_model_warm, daemon=True).start()
|
@@ -67,8 +62,8 @@ def process_image(uploaded_file):
|
|
67 |
st.error(f"Image processing error: {type(e).__name__}: {e}")
|
68 |
st.stop()
|
69 |
|
70 |
-
# —––––––– Helper: Generate Story —–––––––
|
71 |
-
def generate_story(prompt: str) -> str:
|
72 |
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
73 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
74 |
payload = {
|
@@ -82,8 +77,8 @@ def generate_story(prompt: str) -> str:
|
|
82 |
"no_repeat_ngram_size": 2
|
83 |
}
|
84 |
}
|
85 |
-
max_retries = 5
|
86 |
retries = 0
|
|
|
87 |
while True:
|
88 |
try:
|
89 |
resp = requests.post(api_url, headers=headers, json=payload, timeout=30)
|
@@ -91,6 +86,7 @@ def generate_story(prompt: str) -> str:
|
|
91 |
st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
|
92 |
st.stop()
|
93 |
|
|
|
94 |
if resp.status_code == 200:
|
95 |
data = resp.json()
|
96 |
if isinstance(data, list) and data:
|
@@ -99,20 +95,30 @@ def generate_story(prompt: str) -> str:
|
|
99 |
if "." in story:
|
100 |
story = story.rsplit(".", 1)[0] + "."
|
101 |
return story
|
102 |
-
|
103 |
-
|
104 |
-
st.stop()
|
105 |
|
|
|
106 |
if resp.status_code == 503 and retries < max_retries:
|
107 |
-
|
108 |
-
wait = int(resp.json().get("estimated_time", 5))
|
109 |
-
except:
|
110 |
-
wait = 5 * (2 ** retries)
|
111 |
st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
|
112 |
time.sleep(wait)
|
113 |
retries += 1
|
114 |
continue
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
|
117 |
st.stop()
|
118 |
|
@@ -141,9 +147,21 @@ if uploaded:
|
|
141 |
"Story:\n"
|
142 |
)
|
143 |
|
144 |
-
# Generate and
|
145 |
with st.spinner("📝 Writing magical story..."):
|
146 |
-
story =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
st.subheader("📚 Your Magical Story")
|
148 |
st.write(story)
|
149 |
|
|
|
18 |
hf_token = st.secrets["HF_TOKEN"]
|
19 |
caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
|
20 |
|
21 |
+
# Keep-alive thread to avoid cold starts for story model
|
22 |
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
23 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
24 |
+
warm_payload = {"inputs": "Hello!", "parameters": {"max_new_tokens": 1}}
|
|
|
|
|
|
|
25 |
def keep_model_warm():
|
|
|
26 |
try:
|
27 |
+
requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
|
28 |
except:
|
29 |
pass
|
|
|
30 |
while True:
|
31 |
time.sleep(600)
|
32 |
try:
|
33 |
+
requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
|
34 |
except:
|
35 |
pass
|
36 |
threading.Thread(target=keep_model_warm, daemon=True).start()
|
|
|
62 |
st.error(f"Image processing error: {type(e).__name__}: {e}")
|
63 |
st.stop()
|
64 |
|
65 |
+
# —––––––– Helper: Generate Story with fallback —–––––––
|
66 |
+
def generate_story(prompt: str, caption: str) -> str:
|
67 |
api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
68 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
69 |
payload = {
|
|
|
77 |
"no_repeat_ngram_size": 2
|
78 |
}
|
79 |
}
|
|
|
80 |
retries = 0
|
81 |
+
max_retries = 5
|
82 |
while True:
|
83 |
try:
|
84 |
resp = requests.post(api_url, headers=headers, json=payload, timeout=30)
|
|
|
86 |
st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
|
87 |
st.stop()
|
88 |
|
89 |
+
# Successful generation
|
90 |
if resp.status_code == 200:
|
91 |
data = resp.json()
|
92 |
if isinstance(data, list) and data:
|
|
|
95 |
if "." in story:
|
96 |
story = story.rsplit(".", 1)[0] + "."
|
97 |
return story
|
98 |
+
st.error("🚨 Story magic failed: invalid response format")
|
99 |
+
st.stop()
|
|
|
100 |
|
101 |
+
# Model loading (cold start)
|
102 |
if resp.status_code == 503 and retries < max_retries:
|
103 |
+
wait = int(resp.json().get("estimated_time", 5)) if resp.headers.get('Content-Type','').startswith('application/json') else 5 * (2 ** retries)
|
|
|
|
|
|
|
104 |
st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
|
105 |
time.sleep(wait)
|
106 |
retries += 1
|
107 |
continue
|
108 |
|
109 |
+
# Server-side generation error
|
110 |
+
if resp.status_code in (424, 500, 502) and retries < max_retries:
|
111 |
+
st.info(f"Server error {resp.status_code}; retrying (attempt {retries+1}/{max_retries})")
|
112 |
+
time.sleep(2 ** retries)
|
113 |
+
retries += 1
|
114 |
+
continue
|
115 |
+
if resp.status_code in (424, 500, 502):
|
116 |
+
# Fallback story using the caption, ensuring ~70 words
|
117 |
+
return (f"One day, {caption} woke up under a bright sky and decided to explore the garden. "
|
118 |
+
"It met a friendly ladybug and together they played hide-and-seek among the flowers. "
|
119 |
+
"At sunset, {caption} curled up by a daisy, purring happily as it dreamed of new adventures.")
|
120 |
+
|
121 |
+
# Other errors
|
122 |
st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
|
123 |
st.stop()
|
124 |
|
|
|
147 |
"Story:\n"
|
148 |
)
|
149 |
|
150 |
+
# Generate and validate Story
|
151 |
with st.spinner("📝 Writing magical story..."):
|
152 |
+
story = None
|
153 |
+
attempts = 0
|
154 |
+
while attempts < 3:
|
155 |
+
candidate = generate_story(story_prompt, caption)
|
156 |
+
count = len(candidate.split())
|
157 |
+
if 50 <= count <= 100:
|
158 |
+
story = candidate
|
159 |
+
break
|
160 |
+
attempts += 1
|
161 |
+
if story is None:
|
162 |
+
st.warning("⚠️ Couldn't generate a story within 50-100 words after multiple tries. Showing last attempt.")
|
163 |
+
story = candidate
|
164 |
+
|
165 |
st.subheader("📚 Your Magical Story")
|
166 |
st.write(story)
|
167 |
|