mayf commited on
Commit
1c165f8
·
verified ·
1 Parent(s): 504dc12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -18,10 +18,11 @@ def load_pipelines():
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base"
20
  )
21
- # 2) Story generation with Flan-T5
22
  storyteller = pipeline(
23
  "text2text-generation",
24
- model="google/flan-t5-base"
 
25
  )
26
  return captioner, storyteller
27
 
@@ -36,22 +37,21 @@ if uploaded:
36
  # —––––––– 1. Caption
37
  with st.spinner("🔍 Looking at the image..."):
38
  cap_outputs = captioner(image)
39
- # BLIP returns a list of dicts with key "generated_text"
40
  cap = cap_outputs[0].get("generated_text", "").strip()
41
  st.markdown(f"**Caption:** {cap}")
42
 
43
  # —––––––– 2. Story generation
44
  prompt = (
45
- "Write a playful, 50–100 word story for 3–10 year-old children "
46
  f"based on this description:\n\n“{cap}”\n\nStory:"
47
  )
48
  with st.spinner("✍️ Writing a story..."):
49
  out = storyteller(
50
  prompt,
51
- max_length=200,
52
  do_sample=True,
53
- top_p=0.9,
54
- temperature=0.8,
55
  num_return_sequences=1
56
  )
57
  story = out[0]["generated_text"].strip()
@@ -65,3 +65,4 @@ if uploaded:
65
  tts.write_to_fp(tmp)
66
  tmp.flush()
67
  st.audio(tmp.name, format="audio/mp3")
 
 
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base"
20
  )
21
+ # 2) Story generation with a bigger Flan-T5
22
  storyteller = pipeline(
23
  "text2text-generation",
24
+ model="google/flan-t5-large",
25
+ device=0 # set to -1 if you only have CPU
26
  )
27
  return captioner, storyteller
28
 
 
37
  # —––––––– 1. Caption
38
  with st.spinner("🔍 Looking at the image..."):
39
  cap_outputs = captioner(image)
 
40
  cap = cap_outputs[0].get("generated_text", "").strip()
41
  st.markdown(f"**Caption:** {cap}")
42
 
43
  # —––––––– 2. Story generation
44
  prompt = (
45
+ "Write a playful, imaginative story of about 50–100 words for 3–10 year-olds, "
46
  f"based on this description:\n\n“{cap}”\n\nStory:"
47
  )
48
  with st.spinner("✍️ Writing a story..."):
49
  out = storyteller(
50
  prompt,
51
+ max_length=250, # give it a bit more room
52
  do_sample=True,
53
+ top_p=0.95,
54
+ temperature=0.7,
55
  num_return_sequences=1
56
  )
57
  story = out[0]["generated_text"].strip()
 
65
  tts.write_to_fp(tmp)
66
  tmp.flush()
67
  st.audio(tmp.name, format="audio/mp3")
68
+