mayf commited on
Commit
6b1de29
·
verified ·
1 Parent(s): 33fead7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -23
app.py CHANGED
@@ -3,7 +3,7 @@
3
  import streamlit as st
4
  from PIL import Image
5
  from transformers import pipeline
6
- import pyttsx3
7
  import tempfile
8
 
9
  # —––––––– Page config
@@ -13,52 +13,45 @@ st.title("🖼️ ➡️ 📖 Interactive Storyteller")
13
  # —––––––– Model loading + warm-up
14
  @st.cache_resource
15
  def load_pipelines():
16
- # 1) Keep the original BLIP-base for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
- device=0 # if you have GPU; use -1 for CPU-only
21
  )
22
- # 2) Switch to a lightweight story model
23
  storyteller = pipeline(
24
  "text-generation",
25
  model="EleutherAI/gpt-neo-125M",
26
  device=0
27
  )
28
 
29
- # Warm up with a dummy run so first real call is fast
30
  dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
31
  captioner(dummy)
32
  storyteller("Hello", max_new_tokens=1)
33
 
34
  return captioner, storyteller
35
 
36
- # —––––––– Initialize local TTS (offline)
37
- @st.cache_resource
38
- def init_tts_engine():
39
- engine = pyttsx3.init()
40
- engine.setProperty("rate", 150) # words per minute
41
- return engine
42
-
43
  captioner, storyteller = load_pipelines()
44
- tts_engine = init_tts_engine()
45
 
46
- # —––––––– Main UI
47
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
48
  if uploaded:
49
- # 1) Resize image to reduce BLIP load
50
  image = Image.open(uploaded).convert("RGB")
51
  image = image.resize((384, 384), Image.LANCZOS)
52
  st.image(image, caption="Your image", use_container_width=True)
53
 
54
- # 2) Caption
55
  with st.spinner("🔍 Generating caption..."):
56
  cap = captioner(image)[0]["generated_text"].strip()
57
  st.markdown(f"**Caption:** {cap}")
58
 
59
- # 3) Story (greedy = fastest)
60
  prompt = (
61
- f"Tell an 80–100 word fun story for 3–10 year-olds based on:\n\n“{cap}”\n\nStory:"
 
62
  )
63
  with st.spinner("✍️ Generating story..."):
64
  out = storyteller(
@@ -70,9 +63,11 @@ if uploaded:
70
  st.markdown("**Story:**")
71
  st.write(story)
72
 
73
- # 4) TTS (local, no network)
74
  with st.spinner("🔊 Converting to speech..."):
75
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
76
- tts_engine.save_to_file(story, tmp.name)
77
- tts_engine.runAndWait()
78
- st.audio(tmp.name)
 
 
 
3
  import streamlit as st
4
  from PIL import Image
5
  from transformers import pipeline
6
+ from gtts import gTTS
7
  import tempfile
8
 
9
  # —––––––– Page config
 
13
  # —––––––– Model loading + warm-up
14
  @st.cache_resource
15
  def load_pipelines():
16
+ # 1) Original BLIP-base captioner
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
+ device=0 # set to -1 if CPU-only
21
  )
22
+ # 2) Lightweight GPT-Neo for stories
23
  storyteller = pipeline(
24
  "text-generation",
25
  model="EleutherAI/gpt-neo-125M",
26
  device=0
27
  )
28
 
29
+ # Warm-up so first real request is fast
30
  dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
31
  captioner(dummy)
32
  storyteller("Hello", max_new_tokens=1)
33
 
34
  return captioner, storyteller
35
 
 
 
 
 
 
 
 
36
  captioner, storyteller = load_pipelines()
 
37
 
38
+ # —––––––– Image upload & processing
39
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
40
  if uploaded:
41
+ # 1) Load + downsize for faster vision encoding
42
  image = Image.open(uploaded).convert("RGB")
43
  image = image.resize((384, 384), Image.LANCZOS)
44
  st.image(image, caption="Your image", use_container_width=True)
45
 
46
+ # 2) Caption step
47
  with st.spinner("🔍 Generating caption..."):
48
  cap = captioner(image)[0]["generated_text"].strip()
49
  st.markdown(f"**Caption:** {cap}")
50
 
51
+ # 3) Story generation (greedy for speed)
52
  prompt = (
53
+ f"Write an 80–100 word playful story for 3–10 year-olds "
54
+ f"based on this description:\n\n“{cap}”\n\nStory:"
55
  )
56
  with st.spinner("✍️ Generating story..."):
57
  out = storyteller(
 
63
  st.markdown("**Story:**")
64
  st.write(story)
65
 
66
+ # 4) Text-to-Speech via gTTS (network-based)
67
  with st.spinner("🔊 Converting to speech..."):
68
+ tts = gTTS(text=story, lang="en")
69
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
70
+ tts.write_to_fp(tmp)
71
+ tmp.flush()
72
+ st.audio(tmp.name, format="audio/mp3")
73
+