mayf commited on
Commit
ed4df47
·
verified ·
1 Parent(s): 0dcd353

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -46
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import streamlit as st
3
  from PIL import Image
4
  from io import BytesIO
@@ -8,31 +9,39 @@ from gtts import gTTS
8
  import tempfile
9
 
10
  # —––––––– Page Config —–––––––
11
- st.set_page_config(page_title="Magic Story Generator", layout="centered")
12
- st.title("📖✨ Turn Images into Children's Stories")
13
 
14
  # —––––––– Clients (cached) —–––––––
15
- @st.cache_resource
16
  def load_clients():
17
  hf_token = st.secrets["HF_TOKEN"]
18
 
19
- # 1) Authenticate so transformers can pick up your token automatically
20
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
21
  login(hf_token)
22
 
23
- # 2) BLIP-based image captioning client
 
 
 
 
 
24
  caption_client = InferenceApi(
25
  repo_id="Salesforce/blip-image-captioning-base",
26
  token=hf_token
27
  )
28
 
29
- # 3) Text-generation pipeline for story creation
 
30
  story_generator = pipeline(
31
  task="text-generation",
32
  model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
33
  tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
34
- device=0 # set to -1 to run on CPU
 
35
  )
 
36
 
37
  return caption_client, story_generator
38
 
@@ -44,91 +53,84 @@ def generate_caption(img: Image.Image) -> str:
44
  buf = BytesIO()
45
  img.save(buf, format="JPEG")
46
  try:
47
- response = caption_client(data=buf.getvalue())
48
- if isinstance(response, list) and response:
49
- return response[0].get("generated_text", "").strip()
50
  except Exception as e:
51
  st.error(f"Caption generation error: {e}")
52
  return ""
53
 
54
 
55
  # —––––––– Helper: Generate Story via pipeline —–––––––
