mayf commited on
Commit
422a749
·
verified ·
1 Parent(s): 96d517c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -135
app.py CHANGED
@@ -2,12 +2,9 @@ import streamlit as st
2
  from PIL import Image
3
  from io import BytesIO
4
  from huggingface_hub import InferenceApi
 
5
  from gtts import gTTS
6
- import requests
7
- from requests.exceptions import ReadTimeout
8
  import tempfile
9
- import time
10
- import threading
11
 
12
  # —––––––– Page Config —–––––––
13
  st.set_page_config(page_title="Magic Story Generator", layout="centered")
@@ -17,137 +14,69 @@ st.title("📖✨ Turn Images into Children's Stories")
17
  @st.cache_resource
18
  def load_clients():
19
  hf_token = st.secrets["HF_TOKEN"]
 
20
  caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
 
 
 
 
 
 
 
 
 
21
 
22
- # Keep-alive thread to avoid cold starts for story model
23
- api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
24
- headers = {"Authorization": f"Bearer {hf_token}"}
25
- warm_payload = {"inputs": "Hello!", "parameters": {"max_new_tokens": 1}}
26
- def keep_model_warm():
27
- try:
28
- requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
29
- except:
30
- pass
31
- while True:
32
- time.sleep(600)
33
- try:
34
- requests.post(api_url, headers=headers, json=warm_payload, timeout=10)
35
- except:
36
- pass
37
- threading.Thread(target=keep_model_warm, daemon=True).start()
38
-
39
- return caption_client, hf_token
40
-
41
- caption_client, hf_token = load_clients()
42
 
43
  # —––––––– Helper: Generate Caption —–––––––
44
  def generate_caption(img):
45
- img_bytes = BytesIO()
46
- img.save(img_bytes, format="JPEG")
47
- try:
48
- result = caption_client(data=img_bytes.getvalue())
49
- if isinstance(result, list) and result:
50
- return result[0].get("generated_text", "").strip()
51
- except Exception as e:
52
- st.error(f"Caption generation error: {type(e).__name__}: {e}")
53
- return ""
54
-
55
- # —––––––– Helper: Process Image —–––––––
56
- def process_image(uploaded_file):
57
  try:
58
- img = Image.open(uploaded_file).convert("RGB")
59
- if max(img.size) > 2048:
60
- img.thumbnail((2048, 2048))
61
- return img
62
  except Exception as e:
63
- st.error(f"Image processing error: {type(e).__name__}: {e}")
64
- st.stop()
65
-
66
- # —––––––– Helper: Generate Story with improved retry and timeout —–––––––
67
- def generate_story(prompt: str, caption: str) -> str:
68
- api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
69
- headers = {"Authorization": f"Bearer {hf_token}"}
70
- payload = {
71
- "inputs": prompt,
72
- "parameters": {
73
- "max_new_tokens": 200,
74
- "temperature": 0.8,
75
- "top_p": 0.95,
76
- "repetition_penalty": 1.15,
77
- "do_sample": True,
78
- "no_repeat_ngram_size": 2
79
- }
80
- }
81
- retries = 0
82
- max_retries = 5
83
- timeout = 60 # allow up to 60s for large model
84
- while True:
85
- try:
86
- resp = requests.post(api_url, headers=headers, json=payload, timeout=timeout)
87
- except ReadTimeout:
88
- if retries < max_retries:
89
- wait = 2 ** retries
90
- st.info(f"Request timed out; retrying in {wait}s (attempt {retries+1}/{max_retries})")
91
- time.sleep(wait)
92
- retries += 1
93
- continue
94
- st.error("🚨 Story magic failed: request timed out after multiple attempts.")
95
- st.stop()
96
- except Exception as e:
97
- st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
98
- st.stop()
99
-
100
- # Successful generation
101
- if resp.status_code == 200:
102
- data = resp.json()
103
- if isinstance(data, list) and data:
104
- text = data[0].get("generated_text", "").strip()
105
- story = text.split("Story:")[-1].strip()
106
- if "." in story:
107
- story = story.rsplit(".", 1)[0] + "."
108
- return story
109
- st.error("🚨 Story magic failed: invalid response format")
110
- st.stop()
111
-
112
- # Model loading (cold start)
113
- if resp.status_code == 503 and retries < max_retries:
114
- wait = int(resp.json().get("estimated_time", 5))
115
- st.info(f"Model loading; retrying in {wait}s (attempt {retries+1}/{max_retries})")
116
- time.sleep(wait)
117
- retries += 1
118
- continue
119
-
120
- # Server-side generation error
121
- if resp.status_code in (424, 500, 502) and retries < max_retries:
122
- wait = 2 ** retries
123
- st.info(f"Server error {resp.status_code}; retrying in {wait}s (attempt {retries+1}/{max_retries})")
124
- time.sleep(wait)
125
- retries += 1
126
- continue
127
- if resp.status_code in (424, 500, 502):
128
- return (f"One day, {caption} woke up under a bright sky and decided to explore the garden. "
129
- "It met a friendly ladybug and together they played hide-and-seek among the flowers. "
130
- "At sunset, {caption} curled up by a daisy, purring happily as it dreamed of new adventures.")
131
-
132
- # Other errors
133
- st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
134
- st.stop()
135
 
