mayf commited on
Commit
33fead7
·
verified ·
1 Parent(s): 9a673a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -13,40 +13,40 @@ st.title("🖼️ ➡️ 📖 Interactive Storyteller")
13
  # —––––––– Model loading + warm-up
14
  @st.cache_resource
15
  def load_pipelines():
16
- # 1) Smaller BLIP for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
- model="Salesforce/blip-image-captioning-small",
20
- device=0
21
  )
22
- # 2) Small GPT-Neo for stories
23
  storyteller = pipeline(
24
  "text-generation",
25
  model="EleutherAI/gpt-neo-125M",
26
  device=0
27
  )
28
 
29
- # Warm up both models once
30
- dummy = Image.new("RGB", (384,384), color=(128,128,128))
31
  captioner(dummy)
32
  storyteller("Hello", max_new_tokens=1)
33
 
34
  return captioner, storyteller
35
 
36
- # —––––––– TTS engine init (offline)
37
  @st.cache_resource
38
  def init_tts_engine():
39
  engine = pyttsx3.init()
40
- engine.setProperty('rate', 150)
41
  return engine
42
 
43
  captioner, storyteller = load_pipelines()
44
  tts_engine = init_tts_engine()
45
 
46
- # —––––––– Image upload & processing
47
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
48
  if uploaded:
49
- # 1) Load + downsize
50
  image = Image.open(uploaded).convert("RGB")
51
  image = image.resize((384, 384), Image.LANCZOS)
52
  st.image(image, caption="Your image", use_container_width=True)
@@ -56,21 +56,21 @@ if uploaded:
56
  cap = captioner(image)[0]["generated_text"].strip()
57
  st.markdown(f"**Caption:** {cap}")
58
 
59
- # 3) Story
60
  prompt = (
61
- f"Write a fun, 80–100 word story for kids based on:\n\n“{cap}”\n\nStory:"
62
  )
63
  with st.spinner("✍️ Generating story..."):
64
  out = storyteller(
65
  prompt,
66
  max_new_tokens=120,
67
- do_sample=False, # greedy = fastest
68
  )
69
  story = out[0]["generated_text"].strip()
70
  st.markdown("**Story:**")
71
  st.write(story)
72
 
73
- # 4) TTS
74
  with st.spinner("🔊 Converting to speech..."):
75
  tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
76
  tts_engine.save_to_file(story, tmp.name)
 
13
  # —––––––– Model loading + warm-up
14
  @st.cache_resource
15
  def load_pipelines():
16
+ # 1) Keep the original BLIP-base for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
+ model="Salesforce/blip-image-captioning-base",
20
+ device=0 # if you have GPU; use -1 for CPU-only
21
  )
22
+ # 2) Switch to a lightweight story model
23
  storyteller = pipeline(
24
  "text-generation",
25
  model="EleutherAI/gpt-neo-125M",
26
  device=0
27
  )
28
 
29
+ # Warm up with a dummy run so first real call is fast
30
+ dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
31
  captioner(dummy)
32
  storyteller("Hello", max_new_tokens=1)
33
 
34
  return captioner, storyteller
35
 
36
+ # —––––––– Initialize local TTS (offline)
37
  @st.cache_resource
38
  def init_tts_engine():
39
  engine = pyttsx3.init()
40
+ engine.setProperty("rate", 150) # words per minute
41
  return engine
42
 
43
  captioner, storyteller = load_pipelines()
44
  tts_engine = init_tts_engine()
45
 
46
+ # —––––––– Main UI
47
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
48
  if uploaded:
49
+ # 1) Resize image to reduce BLIP load
50
  image = Image.open(uploaded).convert("RGB")
51
  image = image.resize((384, 384), Image.LANCZOS)
52
  st.image(image, caption="Your image", use_container_width=True)
 
56
  cap = captioner(image)[0]["generated_text"].strip()
57
  st.markdown(f"**Caption:** {cap}")
58
 
59
+ # 3) Story (greedy = fastest)
60
  prompt = (
61
+ f"Tell an 80–100 word fun story for 3–10 year-olds based on:\n\n“{cap}”\n\nStory:"
62
  )
63
  with st.spinner("✍️ Generating story..."):
64
  out = storyteller(
65
  prompt,
66
  max_new_tokens=120,
67
+ do_sample=False
68
  )
69
  story = out[0]["generated_text"].strip()
70
  st.markdown("**Story:**")
71
  st.write(story)
72
 
73
+ # 4) TTS (local, no network)
74
  with st.spinner("🔊 Converting to speech..."):
75
  tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
76
  tts_engine.save_to_file(story, tmp.name)