mayf commited on
Commit
8087810
·
verified ·
1 Parent(s): 313c831

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -57
app.py CHANGED
@@ -2,42 +2,38 @@ import os
2
  import time
3
  import streamlit as st
4
  from PIL import Image
5
- from io import BytesIO
6
  from transformers import pipeline
7
- from huggingface_hub import login
8
  from gtts import gTTS
9
  import tempfile
10
 
11
  # —––––––– Requirements —–––––––
12
- # pip install streamlit pillow gTTS transformers huggingface_hub
13
-
14
- # —––––––– Page Config —–––––––
15
- st.set_page_config(page_title="Magic Story Generator (Local Pipeline)", layout="centered")
 
 
 
 
 
16
  st.title("📖✨ Turn Images into Children's Stories")
17
 
18
- # —––––––– Load Clients & Pipelines (cached) —–––––––
19
  @st.cache_resource(show_spinner=False)
20
- def load_clients():
21
- # Authenticate to pull private or remote-code models if needed
22
- hf_token = st.secrets.get("HF_TOKEN")
23
- if hf_token:
24
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
25
- login(hf_token)
26
-
27
  # 1) Image-captioning pipeline (BLIP)
28
  captioner = pipeline(
29
  task="image-to-text",
30
  model="Salesforce/blip-image-captioning-base",
31
- device=-1 # CPU; change to 0 for GPU
32
  )
33
-
34
  # 2) Story-generation pipeline (DeepSeek-R1-Distill-Qwen)
35
  storyteller = pipeline(
36
- task="text-generation",
37
  model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
38
  tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
39
  trust_remote_code=True,
40
- device=-1, # CPU; set 0+ for GPU
41
  temperature=0.6,
42
  top_p=0.9,
43
  repetition_penalty=1.1,
@@ -45,68 +41,54 @@ def load_clients():
45
  max_new_tokens=120,
46
  return_full_text=False
47
  )
48
-
49
  return captioner, storyteller
50
 
51
- captioner, storyteller = load_clients()
52
-
53
- # —––––––– Helpers —–––––––
54
- def generate_caption(img: Image.Image) -> str:
55
- # Use the BLIP pipeline to generate a caption
56
- result = captioner(img)
57
- if isinstance(result, list) and result:
58
- return result[0].get("generated_text", "").strip()
59
- return ""
60
-
61
-
62
- def generate_story(caption: str) -> str:
63
- # Build a simple prompt incorporating the caption
64
- prompt = (
65
- f"Image description: {caption}\n"
66
- "Write a coherent 50-100 word children's story that flows naturally."
67
- )
68
-
69
- t0 = time.time()
70
- outputs = storyteller(
71
- prompt
72
- )
73
- gen_time = time.time() - t0
74
- st.text(f"⏱ Generated in {gen_time:.1f}s")
75
-
76
- story = outputs[0].get("generated_text", "").strip()
77
- # Truncate to 100 words
78
- words = story.split()
79
- if len(words) > 100:
80
- story = " ".join(words[:100]) + ('.' if not story.endswith('.') else '')
81
- return story
82
 
83
  # —––––––– Main App —–––––––
84
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
85
  if uploaded:
 
86
  img = Image.open(uploaded).convert("RGB")
87
- if max(img.size) > 2048:
88
- img.thumbnail((2048, 2048))
89
  st.image(img, use_container_width=True)
90
 
 
91
  with st.spinner("🔍 Generating caption..."):
92
- caption = generate_caption(img)
 
93
  if not caption:
94
  st.error("😢 Couldn't understand this image. Try another one!")
95
  st.stop()
96
  st.success(f"**Caption:** {caption}")
97
 
 
 
 
 
 
98
  with st.spinner("📝 Writing story..."):
99
- story = generate_story(caption)
 
 
 
 
 
 
 
 
 
100
 
101
  st.subheader("📚 Your Magical Story")
102
  st.write(story)
