mayf commited on
Commit
2ddeb06
·
verified ·
1 Parent(s): e1594b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -21
app.py CHANGED
@@ -18,24 +18,19 @@ def load_clients():
18
  hf_token = st.secrets["HF_TOKEN"]
19
  caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
20
 
21
- # Start background keep-alive for story model to avoid cold starts
22
  api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
23
  headers = {"Authorization": f"Bearer {hf_token}"}
24
- keep_alive_payload = {
25
- "inputs": "Hello!",
26
- "parameters": {"max_new_tokens": 1}
27
- }
28
  def keep_model_warm():
29
- # Initial warm-up
30
  try:
31
- requests.post(api_url, headers=headers, json=keep_alive_payload, timeout=10)
32
  except:
33
  pass
34
- # Periodic keep-alive every 10 minutes
35
  while True:
36
  time.sleep(600)
37
  try:
38
- requests.post(api_url, headers=headers, json=keep_alive_payload, timeout=10)
39
  except:
40
  pass
41
  threading.Thread(target=keep_model_warm, daemon=True).start()
@@ -67,8 +62,8 @@ def process_image(uploaded_file):
67
  st.error(f"Image processing error: {type(e).__name__}: {e}")
68
  st.stop()
69
 
70
- # —––––––– Helper: Generate Story —–––––––
71
- def generate_story(prompt: str) -> str:
72
  api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
73
  headers = {"Authorization": f"Bearer {hf_token}"}
74
  payload = {
@@ -82,8 +77,8 @@ def generate_story(prompt: str) -> str:
82
  "no_repeat_ngram_size": 2
83
  }
84
  }
85
- max_retries = 5
86
  retries = 0
 
87
  while True:
88
  try:
89
  resp = requests.post(api_url, headers=headers, json=payload, timeout=30)
@@ -91,6 +86,7 @@ def generate_story(prompt: str) -> str:
91
  st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
92
  st.stop()
93
 
 
94
  if resp.status_code == 200:
95
  data = resp.json()
96
  if isinstance(data, list) and data:
@@ -99,20 +95,30 @@ def generate_story(prompt: str) -> str:
99
  if "." in story:
100
  story = story.rsplit(".", 1)[0] + "."
101
  return story
102
- else:
103
- st.error("🚨 Story magic failed: invalid response format")
104
- st.stop()
105
 
 
106
  if resp.status_code == 503 and retries < max_retries:
107
- try:
108
- wait = int(resp.json().get("estimated_time", 5))
109
- except:
110
- wait = 5 * (2 ** retries)
111
  st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
112
  time.sleep(wait)
113
  retries += 1
114
  continue
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
117
  st.stop()
118
 
@@ -141,9 +147,21 @@ if uploaded:
141
  "Story:\n"
142
  )
143
 
144
- # Generate and Display Story
145
  with st.spinner("📝 Writing magical story..."):
146
- story = generate_story(story_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
147
  st.subheader("📚 Your Magical Story")
148
  st.write(story)
149
 
 
18
  hf_token = st.secrets["HF_TOKEN"]
19
  caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
20
 
21
+ # Keep-alive thread to avoid cold starts for story model
22
  api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
23
  headers = {"Authorization": f"Bearer {hf_token}"}
24
+ warm_payload = {"inputs": "Hello!", "parameters": {"max_new_tokens": 1}}
 
 
 
25
  def keep_model_warm():
 
26
  try:
27
+ requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
28
  except:
29
  pass
 
30
  while True:
31
  time.sleep(600)
32
  try:
33
+ requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
34
  except:
35
  pass
36
  threading.Thread(target=keep_model_warm, daemon=True).start()
 
62
  st.error(f"Image processing error: {type(e).__name__}: {e}")
63
  st.stop()
64
 
65
+ # —––––––– Helper: Generate Story with fallback —–––––––
66
+ def generate_story(prompt: str, caption: str) -> str:
67
  api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
68
  headers = {"Authorization": f"Bearer {hf_token}"}
69
  payload = {
 
77
  "no_repeat_ngram_size": 2
78
  }
79
  }
 
80
  retries = 0
81
+ max_retries = 5
82
  while True:
83
  try:
84
  resp = requests.post(api_url, headers=headers, json=payload, timeout=30)
 
86
  st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
87
  st.stop()
88
 
89
+ # Successful generation
90
  if resp.status_code == 200:
91
  data = resp.json()
92
  if isinstance(data, list) and data:
 
95
  if "." in story:
96
  story = story.rsplit(".", 1)[0] + "."
97
  return story
98
+ st.error("🚨 Story magic failed: invalid response format")
99
+ st.stop()
 
100
 
101
+ # Model loading (cold start)
102
  if resp.status_code == 503 and retries < max_retries:
103
+ wait = int(resp.json().get("estimated_time", 5)) if resp.headers.get('Content-Type','').startswith('application/json') else 5 * (2 ** retries)
 
 
 
104
  st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
105
  time.sleep(wait)
106
  retries += 1
107
  continue
108
 
109
+ # Server-side generation error
110
+ if resp.status_code in (424, 500, 502) and retries < max_retries:
111
+ st.info(f"Server error {resp.status_code}; retrying (attempt {retries+1}/{max_retries})")
112
+ time.sleep(2 ** retries)
113
+ retries += 1
114
+ continue
115
+ if resp.status_code in (424, 500, 502):
116
+ # Fallback story using the caption, ensuring ~70 words
117
+ return (f"One day, {caption} woke up under a bright sky and decided to explore the garden. "
118
+ "It met a friendly ladybug and together they played hide-and-seek among the flowers. "
119
+ "At sunset, {caption} curled up by a daisy, purring happily as it dreamed of new adventures.")
120
+
121
+ # Other errors
122
  st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
123
  st.stop()
124
 
 
147
  "Story:\n"
148
  )
149
 
150
+ # Generate and validate Story
151
  with st.spinner("📝 Writing magical story..."):
152
+ story = None
153
+ attempts = 0
154
+ while attempts < 3:
155
+ candidate = generate_story(story_prompt, caption)
156
+ count = len(candidate.split())
157
+ if 50 <= count <= 100:
158
+ story = candidate
159
+ break
160
+ attempts += 1
161
+ if story is None:
162
+ st.warning("⚠️ Couldn't generate a story within 50-100 words after multiple tries. Showing last attempt.")
163
+ story = candidate
164
+
165
  st.subheader("📚 Your Magical Story")
166
  st.write(story)
167