mayf commited on
Commit
bb9cd5a
·
verified ·
1 Parent(s): d1beffa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -34
app.py CHANGED
@@ -3,6 +3,7 @@ from PIL import Image
3
  from io import BytesIO
4
  from huggingface_hub import InferenceApi
5
  from gtts import gTTS
 
6
  import tempfile
7
 
8
  # —––––––– Page Config —–––––––
@@ -15,26 +16,22 @@ def load_clients():
15
  hf_token = st.secrets["HF_TOKEN"]
16
  return (
17
  InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token),
18
- InferenceApi("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", token=hf_token)
19
  )
20
 
21
- caption_client, story_client = load_clients()
22
 
23
  # —––––––– Helper: Generate Caption —–––––––
24
  def generate_caption(img):
25
- """
26
- Runs the BLIP caption model on a PIL.Image and returns the generated text.
27
- """
28
  img_bytes = BytesIO()
29
  img.save(img_bytes, format="JPEG")
30
  try:
31
  result = caption_client(data=img_bytes.getvalue())
32
  if isinstance(result, list) and result:
33
  return result[0].get("generated_text", "").strip()
34
- return ""
35
  except Exception as e:
36
- st.error(f"Caption generation error: {e}")
37
- return ""
38
 
39
  # —––––––– Helper: Process Image —–––––––
40
  def process_image(uploaded_file):
@@ -44,16 +41,51 @@ def process_image(uploaded_file):
44
  img.thumbnail((2048, 2048))
45
  return img
46
  except Exception as e:
47
- st.error(f"Image processing error: {e}")
48
  st.stop()
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # —––––––– Main App Flow —–––––––
51
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
52
  if uploaded:
53
  img = process_image(uploaded)
54
  st.image(img, use_container_width=True)
55
 
56
- # Generate Caption
57
  with st.spinner("🔍 Discovering image secrets..."):
58
  caption = generate_caption(img)
59
  if not caption:
@@ -61,7 +93,7 @@ if uploaded:
61
  st.stop()
62
  st.success(f"**Caption:** {caption}")
63
 
64
- # Prepare Story Prompt
65
  story_prompt = (
66
  f"Image description: {caption}\n\n"
67
  "Write a 50-100 word children's story that:\n"
@@ -72,29 +104,9 @@ if uploaded:
72
  "Story:\n"
73
  )
74
 
75
- # Generate Story with full payload dict
76
- payload = {
77
- "inputs": story_prompt,
78
- "parameters": {
79
- "max_new_tokens": 200,
80
- "temperature": 0.8,
81
- "top_p": 0.95,
82
- "repetition_penalty": 1.15,
83
- "do_sample": True,
84
- "no_repeat_ngram_size": 2
85
- }
86
- }
87
-
88
  with st.spinner("📝 Writing magical story..."):
89
- try:
90
- story_response = story_client(payload)
91
- full_text = story_response[0].get("generated_text", "")
92
- story = full_text.split("Story:")[-1].strip()
93
- if "." in story:
94
- story = story.rsplit(".", 1)[0] + "."
95
- except Exception as e:
96
- st.error(f"🚨 Story magic failed: {e}")
97
- st.stop()
98
 
99
  # Display Story
100
  st.subheader("📚 Your Magical Story")
@@ -108,7 +120,7 @@ if uploaded:
108
  tts.save(fp.name)
109
  st.audio(fp.name, format="audio/mp3")
110
  except Exception as e:
111
- st.warning("⚠️ Couldn't make audio version: " + str(e))
112
 
113
  # Footer
114
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
 
3
  from io import BytesIO
4
  from huggingface_hub import InferenceApi
5
  from gtts import gTTS
6
+ import requests
7
  import tempfile
8
 
9
  # —––––––– Page Config —–––––––
 
16
  hf_token = st.secrets["HF_TOKEN"]
17
  return (
18
  InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token),
19
+ hf_token # we'll use direct requests for story generation
20
  )
21
 
22
+ caption_client, hf_token = load_clients()
23
 
24
  # —––––––– Helper: Generate Caption —–––––––
25
  def generate_caption(img):
 
 
 
26
  img_bytes = BytesIO()
27
  img.save(img_bytes, format="JPEG")
28
  try:
29
  result = caption_client(data=img_bytes.getvalue())
30
  if isinstance(result, list) and result:
31
  return result[0].get("generated_text", "").strip()
 
32
  except Exception as e:
33
+ st.error(f"Caption generation error: {type(e).__name__}: {e}")
34
+ return ""
35
 
36
  # —––––––– Helper: Process Image —–––––––
37
  def process_image(uploaded_file):
 
41
  img.thumbnail((2048, 2048))
42
  return img
43
  except Exception as e:
44
+ st.error(f"Image processing error: {type(e).__name__}: {e}")
45
  st.stop()
46
 
47
+ # —––––––– Helper: Generate Story via HTTP —–––––––
48
+ def generate_story(prompt: str) -> str:
49
+ api_url = "https://api-inference.huggingface.co/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
50
+ headers = {"Authorization": f"Bearer {hf_token}"}
51
+ payload = {
52
+ "inputs": prompt,
53
+ "parameters": {
54
+ "max_new_tokens": 200,
55
+ "temperature": 0.8,
56
+ "top_p": 0.95,
57
+ "repetition_penalty": 1.15,
58
+ "do_sample": True,
59
+ "no_repeat_ngram_size": 2
60
+ }
61
+ }
62
+ try:
63
+ resp = requests.post(api_url, headers=headers, json=payload, timeout=30)
64
+ except Exception as e:
65
+ st.error(f"🚨 Story magic failed: {type(e).__name__}: {e}")
66
+ st.stop()
67
+ if resp.status_code != 200:
68
+ st.error(f"🚨 Story magic failed: HTTP {resp.status_code} - {resp.text}")
69
+ st.stop()
70
+ data = resp.json()
71
+ # Expecting list of generations
72
+ if isinstance(data, list) and data:
73
+ text = data[0].get("generated_text", "").strip()
74
+ # extract after "Story:" if present
75
+ story = text.split("Story:")[-1].strip()
76
+ if "." in story:
77
+ story = story.rsplit(".", 1)[0] + "."
78
+ return story
79
+ st.error("🚨 Story magic failed: invalid response format")
80
+ st.stop()
81
+
82
  # —––––––– Main App Flow —–––––––
83
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
84
  if uploaded:
85
  img = process_image(uploaded)
86
  st.image(img, use_container_width=True)
87
 
88
+ # Caption
89
  with st.spinner("🔍 Discovering image secrets..."):
90
  caption = generate_caption(img)
91
  if not caption:
 
93
  st.stop()
94
  st.success(f"**Caption:** {caption}")
95
 
96
+ # Prepare Prompt
97
  story_prompt = (
98
  f"Image description: {caption}\n\n"
99
  "Write a 50-100 word children's story that:\n"
 
104
  "Story:\n"
105
  )
106
 
107
+ # Generate Story
 
 
 
 
 
 
 
 
 
 
 
 
108
  with st.spinner("📝 Writing magical story..."):
109
+ story = generate_story(story_prompt)
 
 
 
 
 
 
 
 
110
 
111
  # Display Story
112
  st.subheader("📚 Your Magical Story")
 
120
  tts.save(fp.name)
121
  st.audio(fp.name, format="audio/mp3")
122
  except Exception as e:
123
+ st.warning(f"⚠️ Couldn't make audio version: {type(e).__name__}: {e}")
124
 
125
  # Footer
126
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")