mayf commited on
Commit
a2fa6c1
Β·
verified Β·
1 Parent(s): c4140bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -27
app.py CHANGED
@@ -1,63 +1,56 @@
1
  import streamlit as st
2
  from PIL import Image
3
- from io import BytesIO
4
  from transformers import pipeline
5
  from gtts import gTTS
6
  import tempfile
7
 
8
- # —––––––– Page config and title
9
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
10
  st.title("πŸ–ΌοΈ ➑️ πŸ“– Interactive Storyteller")
11
 
12
- # —––––––– Load pipelines (cached)
13
- @st.experimental_singleton
14
  def load_pipelines():
15
- # 1. Image captioning
16
  captioner = pipeline(
17
  "image-captioning",
18
- model="Salesforce/blip-image-captioning-base",
19
- device=0 if not st.session_state.get("CPU_ONLY", False) else -1
20
  )
21
- # 2. Story generation (you can swap to a kid-friendly fine-tuned model)
22
  storyteller = pipeline(
23
- "text-generation",
24
- model="gpt2",
25
- device=0 if not st.session_state.get("CPU_ONLY", False) else -1
26
  )
27
  return captioner, storyteller
28
 
29
  captioner, storyteller = load_pipelines()
30
 
31
- # —––––––– Sidebar: CPU/GPU toggle (optional)
32
- st.sidebar.write("### Settings")
33
- st.sidebar.checkbox("Force CPU only", key="CPU_ONLY")
34
-
35
- # —––––––– Main UI: image upload
36
- uploaded = st.file_uploader("Upload an image:", type=["jpg","jpeg","png"])
37
  if uploaded:
38
  image = Image.open(uploaded).convert("RGB")
39
- st.image(image, caption="Your picture", use_column_width=True)
40
 
41
  # —––––––– 1. Caption
42
  with st.spinner("πŸ” Looking at the image..."):
43
- caption = captioner(image)[0]["generated_text"]
44
- st.markdown(f"**Caption:** {caption}")
45
 
46
  # —––––––– 2. Story generation
47
  prompt = (
48
- f"Use the following description to write a playful story (50–100 words) "
49
- f"for 3–10 year-old children:\n\nβ€œ{caption}”\n\nStory:"
50
  )
51
  with st.spinner("✍️ Writing a story..."):
52
- output = storyteller(
53
  prompt,
54
- max_length= prompt.count(" ") + 100, # approx ~100 words
55
- num_return_sequences=1,
56
  do_sample=True,
57
  top_p=0.9,
58
- temperature=0.8
 
59
  )
60
- story = output[0]["generated_text"].split("Story:")[-1].strip()
61
  st.markdown("**Story:**")
62
  st.write(story)
63
 
 
1
  import streamlit as st
2
  from PIL import Image
 
3
  from transformers import pipeline
4
  from gtts import gTTS
5
  import tempfile
6
 
7
+ # —––––––– Page config
8
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
9
  st.title("πŸ–ΌοΈ ➑️ πŸ“– Interactive Storyteller")
10
 
11
+ # —––––––– Cache model loading
12
+ @st.cache_resource
13
  def load_pipelines():
14
+ # 1) Image captioning
15
  captioner = pipeline(
16
  "image-captioning",
17
+ model="Salesforce/blip-image-captioning-base"
 
18
  )
19
+ # 2) Story generation with Flan-T5
20
  storyteller = pipeline(
21
+ "text2text-generation",
22
+ model="google/flan-t5-base"
 
23
  )
24
  return captioner, storyteller
25
 
26
  captioner, storyteller = load_pipelines()
27
 
28
+ # —––––––– Image upload
29
+ uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
 
 
 
 
30
  if uploaded:
31
  image = Image.open(uploaded).convert("RGB")
32
+ st.image(image, caption="Your image", use_column_width=True)
33
 
34
  # —––––––– 1. Caption
35
  with st.spinner("πŸ” Looking at the image..."):
36
+ cap = captioner(image)[0]["generated_text"]
37
+ st.markdown(f"**Caption:** {cap}")
38
 
39
  # —––––––– 2. Story generation
40
  prompt = (
41
+ "Write a playful, 50–100 word story for 3–10 year-old children "
42
+ f"based on this description:\n\nβ€œ{cap}”\n\nStory:"
43
  )
44
  with st.spinner("✍️ Writing a story..."):
45
+ out = storyteller(
46
  prompt,
47
+ max_length=200,
 
48
  do_sample=True,
49
  top_p=0.9,
50
+ temperature=0.8,
51
+ num_return_sequences=1
52
  )
53
+ story = out[0]["generated_text"].strip()
54
  st.markdown("**Story:**")
55
  st.write(story)
56