mayf commited on
Commit
cc355a8
·
verified ·
1 Parent(s): 60c225b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -23
app.py CHANGED
@@ -10,72 +10,71 @@ import tempfile
10
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
12
 
13
- # —––––––– Load & warm pipelines
14
  @st.cache_resource
15
  def load_pipelines():
16
- # 1) BLIP-base for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
- device=-1 # or -1 if you only have CPU
21
  )
22
- # 2) Flan-T5-Large for instruction-driven stories
23
  storyteller = pipeline(
24
  "text2text-generation",
25
  model="google/flan-t5-large",
26
  device=0
27
  )
28
-
29
- # Warm-up to avoid first-call lag
30
  dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
31
  captioner(dummy)
32
  storyteller("Warm up", max_new_tokens=1)
33
-
34
  return captioner, storyteller
35
 
36
  captioner, storyteller = load_pipelines()
37
 
38
  # —––––––– Main UI
39
- uploaded = st.file_uploader("Upload an image:", type=["jpg","jpeg","png"])
40
  if uploaded:
41
- # 1) Preprocess & display
42
  image = Image.open(uploaded).convert("RGB")
43
  image = image.resize((384, 384), Image.LANCZOS)
44
  st.image(image, caption="Your image", use_container_width=True)
45
 
46
  # 2) Caption
47
- with st.spinner("🔍 Generating caption..."):
48
  cap = captioner(image)[0]["generated_text"].strip()
49
  st.markdown(f"**Caption:** {cap}")
50
 
51
- # 3) Story
52
  prompt = (
53
- f"Image description: “{cap}”.\n"
54
  "Write an 80–100 word playful story for 3–10 year-old children that:\n"
55
- "1) Sets the scene around the panda.\n"
56
- "2) Describes what it’s doing and how it feels.\n"
57
- "3) Ends with a fun conclusion.\n\n"
58
  "Story:"
59
  )
60
- with st.spinner("✍️ Generating story..."):
61
- out = storyteller(
 
 
62
  prompt,
63
- max_new_tokens=130,
64
  do_sample=True,
65
  temperature=0.7,
66
  top_p=0.9,
67
  top_k=50,
68
- repetition_penalty=1.3,
69
  no_repeat_ngram_size=3
70
  )
71
- # text2text pipeline returns only the generated part
72
- story = out[0]["generated_text"].strip()
73
 
74
  st.markdown("**Story:**")
75
  st.write(story)
76
 
77
- # 4) Text-to-Speech via gTTS
78
- with st.spinner("🔊 Converting to speech..."):
79
  tts = gTTS(text=story, lang="en")
80
  tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
81
  tts.write_to_fp(tmp)
 
10
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
12
 
13
+ # —––––––– Load and warm pipelines
14
  @st.cache_resource
15
  def load_pipelines():
16
+ # BLIP-base for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
+ device=-1 # GPU if available, else -1
21
  )
22
+ # Flan-T5-Large for stories
23
  storyteller = pipeline(
24
  "text2text-generation",
25
  model="google/flan-t5-large",
26
  device=0
27
  )
28
+ # Warm-up runs so user-facing calls are fast
 
29
  dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
30
  captioner(dummy)
31
  storyteller("Warm up", max_new_tokens=1)
 
32
  return captioner, storyteller
33
 
34
  captioner, storyteller = load_pipelines()
35
 
36
  # —––––––– Main UI
37
+ uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
38
  if uploaded:
39
+ # 1) Preprocess image
40
  image = Image.open(uploaded).convert("RGB")
41
  image = image.resize((384, 384), Image.LANCZOS)
42
  st.image(image, caption="Your image", use_container_width=True)
43
 
44
  # 2) Caption
45
+ with st.spinner("🔍 Generating caption"):
46
  cap = captioner(image)[0]["generated_text"].strip()
47
  st.markdown(f"**Caption:** {cap}")
48
 
49
+ # 3) Build a dynamic prompt
50
  prompt = (
51
+ f"Here is an image description: “{cap}”.\n"
52
  "Write an 80–100 word playful story for 3–10 year-old children that:\n"
53
+ "1) Describes the scene and subject from the description.\n"
54
+ "2) Explains what the subject is doing and how it feels.\n"
55
+ "3) Concludes with a fun, imaginative ending.\n\n"
56
  "Story:"
57
  )
58
+
59
+ # 4) Generate the story
60
+ with st.spinner("✍️ Writing the story…"):
61
+ output = storyteller(
62
  prompt,
63
+ max_new_tokens=120,
64
  do_sample=True,
65
  temperature=0.7,
66
  top_p=0.9,
67
  top_k=50,
68
+ repetition_penalty=1.2,
69
  no_repeat_ngram_size=3
70
  )
71
+ story = output[0]["generated_text"].strip()
 
72
 
73
  st.markdown("**Story:**")
74
  st.write(story)
75
 
76
+ # 5) Text-to-Speech
77
+ with st.spinner("🔊 Converting to speech"):
78
  tts = gTTS(text=story, lang="en")
79
  tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
80
  tts.write_to_fp(tmp)