136
  # —––––––– Main App Flow —–––––––
137
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
138
  if uploaded:
139
- img = process_image(uploaded)
 
 
140
  st.image(img, use_container_width=True)
141
 
142
- # Generate Caption
143
- with st.spinner("🔍 Discovering image secrets..."):
144
- caption = generate_caption(img)
145
- if not caption:
146
- st.error("😢 Couldn't understand this image. Try another one!")
147
- st.stop()
148
  st.success(f"**Caption:** {caption}")
149
 
150
- # Prepare Story Prompt
151
  story_prompt = (
152
  f"Image description: {caption}\n\n"
153
  "Write a 50-100 word children's story that:\n"
@@ -158,21 +87,8 @@ if uploaded:
158
  "Story:\n"
159
  )
160
 
161
- # Generate and validate Story
162
  with st.spinner("📝 Writing magical story..."):
163
- story = None
164
- attempts = 0
165
- while attempts < 3:
166
- candidate = generate_story(story_prompt, caption)
167
- count = len(candidate.split())
168
- if 50 <= count <= 100:
169
- story = candidate
170
- break
171
- attempts += 1
172
- if story is None:
173
- st.warning("⚠️ Couldn't generate a story within 50-100 words after multiple tries. Showing last attempt.")
174
- story = candidate
175
-
176
  st.subheader("📚 Your Magical Story")
177
  st.write(story)
178
 
@@ -184,8 +100,9 @@ if uploaded:
184
  tts.save(fp.name)
185
  st.audio(fp.name, format="audio/mp3")
186
  except Exception as e:
187
- st.warning(f"⚠️ Couldn't make audio version: {type(e).__name__}: {e}")
188
 
189
  # Footer
190
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
191
 
 
 
2
  from PIL import Image
3
  from io import BytesIO
4
  from huggingface_hub import InferenceApi
5
+ from transformers import pipeline
6
  from gtts import gTTS
 
 
7
  import tempfile
 
 
8
 
9
  # —––––––– Page Config —–––––––
10
  st.set_page_config(page_title="Magic Story Generator", layout="centered")
 
14
  @st.cache_resource
15
  def load_clients():
16
  hf_token = st.secrets["HF_TOKEN"]
17
+ # image captioning client as before
18
  caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
19
+ # text-generation pipeline for story
20
+ story_generator = pipeline(
21
+ "text-generation",
22
+ model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
23
+ tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
24
+ use_auth_token=hf_token,
25
+ device=0 # or -1 for CPU
26
+ )
27
+ return caption_client, story_generator
28
 
29
+ caption_client, story_generator = load_clients()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # —––––––– Helper: Generate Caption —–––––––
32
  def generate_caption(img):
33
+ buf = BytesIO()
34
+ img.save(buf, format="JPEG")
 
 
 
 
 
 
 
 
 
 
35
  try:
36
+ out = caption_client(data=buf.getvalue())
37
+ return out[0].get("generated_text", "").strip()
 
 
38
  except Exception as e:
39
+ st.error(f"Caption error: {e}")
40
+ return ""
41
+
42
+ # —––––––– Helper: Generate Story via pipeline —–––––––
43
+ def generate_story(prompt: str) -> str:
44
+ # generate up to ~200 tokens to cover 100 words margin
45
+ outputs = story_generator(
46
+ prompt,
47
+ max_new_tokens=200,
48
+ temperature=0.8,
49
+ top_p=0.95,
50
+ repetition_penalty=1.15,
51
+ no_repeat_ngram_size=2,
52
+ do_sample=True
53
+ )
54
+ text = outputs[0]["generated_text"].strip()
55
+ # everything after "Story:" (if you kept that in your prompt)
56
+ if "Story:" in text:
57
+ text = text.split("Story:", 1)[1].strip()
58
+ # truncate to 100 words
59
+ words = text.split()
60
+ if len(words) > 100:
61
+ text = " ".join(words[:100])
62
+ if not text.endswith("."):
63
+ text += "."
64
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # —––––––– Main App Flow —–––––––
67
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
68
  if uploaded:
69
+ img = Image.open(uploaded).convert("RGB")
70
+ if max(img.size) > 2048:
71
+ img.thumbnail((2048, 2048))
72
  st.image(img, use_container_width=True)
73
 
74
+ caption = generate_caption(img)
75
+ if not caption:
76
+ st.error("😢 Couldn't understand this image. Try another one!")
77
+ st.stop()
 
 
78
  st.success(f"**Caption:** {caption}")
79
 
 
80
  story_prompt = (
81
  f"Image description: {caption}\n\n"
82
  "Write a 50-100 word children's story that:\n"
 
87
  "Story:\n"
88
  )
89
 
 
90
  with st.spinner("📝 Writing magical story..."):
91
+ story = generate_story(story_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
92
  st.subheader("📚 Your Magical Story")
93
  st.write(story)
94
 
 
100
  tts.save(fp.name)
101
  st.audio(fp.name, format="audio/mp3")
102
  except Exception as e:
103
+ st.warning(f"⚠️ Couldn't make audio version: {e}")
104
 
105
  # Footer
106
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
107
 
108
+