56
- def generate_story(prompt: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  outputs = story_generator(
58
  prompt,
59
- max_new_tokens=200,
60
- temperature=0.8,
61
- top_p=0.95,
62
- repetition_penalty=1.15,
63
- no_repeat_ngram_size=2,
64
  do_sample=True
65
  )
66
- text = outputs[0]["generated_text"].strip()
 
67
 
68
- # If prompt was echoed, remove it
 
69
  if text.startswith(prompt):
70
  text = text[len(prompt):].strip()
71
-
72
- # If you included a "Story:" marker, split it out
73
- if "Story:" in text:
74
- text = text.split("Story:", 1)[1].strip()
75
-
76
- # Truncate to at most 100 words
77
  words = text.split()
78
  if len(words) > 100:
79
  text = " ".join(words[:100])
80
  if not text.endswith("."):
81
  text += "."
82
-
83
  return text
84
 
85
 
86
  # —––––––– Main App Flow —–––––––
87
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
88
  if uploaded:
89
- # Load & resize
90
  img = Image.open(uploaded).convert("RGB")
91
  if max(img.size) > 2048:
92
  img.thumbnail((2048, 2048))
93
  st.image(img, use_container_width=True)
94
 
95
- # Caption
96
- with st.spinner("🔍 Discovering image secrets..."):
97
  caption = generate_caption(img)
98
  if not caption:
99
  st.error("😢 Couldn't understand this image. Try another one!")
100
  st.stop()
101
  st.success(f"**Caption:** {caption}")
102
 
103
- # Build prompt
104
- story_prompt = (
105
- f"Image description: {caption}\n\n"
106
- "Write a 50-100 word children's story that:\n"
107
- "1. Features the main subject as a friendly character\n"
108
- "2. Includes a simple adventure or discovery\n"
109
- "3. Ends with a happy or funny conclusion\n"
110
- "4. Uses simple language for ages 3-8\n\n"
111
- "Story:\n"
112
- )
113
-
114
- # Generate story
115
  with st.spinner("📝 Writing magical story..."):
116
- story = generate_story(story_prompt)
117
 
118
  st.subheader("📚 Your Magical Story")
119
  st.write(story)
120
 
121
- # Convert to audio
122
- with st.spinner("🔊 Adding story voice..."):
123
  try:
124
  tts = gTTS(text=story, lang="en", slow=False)
125
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
126
  tts.save(fp.name)
127
  st.audio(fp.name, format="audio/mp3")
128
  except Exception as e:
129
- st.warning(f"⚠️ Couldn't make audio version: {e}")
130
 
131
  # Footer
132
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
133
-
134
-
 
1
  import os
2
+ import time
3
  import streamlit as st
4
  from PIL import Image
5
  from io import BytesIO
 
9
  import tempfile
10
 
11
  # —––––––– Page Config —–––––––
12
+ st.set_page_config(page_title="Magic Story Generator (CPU)", layout="centered")
13
+ st.title("📖✨ Turn Images into Children's Stories (CPU)")
14
 
15
  # —––––––– Clients (cached) —–––––––
16
+ @st.cache_resource(show_spinner=False)
17
  def load_clients():
18
  hf_token = st.secrets["HF_TOKEN"]
19
 
20
+ # Authenticate once so pipelines use your token automatically
21
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
22
  login(hf_token)
23
 
24
+ # Pin cache locally to avoid re-downloads
25
+ cache_dir = "./hf_cache"
26
+ os.makedirs(cache_dir, exist_ok=True)
27
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
28
+
29
+ # 1) BLIP-based image captioning client
30
  caption_client = InferenceApi(
31
  repo_id="Salesforce/blip-image-captioning-base",
32
  token=hf_token
33
  )
34
 
35
+ # 2) Text-generation pipeline forced onto CPU
36
+ t0 = time.time()
37
  story_generator = pipeline(
38
  task="text-generation",
39
  model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
40
  tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
41
+ device=-1, # CPU only
42
+ cache_dir=cache_dir
43
  )
44
+ st.text(f"✅ Story model loaded in {time.time() - t0:.1f}s (cached thereafter)")
45
 
46
  return caption_client, story_generator
47
 
 
53
  buf = BytesIO()
54
  img.save(buf, format="JPEG")
55
  try:
56
+ resp = caption_client(data=buf.getvalue())
57
+ if isinstance(resp, list) and resp:
58
+ return resp[0].get("generated_text", "").strip()
59
  except Exception as e:
60
  st.error(f"Caption generation error: {e}")
61
  return ""
62
 
63
 
64
  # —––––––– Helper: Generate Story via pipeline —–––––––
65
+ def generate_story(caption: str) -> str:
66
+ prompt = f"""
67
+ You are a creative children’s-story author.
68
+ Below is the description of an image:
69
+ “{caption}”
70
+
71
+ Write a coherent, 50 to 100-word story that:
72
+ 1. Introduces the main character from the image.
73
+ 2. Shows a simple problem or discovery.
74
+ 3. Resolves it in a happy ending.
75
+ 4. Uses clear language for ages 3–8.
76
+ 5. Keeps each sentence under 20 words.
77
+ Story:
78
+ """
79
+ t0 = time.time()
80
  outputs = story_generator(
81
  prompt,
82
+ max_new_tokens=120,
83
+ temperature=0.7,
84
+ top_p=0.9,
85
+ repetition_penalty=1.1,
86
+ no_repeat_ngram_size=3,
87
  do_sample=True
88
  )
89
+ gen_time = time.time() - t0
90
+ st.text(f"⏱ Generated in {gen_time:.1f}s on CPU")
91
 
92
+ text = outputs[0]["generated_text"].strip()
93
+ # Remove the echoed prompt portion
94
  if text.startswith(prompt):
95
  text = text[len(prompt):].strip()
96
+ # Enforce max 100 words
 
 
 
 
 
97
  words = text.split()
98
  if len(words) > 100:
99
  text = " ".join(words[:100])
100
  if not text.endswith("."):
101
  text += "."
 
102
  return text
103
 
104
 
105
  # —––––––– Main App Flow —–––––––
106
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
107
  if uploaded:
 
108
  img = Image.open(uploaded).convert("RGB")
109
  if max(img.size) > 2048:
110
  img.thumbnail((2048, 2048))
111
  st.image(img, use_container_width=True)
112
 
113
+ with st.spinner("🔍 Generating caption..."):
 
114
  caption = generate_caption(img)
115
  if not caption:
116
  st.error("😢 Couldn't understand this image. Try another one!")
117
  st.stop()
118
  st.success(f"**Caption:** {caption}")
119
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  with st.spinner("📝 Writing magical story..."):
121
+ story = generate_story(caption)
122
 
123
  st.subheader("📚 Your Magical Story")
124
  st.write(story)
125
 
126
+ with st.spinner("🔊 Converting to audio..."):
 
127
  try:
128
  tts = gTTS(text=story, lang="en", slow=False)
129
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
130
  tts.save(fp.name)
131
  st.audio(fp.name, format="audio/mp3")
132
  except Exception as e:
133
+ st.warning(f"⚠️ TTS failed: {e}")
134
 
135
  # Footer
136
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")