mayf commited on
Commit
dfb3989
·
verified ·
1 Parent(s): db1550f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -99
app.py CHANGED
@@ -1,120 +1,67 @@
 
 
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import pipeline
4
  from gtts import gTTS
5
  import tempfile
6
- import os
7
 
8
- # —––––––– Page config —–––––––
9
- st.set_page_config(
10
- page_title="Storyteller for Kids",
11
- page_icon="📚",
12
- layout="centered",
13
- initial_sidebar_state="collapsed"
14
- )
15
- st.title("🖼️➡️📖 Interactive Storyteller")
16
 
17
- # —––––––– Cache model loading —–––––––
18
  @st.cache_resource
19
  def load_pipelines():
20
- # Image-to-text pipeline
21
  captioner = pipeline(
22
  "image-to-text",
23
- model="Salesforce/blip-image-captioning-base",
24
- max_new_tokens=50
25
  )
26
-
27
- # Story generation pipeline with better parameters
28
  storyteller = pipeline(
29
  "text2text-generation",
30
- model="google/flan-t5-xxl",
31
- device_map="auto",
32
- model_kwargs={"load_in_8bit": True}
33
  )
34
-
35
  return captioner, storyteller
36
 
37
- # —––––––– Main workflow —–––––––
38
- def main():
39
- captioner, storyteller = load_pipelines()
40
-
41
- # —––––––– Image upload —–––––––
42
- uploaded = st.file_uploader(
43
- "Upload an image:",
44
- type=["jpg", "jpeg", "png"],
45
- help="Max size: 5MB"
46
- )
47
-
48
- if uploaded:
49
- try:
50
- # —––––––– Display image —–––––––
51
- image = Image.open(uploaded).convert("RGB")
52
- st.image(image, caption="Your Image", use_column_width=True)
53
-
54
- # —––––––– Generate caption —–––––––
55
- with st.spinner("🔍 Analyzing image content..."):
56
- cap_outputs = captioner(image)
57
- cap = cap_outputs[0].get("generated_text", "").strip()
58
-
59
- st.subheader("Image Understanding")
60
- st.info(f"**Detected:** {cap}")
61
 
62
- # —––––––– Generate story —–––––––
63
- st.subheader("Story Creation")
64
- prompt = f"""Create a children's story (3-10 years old) based on this description:
65
-
66
- {cap}
67
-
68
- Requirements:
69
- - 50-100 words
70
- - Playful and imaginative
71
- - Positive message
72
- - Simple vocabulary
73
- - Include animal characters
74
-
75
- Story:"""
76
-
77
- with st.spinner("✍️ Crafting a magical story..."):
78
- story_output = storyteller(
79
- prompt,
80
- max_length=300,
81
- do_sample=True,
82
- top_p=0.95,
83
- temperature=0.85,
84
- num_beams=4,
85
- repetition_penalty=1.2
86
- )
87
- story = story_output[0]["generated_text"].strip()
88
-
89
- st.success("**Generated Story:**")
90
- st.write(story)
91
 
92
- # —––––––– Text-to-Speech —–––––––
93
- st.subheader("Audio Version")
94
- with st.spinner("🔊 Generating audio..."):
95
- try:
96
- tts = gTTS(text=story, lang="en", slow=False)
97
- with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
98
- tts.write_to_fp(tmp)
99
- tmp_path = tmp.name
100
-
101
- st.audio(tmp_path, format="audio/mp3")
102
-
103
- # Add download button
104
- with open(tmp_path, "rb") as f:
105
- st.download_button(
106
- label="Download Audio Story",
107
- data=f,
108
- file_name="kids_story.mp3",
109
- mime="audio/mpeg"
110
- )
111
-
112
- finally:
113
- if os.path.exists(tmp_path):
114
- os.remove(tmp_path)
115
 
116
- except Exception as e:
117
- st.error(f"Error processing your request: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- if __name__ == "__main__":
120
- main()
 
 
 
 
 
 
1
+ # app.py
2
+
3
  import streamlit as st
4
  from PIL import Image
5
  from transformers import pipeline
6
  from gtts import gTTS
7
  import tempfile
 
8
 
9
+ # —––––––– Page config
10
+ st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
+ st.title("🖼️ ➡️ 📖 Interactive Storyteller")
 
 
 
 
 
12
 
13
+ # —––––––– Cache model loading
14
  @st.cache_resource
15
  def load_pipelines():
16
+ # 1) Image-to-text (captioning)
17
  captioner = pipeline(
18
  "image-to-text",
19
+ model="Salesforce/blip-image-captioning-base"
 
20
  )
21
+ # 2) Story generation with Flan-T5
 
22
  storyteller = pipeline(
23
  "text2text-generation",
24
+ model="google/flan-t5-base"
 
 
25
  )
 
26
  return captioner, storyteller
27
 
28
+ captioner, storyteller = load_pipelines()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # —––––––– Image upload
31
+ uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
32
+ if uploaded:
33
+ image = Image.open(uploaded).convert("RGB")
34
+ st.image(image, caption="Your image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # —––––––– 1. Caption
37
+ with st.spinner("🔍 Looking at the image..."):
38
+ cap_outputs = captioner(image)
39
+ # BLIP returns a list of dicts with key "generated_text"
40
+ cap = cap_outputs[0].get("generated_text", "").strip()
41
+ st.markdown(f"**Caption:** {cap}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # —––––––– 2. Story generation
44
+ prompt = (
45
+ "Write a playful, 50–100 word story for 3–10 year-old children "
46
+ f"based on this description:\n\n“{cap}”\n\nStory:"
47
+ )
48
+ with st.spinner("✍️ Writing a story..."):
49
+ out = storyteller(
50
+ prompt,
51
+ max_length=200,
52
+ do_sample=True,
53
+ top_p=0.9,
54
+ temperature=0.8,
55
+ num_return_sequences=1
56
+ )
57
+ story = out[0]["generated_text"].strip()
58
+ st.markdown("**Story:**")
59
+ st.write(story)
60
 
61
+ # —––––––– 3. Text-to-Speech
62
+ with st.spinner("🔊 Converting to speech..."):
63
+ tts = gTTS(story, lang="en")
64
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
65
+ tts.write_to_fp(tmp)
66
+ tmp.flush()
67
+ st.audio(tmp.name, format="audio/mp3")