mayf commited on
Commit
0dcd353
·
verified ·
1 Parent(s): 1fca63f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -17
app.py CHANGED
@@ -1,7 +1,8 @@
 
1
  import streamlit as st
2
  from PIL import Image
3
  from io import BytesIO
4
- from huggingface_hub import InferenceApi
5
  from transformers import pipeline
6
  from gtts import gTTS
7
  import tempfile
@@ -14,34 +15,45 @@ st.title("📖✨ Turn Images into Children's Stories")
14
  @st.cache_resource
15
  def load_clients():
16
  hf_token = st.secrets["HF_TOKEN"]
17
- # image captioning client as before
18
- caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
19
- # text-generation pipeline for story
 
 
 
 
 
 
 
 
 
20
  story_generator = pipeline(
21
- "text-generation",
22
  model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
23
  tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
24
- use_auth_token=hf_token,
25
- device=0 # or -1 for CPU
26
  )
 
27
  return caption_client, story_generator
28
 
29
  caption_client, story_generator = load_clients()
30
 
 
31
  # —––––––– Helper: Generate Caption —–––––––
32
- def generate_caption(img):
33
  buf = BytesIO()
34
  img.save(buf, format="JPEG")
35
  try:
36
- out = caption_client(data=buf.getvalue())
37
- return out[0].get("generated_text", "").strip()
 
38
  except Exception as e:
39
- st.error(f"Caption error: {e}")
40
- return ""
 
41
 
42
  # —––––––– Helper: Generate Story via pipeline —–––––––
43
  def generate_story(prompt: str) -> str:
44
- # generate up to ~200 tokens to cover 100 words margin
45
  outputs = story_generator(
46
  prompt,
47
  max_new_tokens=200,
@@ -52,31 +64,43 @@ def generate_story(prompt: str) -> str:
52
  do_sample=True
53
  )
54
  text = outputs[0]["generated_text"].strip()
55
- # everything after "Story:" (if you kept that in your prompt)
 
 
 
 
 
56
  if "Story:" in text:
57
  text = text.split("Story:", 1)[1].strip()
58
- # truncate to 100 words
 
59
  words = text.split()
60
  if len(words) > 100:
61
  text = " ".join(words[:100])
62
  if not text.endswith("."):
63
  text += "."
 
64
  return text
65
 
 
66
  # —––––––– Main App Flow —–––––––
67
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
68
  if uploaded:
 
69
  img = Image.open(uploaded).convert("RGB")
70
  if max(img.size) > 2048:
71
  img.thumbnail((2048, 2048))
72
  st.image(img, use_container_width=True)
73
 
74
- caption = generate_caption(img)
 
 
75
  if not caption:
76
  st.error("😢 Couldn't understand this image. Try another one!")
77
  st.stop()
78
  st.success(f"**Caption:** {caption}")
79
 
 
80
  story_prompt = (
81
  f"Image description: {caption}\n\n"
82
  "Write a 50-100 word children's story that:\n"
@@ -87,12 +111,14 @@ if uploaded:
87
  "Story:\n"
88
  )
89
 
 
90
  with st.spinner("📝 Writing magical story..."):
91
  story = generate_story(story_prompt)
 
92
  st.subheader("📚 Your Magical Story")
93
  st.write(story)
94
 
95
- # Audio Conversion
96
  with st.spinner("🔊 Adding story voice..."):
97
  try:
98
  tts = gTTS(text=story, lang="en", slow=False)
 
1
+ import os
2
  import streamlit as st
3
  from PIL import Image
4
  from io import BytesIO
5
+ from huggingface_hub import InferenceApi, login
6
  from transformers import pipeline
7
  from gtts import gTTS
8
  import tempfile
 
15
  @st.cache_resource
16
  def load_clients():
17
  hf_token = st.secrets["HF_TOKEN"]
18
+
19
+ # 1) Authenticate so transformers can pick up your token automatically
20
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
21
+ login(hf_token)
22
+
23
+ # 2) BLIP-based image captioning client
24
+ caption_client = InferenceApi(
25
+ repo_id="Salesforce/blip-image-captioning-base",
26
+ token=hf_token
27
+ )
28
+
29
+ # 3) Text-generation pipeline for story creation
30
  story_generator = pipeline(
31
+ task="text-generation",
32
  model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
33
  tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
34
+ device=0 # set to -1 to run on CPU
 
35
  )
36
+
37
  return caption_client, story_generator
38
 
39
  caption_client, story_generator = load_clients()
40
 
41
+
42
  # —––––––– Helper: Generate Caption —–––––––
43
+ def generate_caption(img: Image.Image) -> str:
44
  buf = BytesIO()
45
  img.save(buf, format="JPEG")
46
  try:
47
+ response = caption_client(data=buf.getvalue())
48
+ if isinstance(response, list) and response:
49
+ return response[0].get("generated_text", "").strip()
50
  except Exception as e:
51
+ st.error(f"Caption generation error: {e}")
52
+ return ""
53
+
54
 
55
  # —––––––– Helper: Generate Story via pipeline —–––––––
56
  def generate_story(prompt: str) -> str:
 
57
  outputs = story_generator(
58
  prompt,
59
  max_new_tokens=200,
 
64
  do_sample=True
65
  )
66
  text = outputs[0]["generated_text"].strip()
67
+
68
+ # If prompt was echoed, remove it
69
+ if text.startswith(prompt):
70
+ text = text[len(prompt):].strip()
71
+
72
+ # If you included a "Story:" marker, split it out
73
  if "Story:" in text:
74
  text = text.split("Story:", 1)[1].strip()
75
+
76
+ # Truncate to at most 100 words
77
  words = text.split()
78
  if len(words) > 100:
79
  text = " ".join(words[:100])
80
  if not text.endswith("."):
81
  text += "."
82
+
83
  return text
84
 
85
+
86
  # —––––––– Main App Flow —–––––––
87
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
88
  if uploaded:
89
+ # Load & resize
90
  img = Image.open(uploaded).convert("RGB")
91
  if max(img.size) > 2048:
92
  img.thumbnail((2048, 2048))
93
  st.image(img, use_container_width=True)
94
 
95
+ # Caption
96
+ with st.spinner("🔍 Discovering image secrets..."):
97
+ caption = generate_caption(img)
98
  if not caption:
99
  st.error("😢 Couldn't understand this image. Try another one!")
100
  st.stop()
101
  st.success(f"**Caption:** {caption}")
102
 
103
+ # Build prompt
104
  story_prompt = (
105
  f"Image description: {caption}\n\n"
106
  "Write a 50-100 word children's story that:\n"
 
111
  "Story:\n"
112
  )
113
 
114
+ # Generate story
115
  with st.spinner("📝 Writing magical story..."):
116
  story = generate_story(story_prompt)
117
+
118
  st.subheader("📚 Your Magical Story")
119
  st.write(story)
120
 
121
+ # Convert to audio
122
  with st.spinner("🔊 Adding story voice..."):
123
  try:
124
  tts = gTTS(text=story, lang="en", slow=False)