mayf commited on
Commit
dd489ad
·
verified ·
1 Parent(s): c1e6c9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -13,20 +13,20 @@ st.title("🖼️ ➡️ 📖 Interactive Storyteller")
13
  # —––––––– Model loading + warm-up
14
  @st.cache_resource
15
  def load_pipelines():
16
- # 1) Original BLIP-base captioner
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
- device=0 # set to -1 if CPU-only
21
  )
22
- # 2) Lightweight GPT-Neo for stories
23
  storyteller = pipeline(
24
  "text-generation",
25
  model="EleutherAI/gpt-neo-125M",
26
  device=0
27
  )
28
 
29
- # Warm-up so first real request is fast
30
  dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
31
  captioner(dummy)
32
  storyteller("Hello", max_new_tokens=1)
@@ -35,10 +35,10 @@ def load_pipelines():
35
 
36
  captioner, storyteller = load_pipelines()
37
 
38
- # —––––––– Image upload & processing
39
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
40
  if uploaded:
41
- # 1) Load + downsize for faster vision encoding
42
  image = Image.open(uploaded).convert("RGB")
43
  image = image.resize((384, 384), Image.LANCZOS)
44
  st.image(image, caption="Your image", use_container_width=True)
@@ -48,26 +48,31 @@ if uploaded:
48
  cap = captioner(image)[0]["generated_text"].strip()
49
  st.markdown(f"**Caption:** {cap}")
50
 
51
- # 3) Story generation (greedy for speed)
52
  prompt = (
53
- f"Write an 80–100 word playful story for 3–10 year-olds "
54
- f"based on this description:\n\n“{cap}”\n\nStory:"
55
  )
56
  with st.spinner("✍️ Generating story..."):
57
  out = storyteller(
58
  prompt,
59
- max_new_tokens=120,
60
- do_sample=False
 
 
 
 
 
61
  )
62
- story = out[0]["generated_text"].strip()
 
63
  st.markdown("**Story:**")
64
  st.write(story)
65
 
66
- # 4) Text-to-Speech via gTTS (network-based)
67
  with st.spinner("🔊 Converting to speech..."):
68
  tts = gTTS(text=story, lang="en")
69
  tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
70
  tts.write_to_fp(tmp)
71
  tmp.flush()
72
  st.audio(tmp.name, format="audio/mp3")
73
-
 
13
  # —––––––– Model loading + warm-up
14
  @st.cache_resource
15
  def load_pipelines():
16
+ # 1) Original BLIP-base for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
+ device=0 # change to -1 if you only have CPU
21
  )
22
+ # 2) Small GPT-Neo for quick stories
23
  storyteller = pipeline(
24
  "text-generation",
25
  model="EleutherAI/gpt-neo-125M",
26
  device=0
27
  )
28
 
29
+ # Warm up both so the first real call is faster
30
  dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
31
  captioner(dummy)
32
  storyteller("Hello", max_new_tokens=1)
 
35
 
36
  captioner, storyteller = load_pipelines()
37
 
38
+ # —––––––– Main UI
39
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
40
  if uploaded:
41
+ # 1) Load + resize for faster encoding
42
  image = Image.open(uploaded).convert("RGB")
43
  image = image.resize((384, 384), Image.LANCZOS)
44
  st.image(image, caption="Your image", use_container_width=True)
 
48
  cap = captioner(image)[0]["generated_text"].strip()
49
  st.markdown(f"**Caption:** {cap}")
50
 
51
+ # 3) Story generation (sampling + repetition control)
52
  prompt = (
53
+ f"Write an 80–100 word fun story for 3–10 year-old children "
54
+ f"based on this description:\n\n“{cap}”\n\nStory: "
55
  )
56
  with st.spinner("✍️ Generating story..."):
57
  out = storyteller(
58
  prompt,
59
+ max_new_tokens=120, # room for ~100 words
60
+ do_sample=True, # enable sampling
61
+ temperature=0.8, # creativity
62
+ top_p=0.9, # nucleus sampling
63
+ top_k=50, # limit to top 50 tokens
64
+ repetition_penalty=1.2, # discourage exact repeats
65
+ no_repeat_ngram_size=3 # prevent 3-gram repeats
66
  )
67
+ # strip off the prompt so only the story remains
68
+ story = out[0]["generated_text"][len(prompt):].strip()
69
  st.markdown("**Story:**")
70
  st.write(story)
71
 
72
+ # 4) Text-to-Speech via gTTS
73
  with st.spinner("🔊 Converting to speech..."):
74
  tts = gTTS(text=story, lang="en")
75
  tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
76
  tts.write_to_fp(tmp)
77
  tmp.flush()
78
  st.audio(tmp.name, format="audio/mp3")