mayf commited on
Commit
258bc7e
·
verified ·
1 Parent(s): 711740c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -14,16 +14,12 @@ st.title("🖼️ ➡️ 📖 Interactive Storyteller")
14
  # —––––––– Inference clients (cached)
15
  @st.cache_resource
16
  def load_clients():
17
- # read your HF token from Space secrets
18
  hf_token = st.secrets["HF_TOKEN"]
19
-
20
- # caption client: BLIP-base via HF Image-to-Text API
21
  caption_client = InferenceApi(
22
  repo_id="Salesforce/blip-image-captioning-base",
23
  task="image-to-text",
24
  token=hf_token
25
  )
26
- # story client: DeepSeek-R1-Distill via HF Text-Generation API
27
  story_client = InferenceApi(
28
  repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
29
  task="text-generation",
@@ -34,27 +30,35 @@ def load_clients():
34
  caption_client, story_client = load_clients()
35
 
36
  # —––––––– Main UI
37
- uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
38
  if not uploaded:
39
- st.info("Please upload an image (JPG/PNG) to begin.")
40
  else:
41
- # 1) Display the image
42
  img = Image.open(uploaded).convert("RGB")
43
  st.image(img, use_container_width=True)
44
 
45
- # 2) Caption via HF Inference API
46
  with st.spinner("🔍 Generating caption..."):
47
  buf = BytesIO()
48
  img.save(buf, format="PNG")
49
- caption_output = caption_client(data=buf.getvalue())
50
- # handle API return formats
51
- if isinstance(caption_output, dict):
52
- cap_text = caption_output.get("generated_text", "").strip()
 
 
 
53
  else:
54
- cap_text = str(caption_output).strip()
 
 
 
 
 
55
  st.markdown(f"**Caption:** {cap_text}")
56
 
57
- # 3) Build prompt
58
  prompt = (
59
  f"Here’s an image description: “{cap_text}”.\n\n"
60
  "Write an 80–100 word playful story for 3–10 year-old children that:\n"
@@ -64,11 +68,11 @@ else:
64
  "Story:"
65
  )
66
 
67
- # 4) Story via HF Inference API
68
  with st.spinner("✍️ Generating story..."):
69
- story_output = story_client(
70
  inputs=prompt,
71
- params={
72
  "max_new_tokens": 120,
73
  "do_sample": True,
74
  "temperature": 0.7,
@@ -78,11 +82,17 @@ else:
78
  "no_repeat_ngram_size": 3
79
  }
80
  )
81
- # API returns list of generations or a dict
82
- if isinstance(story_output, list):
83
- story = story_output[0].get("generated_text", "").strip()
 
84
  else:
85
- story = story_output.get("generated_text", "").strip()
 
 
 
 
 
86
  st.markdown("**Story:**")
87
  st.write(story)
88
 
@@ -93,5 +103,4 @@ else:
93
  tts.write_to_fp(tmp)
94
  tmp.flush()
95
  st.audio(tmp.name, format="audio/mp3")
96
-
97
 
 
14
  # —––––––– Inference clients (cached)
15
  @st.cache_resource
16
  def load_clients():
 
17
  hf_token = st.secrets["HF_TOKEN"]
 
 
18
  caption_client = InferenceApi(
19
  repo_id="Salesforce/blip-image-captioning-base",
20
  task="image-to-text",
21
  token=hf_token
22
  )
 
23
  story_client = InferenceApi(
24
  repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
25
  task="text-generation",
 
30
  caption_client, story_client = load_clients()
31
 
32
  # —––––––– Main UI
33
+ uploaded = st.file_uploader("Upload an image:", type=["jpg","jpeg","png"])
34
  if not uploaded:
35
+ st.info("Please upload a JPG/PNG image to begin.")
36
  else:
37
+ # 1) Display image
38
  img = Image.open(uploaded).convert("RGB")
39
  st.image(img, use_container_width=True)
40
 
41
+ # 2) Generate caption
42
  with st.spinner("🔍 Generating caption..."):
43
  buf = BytesIO()
44
  img.save(buf, format="PNG")
45
+ cap_out = caption_client(data=buf.getvalue())
46
+
47
+ # Correctly extract from list/dict
48
+ if isinstance(cap_out, list) and cap_out:
49
+ cap_text = cap_out[0].get("generated_text", "").strip()
50
+ elif isinstance(cap_out, dict):
51
+ cap_text = cap_out.get("generated_text", "").strip()
52
  else:
53
+ cap_text = str(cap_out).strip()
54
+
55
+ if not cap_text:
56
+ st.error("😕 I couldn’t generate a caption. Try uploading a different image.")
57
+ st.stop()
58
+
59
  st.markdown(f"**Caption:** {cap_text}")
60
 
61
+ # 3) Build prompt for story
62
  prompt = (
63
  f"Here’s an image description: “{cap_text}”.\n\n"
64
  "Write an 80–100 word playful story for 3–10 year-old children that:\n"
 
68
  "Story:"
69
  )
70
 
71
+ # 4) Generate story
72
  with st.spinner("✍️ Generating story..."):
73
+ story_out = story_client(
74
  inputs=prompt,
75
+ parameters={ # must be `parameters`, not `params`
76
  "max_new_tokens": 120,
77
  "do_sample": True,
78
  "temperature": 0.7,
 
82
  "no_repeat_ngram_size": 3
83
  }
84
  )
85
+ if isinstance(story_out, list) and story_out:
86
+ story = story_out[0].get("generated_text", "").strip()
87
+ elif isinstance(story_out, dict):
88
+ story = story_out.get("generated_text", "").strip()
89
  else:
90
+ story = str(story_out).strip()
91
+
92
+ if not story:
93
+ st.error("😕 I couldn’t generate a story. Please try again!")
94
+ st.stop()
95
+
96
  st.markdown("**Story:**")
97
  st.write(story)
98
 
 
103
  tts.write_to_fp(tmp)
104
  tmp.flush()
105
  st.audio(tmp.name, format="audio/mp3")
 
106