Leo Liu commited on
Commit
3ae2d5d
·
verified ·
1 Parent(s): 59e88fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -104
app.py CHANGED
@@ -2,148 +2,163 @@
2
  import streamlit as st
3
  from transformers import pipeline
4
  import math
 
5
 
6
  # function part
7
- # 时间戳
8
- def split_story_with_delay(story_text, sampling_rate=16000):
9
- """将故事分割为带时间戳的段落"""
10
- words = story_text.split()
11
- chunk_size = max(1, len(words)//5) # 按词数均分5段
12
- chunks = [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
13
- # 假设每段持续2秒(根据音频长度动态调整更佳)
14
- duration = len(audio_data["audio"]) / sampling_rate
15
- chunk_duration = duration / len(chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  return list(zip(chunks, [chunk_duration]*len(chunks)))
17
 
18
- # img2text
19
  def img2text(url):
20
- image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
21
- text = image_to_text_model(url)[0]["generated_text"]
 
 
22
  return text
23
 
24
- # text2story
25
  def text2story(text):
26
- # 添加儿童故事专用prompt模板
27
- prompt = f"""Generate a VERY SHORT fairy tale for children aged 3-10 based on: {text}
28
- Story must:
29
- 1. Have animal/fairy characters
30
- 2. Teach kindness or courage
31
- 3. Use simple words
32
- 4. Be 50-100 words
33
- Story:"""
34
-
35
- pipe = pipeline(
36
- "text-generation",
37
- model="pranavpsv/genre-story-generator-v2",
38
- # 优化生成参数
39
- max_new_tokens=150, # 严格控制输出长度
40
- min_new_tokens=50, # 确保最低字数
41
- do_sample=True,
42
- temperature=0.7, # 平衡创意与连贯性
43
- top_k=40, # 加速生成
44
- top_p=0.9,
45
- repetition_penalty=1.2,
46
- num_return_sequences=1 # 减少计算量
47
- )
48
-
49
- # 生成后处理
50
- raw_story = pipe(prompt)[0]['generated_text']
51
 
52
- # 提取核心故事内容(过滤prompt重复)
53
- story = raw_story.split("Story:")[-1].strip()
 
 
 
 
 
 
 
 
 
54
 
55
- # 精确截断至150字(中文按字符计算)
56
- return ' '.join(story.split()[:150]) if len(story) > 150 else story
 
57
 
58
- # text2audio
59
  def text2audio(story_text):
60
- pipe = pipeline("text-to-audio", model="Matthijs/mms-tts-eng")
61
- audio_data = pipe(story_text)
62
- return audio_data
63
-
 
 
 
 
64
 
65
  def main():
66
- st.set_page_config(page_title="Magic Storyteller", page_icon="🧚")
67
 
68
- # Optimize title area to attract children's attention
69
  st.markdown("""
70
  <style>
71
  @import url('https://fonts.googleapis.com/css2?family=Comic+Neue:wght@700&display=swap');
72
  .header {
73
- background-image: url('https://huggingface.co/spaces/Leo0129/classagm/resolve/main/background.jpg');
74
- background-size: cover;
75
  border-radius: 15px;
76
  padding: 2rem;
77
  text-align: center;
78
  box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
 
79
  }
80
- .header h1 {
81
- color: #FF9A6C;
82
  font-family: 'Comic Neue', cursive;
83
- font-size: 2.5rem;
84
- text-shadow: 2px 2px #FFF;
85
- margin-bottom: 0.5rem !important;
 
 
 
 
86
  }
87
-
88
  </style>
89
- """, unsafe_allow_html=True)
90
 
91
  st.markdown("""
92
  <div class="header">
93
- <h1>🪄 Magic Storyteller </h1>
 
94
  </div>
95
  """, unsafe_allow_html=True)
96
- uploaded_file = st.file_uploader("🌈 Choose your magic picture...", type=["jpg", "png"])
97
-
98
- if uploaded_file is not None:
99
- bytes_data = uploaded_file.getvalue()
100
- with open(uploaded_file.name, "wb") as file:
101
- file.write(bytes_data)
102
- st.image(uploaded_file, caption="Your Magic Picture ✨", use_container_width=True)
103
- status_container = st.empty()
104
- progress_bar = st.progress(0)
105
 
106
- # Stage 1: Image to Text
107
- with status_container.status("🔮 **Step 1/3**: Decoding picture magic...", expanded=True) as status: # 保持缩进
108
- progress_bar.progress(33)
109
- scenario = img2text(uploaded_file.name)
110
- status.update(label="✅ Picture decoded!", state="complete")
111
- st.write(f"**What I see:** {scenario}")
112
 
113
- #Stage 2: Text to Story
114
- with status_container.status("📚 **Step 2/3**: Writing your fairy tale...", expanded=True) as status:
115
- progress_bar.progress(66)
 
 
 
 
 
 
 
 
 
 
116
  story = text2story(scenario)
117
- status.update(label=" Story created!", state="complete")
118
- st.write(f"**Your Story:**\n{story}")
119
-
120
- #Stage 3: Story to Audio data
121
- with status_container.status("🎵 **Step 3/3**: Adding magic audio...", expanded=True) as status:
122
- progress_bar.progress(100)
123
- audio_data = text2audio(story)
124
 
125
- # 新增字幕处理
126
- subtitle_chunks = split_story_with_delay(story, audio_data['sampling_rate'])
127
- current_subtitle = st.empty()
 
128
 
129
- status.update(label=" Start playing the story!", state="complete")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # 播放音频并更新字幕
132
- with st.audio(...): # 保持原有参数
133
- for text, duration in subtitle_chunks:
134
- current_subtitle.markdown(f"""
135
- <div style="
136
- background: rgba(255,255,255,0.9);
137
- padding: 1rem;
138
- border-radius: 10px;
139
- margin: 1rem 0;
140
- font-size: 1.2rem;
141
- color: #FF6B6B;
142
- text-align: center;
143
- font-family: 'Comic Neue', cursive;
144
- ">{text}</div>
145
- """, unsafe_allow_html=True)
146
- time.sleep(duration) # 需import time
147
 
148
  if __name__ == "__main__":
149
  main()
 
2
  import streamlit as st
3
  from transformers import pipeline
4
  import math
5
+ import time # 新增time模块
6
 
7
  # function part
8
+ def split_story_with_delay(story_text, total_duration, num_chunks=5):
9
+ """将故事分割为带时间戳的段落(优化版)"""
10
+ # 按句号分割更符合自然段落
11
+ sentences = [s.strip() for s in story_text.split('. ') if s]
12
+ if not sentences:
13
+ return [(story_text, total_duration)]
14
+
15
+ # 动态计算分段数量(每段最多2句话)
16
+ chunk_size = max(1, min(2, len(sentences)//num_chunks))
17
+ chunks = []
18
+ current_chunk = []
19
+
20
+ for sent in sentences:
21
+ current_chunk.append(sent)
22
+ if len(current_chunk) >= chunk_size:
23
+ chunks.append('. '.join(current_chunk) + '.')
24
+ current_chunk = []
25
+
26
+ if current_chunk:
27
+ chunks.append('. '.join(current_chunk) + '.')
28
+
29
+ # 计算每段持续时间
30
+ chunk_duration = total_duration / len(chunks)
31
  return list(zip(chunks, [chunk_duration]*len(chunks)))
32
 
 
33
  def img2text(url):
34
+ # 添加进度提示
35
+ with st.spinner("🖼️ Analyzing the magic picture..."):
36
+ image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
37
+ text = image_to_text_model(url)[0]["generated_text"]
38
  return text
39
 
 
40
  def text2story(text):
41
+ # 优化prompt模板
42
+ prompt = f"""Create a magical children's story (for ages 3-8) based on: {text}
43
+ Story requirements:
44
+ 🐰 Animal/Fantasy characters
45
+ 🎁 Simple moral lesson
46
+ 🌈 Vivid descriptions
47
+ 80-120 words
48
+ 🌼 Use dialog between characters
49
+
50
+ Magical story begins:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # 添加模型加载进度
53
+ with st.spinner("📖 Brewing story magic..."):
54
+ pipe = pipeline(
55
+ "text-generation",
56
+ model="pranavpsv/genne-story-generator-v2",
57
+ max_new_tokens=200,
58
+ temperature=0.8,
59
+ top_p=0.95,
60
+ repetition_penalty=1.1
61
+ )
62
+ raw_story = pipe(prompt)[0]['generated_text']
63
 
64
+ # 优化故事提取逻辑
65
+ story = raw_story.split("Magical story begins:")[-1].strip()
66
+ return story[:500] # 确保长度限制
67
 
 
68
  def text2audio(story_text):
69
+ # 添加音频生成进度
70
+ with st.spinner("🔊 Mixing audio potion..."):
71
+ pipe = pipeline("text-to-audio", model="Matthijs/mms-tts-eng")
72
+ audio_data = pipe(story_text, return_tensors="pt") # 优化内存使用
73
+ return {
74
+ "array": audio_data["audio"][0].numpy(),
75
+ "sampling_rate": audio_data["sampling_rate"]
76
+ }
77
 
78
  def main():
79
+ st.set_page_config(page_title="Magic Storyteller", page_icon="🧚", layout="wide")
80
 
81
+ # 优化UI样式
82
  st.markdown("""
83
  <style>
84
  @import url('https://fonts.googleapis.com/css2?family=Comic+Neue:wght@700&display=swap');
85
  .header {
86
+ background: linear-gradient(45deg, #FF9A6C, #FF6B6B);
 
87
  border-radius: 15px;
88
  padding: 2rem;
89
  text-align: center;
90
  box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
91
+ margin-bottom: 2rem;
92
  }
93
+ .subtitle {
 
94
  font-family: 'Comic Neue', cursive;
95
+ color: #4B4B4B;
96
+ font-size: 1.2rem;
97
+ margin: 1rem 0;
98
+ padding: 1rem;
99
+ background: rgba(255,255,255,0.9);
100
+ border-radius: 10px;
101
+ border-left: 5px solid #FF6B6B;
102
  }
 
103
  </style>
104
+ """, unsafe_allow_html=True)
105
 
106
  st.markdown("""
107
  <div class="header">
108
+ <h1 style='margin:0;'>🪄 Magic Storyteller</h1>
109
+ <p style='color: white; font-size: 1.2rem;'>Turn your pictures into magical stories!</p>
110
  </div>
111
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
112
 
113
+ uploaded_file = st.file_uploader("🌈 Choose your magic picture...", type=["jpg", "png"])
 
 
 
 
 
114
 
115
+ if uploaded_file:
116
+ with st.expander(" Your Magic Picture", expanded=True):
117
+ st.image(uploaded_file, use_column_width=True)
118
+
119
+ # 流程进度管理
120
+ with st.status("🔮 Story Creation Progress", expanded=True) as status:
121
+ # Stage 1
122
+ st.subheader("Step 1: Decoding Picture Magic")
123
+ scenario = img2text(uploaded_file)
124
+ st.success(f"**Discovered Magic:** {scenario}")
125
+
126
+ # Stage 2
127
+ st.subheader("Step 2: Brewing Story Potion")
128
  story = text2story(scenario)
129
+ st.success("**Magical Story Created!**")
 
 
 
 
 
 
130
 
131
+ # Stage 3
132
+ st.subheader("Step 3: Mixing Audio Spell")
133
+ audio_data = text2audio(story)
134
+ st.success("**Audio Potion Ready!**")
135
 
136
+ status.update(label="🎉 All Magic Complete!", state="complete")
137
+
138
+ # 故事展示区域
139
+ with st.container():
140
+ st.subheader("📖 Your Magical Story")
141
+ st.write(story)
142
+
143
+ # 音频播放与字幕
144
+ st.subheader("🎧 Story Audio")
145
+ st.audio(
146
+ audio_data["array"],
147
+ sample_rate=audio_data["sampling_rate"]
148
+ )
149
+
150
+ # 字幕显示(静态版本)
151
+ st.subheader("📜 Story Subtitles")
152
+ total_duration = len(audio_data["array"]) / audio_data["sampling_rate"]
153
+ subtitle_chunks = split_story_with_delay(story, total_duration)
154
 
155
+ for idx, (text, duration) in enumerate(subtitle_chunks, 1):
156
+ st.markdown(f"""
157
+ <div class="subtitle">
158
+ <span style='color: #FF6B6B; font-size: 1.4rem;'>✨ Part {idx}:</span>
159
+ {text}
160
+ </div>
161
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
162
 
163
  if __name__ == "__main__":
164
  main()