mayf commited on
Commit
e1594b2
·
verified ·
1 Parent(s): bb9cd5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -29
app.py CHANGED
@@ -5,6 +5,8 @@ from huggingface_hub import InferenceApi
5
  from gtts import gTTS
6
  import requests
7
  import tempfile
 
 
8
 
9
  # —––––––– Page Config —–––––––
10
  st.set_page_config(page_title="Magic Story Generator", layout="centered")
@@ -14,10 +16,31 @@ st.title("📖✨ Turn Images into Children's Stories")
14
  @st.cache_resource
15
  def load_clients():
16
  hf_token = st.secrets["HF_TOKEN"]
17
- return (
18
- InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token),
19
- hf_token # we'll use direct requests for story generation
20
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  caption_client, hf_token = load_clients()
23
 
@@ -44,7 +67,7 @@ def process_image(uploaded_file):
44
  st.error(f"Image processing error: {type(e).__name__}: {e}")
45
  st.stop()
46
 
47
- # —––––––– Helper: Generate Story via HTTP —–––––––
48
  def generate_story(prompt: str) -> str:
49
  api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
50
  headers = {"Authorization": f"Bearer {hf_token}"}
@@ -59,25 +82,39 @@ def generate_story(prompt: str) -> str:
59
  "no_repeat_ngram_size": 2
60
  }
61
  }
62
- try:
63
- resp = requests.post(api_url, headers=headers, json=payload, timeout=30)
64
- except Exception as e:
65
- st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
66
- st.stop()
67
- if resp.status_code != 200:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
69
  st.stop()
70
- data = resp.json()
71
- # Expecting list of generations
72
- if isinstance(data, list) and data:
73
- text = data[0].get("generated_text", "").strip()
74
- # extract after "Story:" if present
75
- story = text.split("Story:")[-1].strip()
76
- if "." in story:
77
- story = story.rsplit(".", 1)[0] + "."
78
- return story
79
- st.error("🚨 Story magic failed: invalid response format")
80
- st.stop()
81
 
82
  # —––––––– Main App Flow —–––––––
83
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
@@ -85,7 +122,7 @@ if uploaded:
85
  img = process_image(uploaded)
86
  st.image(img, use_container_width=True)
87
 
88
- # Caption
89
  with st.spinner("🔍 Discovering image secrets..."):
90
  caption = generate_caption(img)
91
  if not caption:
@@ -93,7 +130,7 @@ if uploaded:
93
  st.stop()
94
  st.success(f"**Caption:** {caption}")
95
 
96
- # Prepare Prompt
97
  story_prompt = (
98
  f"Image description: {caption}\n\n"
99
  "Write a 50-100 word children's story that:\n"
@@ -104,11 +141,9 @@ if uploaded:
104
  "Story:\n"
105
  )
106
 
107
- # Generate Story
108
  with st.spinner("📝 Writing magical story..."):
109
  story = generate_story(story_prompt)
110
-
111
- # Display Story
112
  st.subheader("📚 Your Magical Story")
113
  st.write(story)
114
 
@@ -124,5 +159,3 @@ if uploaded:
124
 
125
  # Footer
126
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
127
-
128
-
 
5
  from gtts import gTTS
6
  import requests
7
  import tempfile
8
+ import time
9
+ import threading
10
 
11
  # —––––––– Page Config —–––––––
12
  st.set_page_config(page_title="Magic Story Generator", layout="centered")
 
16
  @st.cache_resource
17
  def load_clients():
18
  hf_token = st.secrets["HF_TOKEN"]
19
+ caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
20
+
21
+ # Start background keep-alive for story model to avoid cold starts
22
+ api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
23
+ headers = {"Authorization": f"Bearer {hf_token}"}
24
+ keep_alive_payload = {
25
+ "inputs": "Hello!",
26
+ "parameters": {"max_new_tokens": 1}
27
+ }
28
+ def keep_model_warm():
29
+ # Initial warm-up
30
+ try:
31
+ requests.post(api_url, headers=headers, json=keep_alive_payload, timeout=10)
32
+ except:
33
+ pass
34
+ # Periodic keep-alive every 10 minutes
35
+ while True:
36
+ time.sleep(600)
37
+ try:
38
+ requests.post(api_url, headers=headers, json=keep_alive_payload, timeout=10)
39
+ except:
40
+ pass
41
+ threading.Thread(target=keep_model_warm, daemon=True).start()
42
+
43
+ return caption_client, hf_token
44
 
45
  caption_client, hf_token = load_clients()
46
 
 
67
  st.error(f"Image processing error: {type(e).__name__}: {e}")
68
  st.stop()
69
 
70
+ # —––––––– Helper: Generate Story —–––––––
71
  def generate_story(prompt: str) -> str:
72
  api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
73
  headers = {"Authorization": f"Bearer {hf_token}"}
 
82
  "no_repeat_ngram_size": 2
83
  }
84
  }
85
+ max_retries = 5
86
+ retries = 0
87
+ while True:
88
+ try:
89
+ resp = requests.post(api_url, headers=headers, json=payload, timeout=30)
90
+ except Exception as e:
91
+ st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
92
+ st.stop()
93
+
94
+ if resp.status_code == 200:
95
+ data = resp.json()
96
+ if isinstance(data, list) and data:
97
+ text = data[0].get("generated_text", "").strip()
98
+ story = text.split("Story:")[-1].strip()
99
+ if "." in story:
100
+ story = story.rsplit(".", 1)[0] + "."
101
+ return story
102
+ else:
103
+ st.error("🚨 Story magic failed: invalid response format")
104
+ st.stop()
105
+
106
+ if resp.status_code == 503 and retries < max_retries:
107
+ try:
108
+ wait = int(resp.json().get("estimated_time", 5))
109
+ except:
110
+ wait = 5 * (2 ** retries)
111
+ st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
112
+ time.sleep(wait)
113
+ retries += 1
114
+ continue
115
+
116
  st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
117
  st.stop()
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # —––––––– Main App Flow —–––––––
120
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
 
122
  img = process_image(uploaded)
123
  st.image(img, use_container_width=True)
124
 
125
+ # Generate Caption
126
  with st.spinner("🔍 Discovering image secrets..."):
127
  caption = generate_caption(img)
128
  if not caption:
 
130
  st.stop()
131
  st.success(f"**Caption:** {caption}")
132
 
133
+ # Prepare Story Prompt
134
  story_prompt = (
135
  f"Image description: {caption}\n\n"
136
  "Write a 50-100 word children's story that:\n"
 
141
  "Story:\n"
142
  )
143
 
144
+ # Generate and Display Story
145
  with st.spinner("📝 Writing magical story..."):
146
  story = generate_story(story_prompt)
 
 
147
  st.subheader("📚 Your Magical Story")
148
  st.write(story)
149
 
 
159
 
160
  # Footer
161
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")