103
 
 
104
  with st.spinner("🔊 Converting to audio..."):
105
  try:
106
  tts = gTTS(text=story, lang="en", slow=False)
107
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
108
- tts.save(fp.name)
109
- st.audio(fp.name, format="audio/mp3")
110
  except Exception as e:
111
  st.warning(f"⚠️ TTS failed: {e}")
112
 
 
2
  import time
3
  import streamlit as st
4
  from PIL import Image
 
5
  from transformers import pipeline
 
6
  from gtts import gTTS
7
  import tempfile
8
 
9
  # —––––––– Requirements —–––––––
10
+ # streamlit>=1.20
11
+ # pillow>=9.0
12
+ # transformers>=4.30
13
+ # torch>=2.0.0
14
+ # sentencepiece>=0.1.97
15
+ # gTTS>=2.3.1
16
+
17
+ # —––––––– Page Setup —–––––––
18
+ st.set_page_config(page_title="Magic Story Generator", layout="centered")
19
  st.title("📖✨ Turn Images into Children's Stories")
20
 
21
+ # —––––––– Load Pipelines (cached) —–––––––
22
  @st.cache_resource(show_spinner=False)
23
+ def load_pipelines():
 
 
 
 
 
 
24
  # 1) Image-captioning pipeline (BLIP)
25
  captioner = pipeline(
26
  task="image-to-text",
27
  model="Salesforce/blip-image-captioning-base",
28
+ device=-1 # CPU; set to 0+ for GPU
29
  )
 
30
  # 2) Story-generation pipeline (DeepSeek-R1-Distill-Qwen)
31
  storyteller = pipeline(
32
+ task="text2text-generation",
33
  model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
34
  tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
35
  trust_remote_code=True,
36
+ device=-1, # CPU; set to 0+ for GPU
37
  temperature=0.6,
38
  top_p=0.9,
39
  repetition_penalty=1.1,
 
41
  max_new_tokens=120,
42
  return_full_text=False
43
  )
 
44
  return captioner, storyteller
45
 
46
+ captioner, storyteller = load_pipelines()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # —––––––– Main App —–––––––
49
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
50
  if uploaded:
51
+ # Display uploaded image
52
  img = Image.open(uploaded).convert("RGB")
53
+ img.thumbnail((2048, 2048), Image.ANTIALIAS)
 
54
  st.image(img, use_container_width=True)
55
 
56
+ # Generate caption
57
  with st.spinner("🔍 Generating caption..."):
58
+ cap = captioner(img)
59
+ caption = cap[0].get("generated_text", "").strip() if isinstance(cap, list) else ""
60
  if not caption:
61
  st.error("😢 Couldn't understand this image. Try another one!")
62
  st.stop()
63
  st.success(f"**Caption:** {caption}")
64
 
65
+ # Generate story
66
+ prompt = (
67
+ f"Image description: {caption}\n"
68
+ "Write a coherent, 50-100 word children’s story that flows naturally."
69
+ )
70
  with st.spinner("📝 Writing story..."):
71
+ start = time.time()
72
+ out = storyteller(prompt)
73
+ gen_time = time.time() - start
74
+ st.text(f"⏱ Generated in {gen_time:.1f}s")
75
+ story = out[0].get("generated_text", "").strip()
76
+
77
+ # Enforce ≤100 words
78
+ words = story.split()
79
+ if len(words) > 100:
80
+ story = " ".join(words[:100]) + ("" if story.endswith('.') else ".")
81
 
82
  st.subheader("📚 Your Magical Story")
83
  st.write(story)
84
 
85
+ # Convert to audio
86
  with st.spinner("🔊 Converting to audio..."):
87
  try:
88
  tts = gTTS(text=story, lang="en", slow=False)
89
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
90
+ tts.save(tmp.name)
91
+ st.audio(tmp.name, format="audio/mp3")
92
  except Exception as e:
93
  st.warning(f"⚠️ TTS failed: {e}")
94