mayf commited on
Commit
e9fb854
·
verified ·
1 Parent(s): eb25a05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -19
app.py CHANGED
@@ -10,26 +10,26 @@ import tempfile
10
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
12
 
13
- # —––––––– Model loading + warm-up
14
  @st.cache_resource
15
  def load_pipelines():
16
- # 1) Original BLIP-base for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
  device=0
21
  )
22
- # 2) Instruction-tuned Flan-T5 small for stories
23
  storyteller = pipeline(
24
  "text2text-generation",
25
- model="google/flan-t5-small",
26
  device=0
27
  )
28
 
29
- # Warm up so first real request is faster
30
- dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
31
  captioner(dummy)
32
- storyteller("Tell me something", max_new_tokens=1)
33
 
34
  return captioner, storyteller
35
 
@@ -38,8 +38,9 @@ captioner, storyteller = load_pipelines()
38
  # —––––––– Main UI
39
  uploaded = st.file_uploader("Upload an image:", type=["jpg","jpeg","png"])
40
  if uploaded:
41
- # 1) Load + downsize
42
- image = Image.open(uploaded).convert("RGB").resize((384, 384), Image.LANCZOS)
 
43
  st.image(image, caption="Your image", use_container_width=True)
44
 
45
  # 2) Caption
@@ -47,32 +48,37 @@ if uploaded:
47
  cap = captioner(image)[0]["generated_text"].strip()
48
  st.markdown(f"**Caption:** {cap}")
49
 
50
- # 3) Story generation
51
  prompt = (
52
- f"Here is an image description: “{cap}”.\n"
53
- "Write a playful, 80–100 word story for 3–10 year-olds\n\n"
 
 
 
54
  "Story:"
55
  )
56
- with st.spinner("✍️ Generating story..."):
57
  out = storyteller(
58
  prompt,
59
- max_new_tokens=150,
60
  do_sample=True,
61
  temperature=0.7,
62
  top_p=0.9,
63
- repetition_penalty=1.2,
 
64
  no_repeat_ngram_size=3
65
  )
66
- # strip off the prompt so you only get the story
67
- story = out[0]["generated_text"].split("Story:")[-1].strip()
 
68
 
69
  st.markdown("**Story:**")
70
  st.write(story)
71
 
72
- # 4) Text-to-Speech
73
  with st.spinner("🔊 Converting to speech..."):
74
  tts = gTTS(text=story, lang="en")
75
  tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
76
  tts.write_to_fp(tmp)
77
  tmp.flush()
78
- st.audio(tmp.name, format="audio/mp3")
 
10
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
12
 
13
+ # —––––––– Load & warm models
14
  @st.cache_resource
15
  def load_pipelines():
16
+ # 1) BLIP-base for image captions
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
  device=0
21
  )
22
+ # 2) Flan-T5-Large for instruction following
23
  storyteller = pipeline(
24
  "text2text-generation",
25
+ model="google/flan-t5-large",
26
  device=0
27
  )
28
 
29
+ # Warm up so first real call is faster
30
+ dummy = Image.new("RGB", (384, 384), color=(128,128,128))
31
  captioner(dummy)
32
+ storyteller("Hello", max_new_tokens=1)
33
 
34
  return captioner, storyteller
35
 
 
38
  # —––––––– Main UI
39
  uploaded = st.file_uploader("Upload an image:", type=["jpg","jpeg","png"])
40
  if uploaded:
41
+ # 1) Preprocess + display
42
+ image = Image.open(uploaded).convert("RGB")
43
+ image = image.resize((384,384), Image.LANCZOS)
44
  st.image(image, caption="Your image", use_container_width=True)
45
 
46
  # 2) Caption
 
48
  cap = captioner(image)[0]["generated_text"].strip()
49
  st.markdown(f"**Caption:** {cap}")
50
 
51
+ # 3) Story — stronger, clearer prompt
52
  prompt = (
53
+ f"Here’s an image description: “{cap}”.\n\n"
54
+ "Write a playful, 80–100 word story for 3–10 year-old children.\n"
55
+ "- Focus only on the panda and what it’s doing.\n"
56
+ "- Do not introduce any other characters (no kids, no parents).\n"
57
+ "- Be vivid: mention the panda’s feelings or the crunchy meat.\n\n"
58
  "Story:"
59
  )
60
+ with st.spinner("✍️ Writing story..."):
61
  out = storyteller(
62
  prompt,
63
+ max_new_tokens=130,
64
  do_sample=True,
65
  temperature=0.7,
66
  top_p=0.9,
67
+ top_k=50,
68
+ repetition_penalty=1.3,
69
  no_repeat_ngram_size=3
70
  )
71
+ # strip prompt prefix keep only the generated story
72
+ raw = out[0]["generated_text"]
73
+ story = raw.split("Story:")[-1].strip()
74
 
75
  st.markdown("**Story:**")
76
  st.write(story)
77
 
78
+ # 4) Text-to-Speech (gTTS)
79
  with st.spinner("🔊 Converting to speech..."):
80
  tts = gTTS(text=story, lang="en")
81
  tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
82
  tts.write_to_fp(tmp)
83
  tmp.flush()
84
+ st.audio(tmp.name, format="audio/mp3")