mayf commited on
Commit
c916589
·
verified ·
1 Parent(s): b3f64ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -28
app.py CHANGED
@@ -3,65 +3,76 @@
3
  import streamlit as st
4
  from PIL import Image
5
  from transformers import pipeline
6
- from gtts import gTTS
7
  import tempfile
8
 
9
  # —––––––– Page config
10
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
12
 
13
- # —––––––– Cache model loading
14
  @st.cache_resource
15
  def load_pipelines():
16
- # 1) Image-to-text (captioning)
17
  captioner = pipeline(
18
  "image-to-text",
19
- model="Salesforce/blip-image-captioning-base"
 
20
  )
21
- # 2) Story generation with GPT-Neo 2.7B
22
  storyteller = pipeline(
23
  "text-generation",
24
- model="EleutherAI/gpt-neo-2.7B",
25
- device=-1 # set to -1 if you only have CPU
26
  )
 
 
 
 
 
 
27
  return captioner, storyteller
28
 
 
 
 
 
 
 
 
29
  captioner, storyteller = load_pipelines()
 
30
 
31
- # —––––––– Image upload
32
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
33
  if uploaded:
 
34
  image = Image.open(uploaded).convert("RGB")
35
- st.image(image, caption="Your image", use_column_width=True)
 
36
 
37
- # —––––––– 1. Caption
38
- with st.spinner("🔍 Looking at the image..."):
39
- cap_outputs = captioner(image)
40
- cap = cap_outputs[0].get("generated_text", "").strip()
41
  st.markdown(f"**Caption:** {cap}")
42
 
43
- # —––––––– 2. Story generation
44
  prompt = (
45
- "Write a playful, 80–100 word story for 3–10 year-old children "
46
- f"based on this description:\n\n“{cap}”\n\nStory:"
47
  )
48
- with st.spinner("✍️ Writing a story..."):
49
  out = storyteller(
50
  prompt,
51
- max_new_tokens=120, # allow space for ~100 words
52
- do_sample=True,
53
- top_p=0.9,
54
- temperature=0.8,
55
- num_return_sequences=1
56
  )
57
  story = out[0]["generated_text"].strip()
58
  st.markdown("**Story:**")
59
  st.write(story)
60
 
61
- # —––––––– 3. Text-to-Speech
62
  with st.spinner("🔊 Converting to speech..."):
63
- tts = gTTS(story, lang="en")
64
- tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
65
- tts.write_to_fp(tmp)
66
- tmp.flush()
67
- st.audio(tmp.name, format="audio/mp3")
 
3
  import streamlit as st
4
  from PIL import Image
5
  from transformers import pipeline
6
+ import pyttsx3
7
  import tempfile
8
 
9
  # —––––––– Page config
10
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
12
 
13
+ # —––––––– Model loading + warm-up
14
  @st.cache_resource
15
  def load_pipelines():
16
+ # 1) Smaller BLIP for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
+ model="Salesforce/blip-image-captioning-small",
20
+ device=0
21
  )
22
+ # 2) Small GPT-Neo for stories
23
  storyteller = pipeline(
24
  "text-generation",
25
+ model="EleutherAI/gpt-neo-125M",
26
+ device=0
27
  )
28
+
29
+ # Warm up both models once
30
+ dummy = Image.new("RGB", (384,384), color=(128,128,128))
31
+ captioner(dummy)
32
+ storyteller("Hello", max_new_tokens=1)
33
+
34
  return captioner, storyteller
35
 
36
+ # —––––––– TTS engine init (offline)
37
+ @st.cache_resource
38
+ def init_tts_engine():
39
+ engine = pyttsx3.init()
40
+ engine.setProperty('rate', 150)
41
+ return engine
42
+
43
  captioner, storyteller = load_pipelines()
44
+ tts_engine = init_tts_engine()
45
 
46
+ # —––––––– Image upload & processing
47
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
48
  if uploaded:
49
+ # 1) Load + downsize
50
  image = Image.open(uploaded).convert("RGB")
51
+ image = image.resize((384, 384), Image.LANCZOS)
52
+ st.image(image, caption="Your image", use_container_width=True)
53
 
54
+ # 2) Caption
55
+ with st.spinner("🔍 Generating caption..."):
56
+ cap = captioner(image)[0]["generated_text"].strip()
 
57
  st.markdown(f"**Caption:** {cap}")
58
 
59
+ # 3) Story
60
  prompt = (
61
+ f"Write a fun, 80–100 word story for kids based on:\n\n“{cap}”\n\nStory:"
 
62
  )
63
+ with st.spinner("✍️ Generating story..."):
64
  out = storyteller(
65
  prompt,
66
+ max_new_tokens=120,
67
+ do_sample=False, # greedy = fastest
 
 
 
68
  )
69
  story = out[0]["generated_text"].strip()
70
  st.markdown("**Story:**")
71
  st.write(story)
72
 
73
+ # 4) TTS
74
  with st.spinner("🔊 Converting to speech..."):
75
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
76
+ tts_engine.save_to_file(story, tmp.name)
77
+ tts_engine.runAndWait()
78
+ st.audio(tmp.name)