Leo Liu commited on
Commit
15e4fc0
·
verified ·
1 Parent(s): fae883c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -23
app.py CHANGED
@@ -3,6 +3,7 @@ import streamlit as st
3
  from transformers import pipeline
4
  from gtts import gTTS
5
  import io
 
6
 
7
 
8
  # function part
@@ -15,14 +16,20 @@ def img2text(url):
15
 
16
  # text2story
17
  def text2story(text):
18
- # 定义提示词模板(包含变量占位符)
19
- prompt_template = f"""Write a children's story for ages 3-10 based on: {text}
20
- Requirements:
21
- - Use simple words
22
- - Include a happy ending
23
- """
 
 
 
 
 
 
24
 
25
- # 填充模板中的变量
26
  full_prompt = prompt_template.format(text=text)
27
 
28
  # 初始化生成管道
@@ -31,29 +38,29 @@ def text2story(text):
31
  model="pranavpsv/genre-story-generator-v2",
32
  max_new_tokens=180,
33
  min_new_tokens=130,
34
- temperature=0.7
 
35
  )
36
 
37
  # 生成原始文本
38
  raw_output = pipe(full_prompt, return_full_text=False)[0]['generated_text']
39
 
40
- # 增强版提示词移除功能
41
- def clean_output(generated_text, prompt):
42
- # 方法1:精确匹配移除
43
- if generated_text.startswith(prompt):
44
- return generated_text[len(prompt):].strip()
45
-
46
- # 方法2:正则表达式模糊匹配
47
- import re
48
- pattern = re.compile(r'Write a children\'s story.*?based on:.*?\n', re.DOTALL)
49
- cleaned = re.sub(pattern, '', generated_text, count=1)
50
 
51
- # 移除残留的提示词片段
52
- cleaned = cleaned.split("Requirements:")[0].strip()
53
- return cleaned
 
54
 
55
- # 返回处理后的干净文本
56
- return clean_output(raw_output, full_prompt)
57
 
58
 
59
  # text2audio
 
3
  from transformers import pipeline
4
  from gtts import gTTS
5
  import io
6
+ import re
7
 
8
 
9
  # function part
 
16
 
17
  # text2story
18
  def text2story(text):
19
+ # 优化提示词模板(使用统一的分隔符)
20
+ prompt_template = """[PROMPT_START]
21
+ Write a children's story for ages 3-10 based on: {text}
22
+
23
+ Requirements:
24
+ 1. Use simple words (1st-3rd grade level)
25
+ 2. Main character must be an animal
26
+ 3. Include magic elements
27
+ 4. Have a happy ending
28
+ 5. Story length: 100-120 words
29
+ [PROMPT_END]
30
+ """
31
 
32
+ # 生成完整提示词(避免重复插入)
33
  full_prompt = prompt_template.format(text=text)
34
 
35
  # 初始化生成管道
 
38
  model="pranavpsv/genre-story-generator-v2",
39
  max_new_tokens=180,
40
  min_new_tokens=130,
41
+ temperature=0.7,
42
+ pad_token_id=50256
43
  )
44
 
45
  # 生成原始文本
46
  raw_output = pipe(full_prompt, return_full_text=False)[0]['generated_text']
47
 
48
+ # 增强版清洗逻辑
49
+ def clean_output(generated_text):
50
+ # 使用正则表达式匹配提示词块
51
+ prompt_pattern = re.compile(
52
+ r'\[PROMPT_START\].*?\[PROMPT_END\]',
53
+ re.DOTALL # 匹配多行内容
54
+ )
55
+ # 移除整个提示词块
56
+ cleaned = re.sub(prompt_pattern, '', generated_text)
 
57
 
58
+ # 二次清理残留内容
59
+ cleaned = re.sub(r'^Write a children.*?\n', '', cleaned) # 处理可能的开头残留
60
+ cleaned = re.sub(r'Requirements:.*?\n', '', cleaned) # 移除要求残留
61
+ return cleaned.strip()
62
 
63
+ return clean_output(raw_output)
 
64
 
65
 
66
  # text2audio