mayf commited on
Commit
db1550f
·
verified ·
1 Parent(s): 1c165f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -46
app.py CHANGED
@@ -1,68 +1,120 @@
1
- # app.py
2
-
3
  import streamlit as st
4
  from PIL import Image
5
  from transformers import pipeline
6
  from gtts import gTTS
7
  import tempfile
 
8
 
9
- # —––––––– Page config
10
- st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
- st.title("🖼️ ➡️ 📖 Interactive Storyteller")
 
 
 
 
 
12
 
13
- # —––––––– Cache model loading
14
  @st.cache_resource
15
  def load_pipelines():
16
- # 1) Image-to-text (captioning)
17
  captioner = pipeline(
18
  "image-to-text",
19
- model="Salesforce/blip-image-captioning-base"
 
20
  )
21
- # 2) Story generation with a bigger Flan-T5
 
22
  storyteller = pipeline(
23
  "text2text-generation",
24
- model="google/flan-t5-large",
25
- device=0 # set to -1 if you only have CPU
 
26
  )
 
27
  return captioner, storyteller
28
 
29
- captioner, storyteller = load_pipelines()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # —––––––– Image upload
32
- uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
33
- if uploaded:
34
- image = Image.open(uploaded).convert("RGB")
35
- st.image(image, caption="Your image", use_column_width=True)
 
 
36
 
37
- # —––––––– 1. Caption
38
- with st.spinner("🔍 Looking at the image..."):
39
- cap_outputs = captioner(image)
40
- cap = cap_outputs[0].get("generated_text", "").strip()
41
- st.markdown(f"**Caption:** {cap}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # —––––––– 2. Story generation
44
- prompt = (
45
- "Write a playful, imaginative story of about 50–100 words for 3–10 year-olds, "
46
- f"based on this description:\n\n“{cap}”\n\nStory:"
47
- )
48
- with st.spinner("✍️ Writing a story..."):
49
- out = storyteller(
50
- prompt,
51
- max_length=250, # give it a bit more room
52
- do_sample=True,
53
- top_p=0.95,
54
- temperature=0.7,
55
- num_return_sequences=1
56
- )
57
- story = out[0]["generated_text"].strip()
58
- st.markdown("**Story:**")
59
- st.write(story)
 
 
 
 
 
 
60
 
61
- # —––––––– 3. Text-to-Speech
62
- with st.spinner("🔊 Converting to speech..."):
63
- tts = gTTS(story, lang="en")
64
- tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
65
- tts.write_to_fp(tmp)
66
- tmp.flush()
67
- st.audio(tmp.name, format="audio/mp3")
68
 
 
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import pipeline
4
  from gtts import gTTS
5
  import tempfile
6
+ import os
7
 
8
+ # —––––––– Page config —–––––––
9
+ st.set_page_config(
10
+ page_title="Storyteller for Kids",
11
+ page_icon="📚",
12
+ layout="centered",
13
+ initial_sidebar_state="collapsed"
14
+ )
15
+ st.title("🖼️➡️📖 Interactive Storyteller")
16
 
17
+ # —––––––– Cache model loading —–––––––
18
  @st.cache_resource
19
  def load_pipelines():
20
+ # Image-to-text pipeline
21
  captioner = pipeline(
22
  "image-to-text",
23
+ model="Salesforce/blip-image-captioning-base",
24
+ max_new_tokens=50
25
  )
26
+
27
+ # Story generation pipeline with better parameters
28
  storyteller = pipeline(
29
  "text2text-generation",
30
+ model="google/flan-t5-xxl",
31
+ device_map="auto",
32
+ model_kwargs={"load_in_8bit": True}
33
  )
34
+
35
  return captioner, storyteller
36
 
37
+ # —––––––– Main workflow —–––––––
38
+ def main():
39
+ captioner, storyteller = load_pipelines()
40
+
41
+ # —––––––– Image upload —–––––––
42
+ uploaded = st.file_uploader(
43
+ "Upload an image:",
44
+ type=["jpg", "jpeg", "png"],
45
+ help="Max size: 5MB"
46
+ )
47
+
48
+ if uploaded:
49
+ try:
50
+ # —––––––– Display image —–––––––
51
+ image = Image.open(uploaded).convert("RGB")
52
+ st.image(image, caption="Your Image", use_column_width=True)
53
 
54
+ # —––––––– Generate caption —–––––––
55
+ with st.spinner("🔍 Analyzing image content..."):
56
+ cap_outputs = captioner(image)
57
+ cap = cap_outputs[0].get("generated_text", "").strip()
58
+
59
+ st.subheader("Image Understanding")
60
+ st.info(f"**Detected:** {cap}")
61
 
62
+ # —––––––– Generate story —–––––––
63
+ st.subheader("Story Creation")
64
+ prompt = f"""Create a children's story (3-10 years old) based on this description:
65
+
66
+ {cap}
67
+
68
+ Requirements:
69
+ - 50-100 words
70
+ - Playful and imaginative
71
+ - Positive message
72
+ - Simple vocabulary
73
+ - Include animal characters
74
+
75
+ Story:"""
76
+
77
+ with st.spinner("✍️ Crafting a magical story..."):
78
+ story_output = storyteller(
79
+ prompt,
80
+ max_length=300,
81
+ do_sample=True,
82
+ top_p=0.95,
83
+ temperature=0.85,
84
+ num_beams=4,
85
+ repetition_penalty=1.2
86
+ )
87
+ story = story_output[0]["generated_text"].strip()
88
+
89
+ st.success("**Generated Story:**")
90
+ st.write(story)
91
 
92
+ # —––––––– Text-to-Speech —–––––––
93
+ st.subheader("Audio Version")
94
+ with st.spinner("🔊 Generating audio..."):
95
+ try:
96
+ tts = gTTS(text=story, lang="en", slow=False)
97
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
98
+ tts.write_to_fp(tmp)
99
+ tmp_path = tmp.name
100
+
101
+ st.audio(tmp_path, format="audio/mp3")
102
+
103
+ # Add download button
104
+ with open(tmp_path, "rb") as f:
105
+ st.download_button(
106
+ label="Download Audio Story",
107
+ data=f,
108
+ file_name="kids_story.mp3",
109
+ mime="audio/mpeg"
110
+ )
111
+
112
+ finally:
113
+ if os.path.exists(tmp_path):
114
+ os.remove(tmp_path)
115
 
116
+ except Exception as e:
117
+ st.error(f"Error processing your request: {str(e)}")
 
 
 
 
 
118
 
119
+ if __name__ == "__main__":
120
+ main()