zhang-3000 commited on
Commit
bb28d20
·
1 Parent(s): bd1f73e

Initial commit with blog generation code

Browse files
Files changed (2) hide show
  1. app.py +217 -0
  2. requirements.txt +176 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import json
4
+ import re
5
+ import requests
6
+ import streamlit as st
7
+
8
+ from lagent.agents import Agent
9
+ from lagent.prompts.parsers import PluginParser
10
+ from lagent.agents.stream import PLUGIN_CN, get_plugin_prompt
11
+ from lagent.schema import AgentMessage
12
+ from lagent.actions import ArxivSearch
13
+ from lagent.hooks import Hook
14
+ from lagent.llms import GPTAPI
15
+
16
+ YOUR_TOKEN_HERE = os.getenv("token")
17
+ if not YOUR_TOKEN_HERE:
18
+ raise EnvironmentError("未找到环境变量 'token',请设置后再运行程序。")
19
+
20
+ # Hook类,用于对消息添加前缀
21
+ class PrefixedMessageHook(Hook):
22
+ def __init__(self, prefix, senders=None):
23
+ """
24
+ 初始化Hook
25
+ :param prefix: 消息前缀
26
+ :param senders: 指定发送者列表
27
+ """
28
+ self.prefix = prefix
29
+ self.senders = senders or []
30
+
31
+ def before_agent(self, agent, messages, session_id):
32
+ """
33
+ 在代理处理消息前修改消息内容
34
+ :param agent: 当前代理
35
+ :param messages: 消息列表
36
+ :param session_id: 会话ID
37
+ """
38
+ for message in messages:
39
+ if message.sender in self.senders:
40
+ message.content = self.prefix + message.content
41
+
42
+ class AsyncBlogger:
43
+ """博客生成类,整合写作者和批评者。"""
44
+
45
+ def __init__(self, model_type, api_base, writer_prompt, critic_prompt, critic_prefix='', max_turn=2):
46
+ """
47
+ 初始化博客生成器
48
+ :param model_type: 模型类型
49
+ :param api_base: API 基地址
50
+ :param writer_prompt: 写作者提示词
51
+ :param critic_prompt: 批评者提示词
52
+ :param critic_prefix: 批评消息前缀
53
+ :param max_turn: 最大轮次
54
+ """
55
+ self.model_type = model_type
56
+ self.api_base = api_base
57
+ self.llm = GPTAPI(
58
+ model_type=model_type,
59
+ api_base=api_base,
60
+ key=YOUR_TOKEN_HERE,
61
+ max_new_tokens=4096,
62
+ )
63
+ self.plugins = [dict(type='lagent.actions.ArxivSearch')]
64
+ self.writer = Agent(
65
+ self.llm,
66
+ writer_prompt,
67
+ name='写作者',
68
+ output_format=dict(
69
+ type=PluginParser,
70
+ template=PLUGIN_CN,
71
+ prompt=get_plugin_prompt(self.plugins)
72
+ )
73
+ )
74
+ self.critic = Agent(
75
+ self.llm,
76
+ critic_prompt,
77
+ name='批评者',
78
+ hooks=[PrefixedMessageHook(critic_prefix, ['写作者'])]
79
+ )
80
+ self.max_turn = max_turn
81
+
82
+ async def forward(self, message: AgentMessage, update_placeholder):
83
+ """
84
+ 执行多阶段博客生成流程
85
+ :param message: 初始消息
86
+ :param update_placeholder: Streamlit占位符
87
+ :return: 最终优化的博客内容
88
+ """
89
+ step1_placeholder = update_placeholder.container()
90
+ step2_placeholder = update_placeholder.container()
91
+ step3_placeholder = update_placeholder.container()
92
+
93
+ # 第一步:生成初始内容
94
+ step1_placeholder.markdown("**Step 1: 生成初始内容...**")
95
+ message = self.writer(message)
96
+ if message.content:
97
+ step1_placeholder.markdown(f"**生成的初始内容**:\n\n{message.content}")
98
+ else:
99
+ step1_placeholder.markdown("**生成的初始内容为空,请检查生成逻辑。**")
100
+
101
+ # 第二步:批评者提供反馈
102
+ step2_placeholder.markdown("**Step 2: 批评者正在提供反馈和文献推荐...**")
103
+ message = self.critic(message)
104
+ if message.content:
105
+ # 解析批评者反馈
106
+ suggestions = re.search(r"批评建议:\s*(.*?)\s*推荐的关键词:", message.content, re.S)
107
+ if suggestions:
108
+ feedback = suggestions.group(1).strip()
109
+ else:
110
+ feedback = "未提供批评建议"
111
+
112
+
113
+ keywords = re.findall(r"-\s([a-zA-Z\-]+)", message.content)
114
+ feedback = suggestions.group(1).strip() if suggestions else "未提供批评建议"
115
+
116
+ # 打印提取的反馈和关键词
117
+ st.write(f"提取的批评建议:\n{feedback}")
118
+ st.write(f"提取的推荐关键词:\n{keywords}")
119
+
120
+ english_keywords = [kw.strip() for kw in keywords if re.match(r'^[a-zA-Z\-]+$', kw.strip())]
121
+
122
+ # Arxiv 文献查询
123
+ arxiv_search = ArxivSearch()
124
+ arxiv_results = arxiv_search.get_arxiv_article_information(', '.join(english_keywords))
125
+
126
+ # 显示批评内容和文献推荐
127
+ message.content = f"**批评建议**:\n{feedback}\n\n**推荐的文献**:\n{arxiv_results}"
128
+ step2_placeholder.markdown(f"**批评和文献推荐**:\n\n{message.content}")
129
+ else:
130
+ step2_placeholder.markdown("**批评内容为空,请检查批评逻辑。**")
131
+
132
+
133
+
134
+ # 第三步:写作者根据反馈优化内容
135
+ step3_placeholder.markdown("**Step 3: 根据反馈改进内容...**")
136
+ improvement_prompt = AgentMessage(
137
+ sender="critic",
138
+ content=(
139
+ f"根据以下批评建议和推荐文献对内容进行改进:\n\n"
140
+ f"批评建议:\n{feedback}\n\n"
141
+ f"推荐文献:\n{arxiv_results}\n\n"
142
+ f"请优化初始内容,使其更加清晰、丰富,并符合专业水准。"
143
+ ),
144
+ )
145
+ message = self.writer(improvement_prompt)
146
+ if message.content:
147
+ step3_placeholder.markdown(f"**最终优化的博客内容**:\n\n{message.content}")
148
+ else:
149
+ step3_placeholder.markdown("**最终优化的博客内容为空,请检查生成逻辑。**")
150
+
151
+ return message
152
+
153
+ def setup_sidebar():
154
+ """设置侧边栏,选择模型。"""
155
+ model_name = st.sidebar.text_input('模型名称:', value='internlm2.5-latest')
156
+ api_base = st.sidebar.text_input(
157
+ 'API Base 地址:', value='https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions'
158
+ )
159
+
160
+ return model_name, api_base
161
+
162
+ def main():
163
+ """
164
+ 主函数:构建Streamlit界面并处理用户交互
165
+ """
166
+ st.set_page_config(layout='wide', page_title='Lagent Web Demo', page_icon='🤖')
167
+ st.title("多代理博客优化助手")
168
+
169
+ model_type, api_base = setup_sidebar()
170
+ topic = st.text_input('输入一个话题:', 'Self-Supervised Learning')
171
+ generate_button = st.button('生成博客内容')
172
+
173
+ if (
174
+ 'blogger' not in st.session_state or
175
+ st.session_state['model_type'] != model_type or
176
+ st.session_state['api_base'] != api_base
177
+ ):
178
+ st.session_state['blogger'] = AsyncBlogger(
179
+ model_type=model_type,
180
+ api_base=api_base,
181
+ writer_prompt="你是一位优秀的AI内容写作者,请撰写一篇有吸引力且信息丰富的博客内容。",
182
+ critic_prompt="""
183
+ 作为一位严谨的批评者,请给出建设性的批评和改进建议,并基于相关主题使用已有的工具推荐一些参考文献,推荐的关键词应该是英语形式,简洁且切题。
184
+ 请按照以下格式提供反馈:
185
+ 1. 批评建议:
186
+ - (具体建议)
187
+ 2. 推荐的关键词:
188
+ - (关键词1)
189
+ - (关键词2)
190
+ - (...)
191
+ 3. 推荐的文献:
192
+ - (文献1)
193
+ - (文献2)
194
+ - (...)
195
+ """,
196
+ critic_prefix="请批评以下内容,并提供改进建议:\n\n"
197
+ )
198
+ st.session_state['model_type'] = model_type
199
+ st.session_state['api_base'] = api_base
200
+
201
+ if generate_button:
202
+ update_placeholder = st.empty()
203
+
204
+ async def run_async_blogger():
205
+ message = AgentMessage(
206
+ sender='user',
207
+ content=f"请撰写一篇关于{topic}的博客文章,要求表达专业,生动有趣,并且易于理解。"
208
+ )
209
+ result = await st.session_state['blogger'].forward(message, update_placeholder)
210
+ return result
211
+
212
+ loop = asyncio.new_event_loop()
213
+ asyncio.set_event_loop(loop)
214
+ loop.run_until_complete(run_async_blogger())
215
+
216
+ if __name__ == '__main__':
217
+ main()
requirements.txt ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.4
2
+ aiohttp==3.11.11
3
+ aiosignal==1.3.2
4
+ altair==5.5.0
5
+ annotated-types==0.7.0
6
+ anyio==4.7.0
7
+ argon2-cffi==23.1.0
8
+ argon2-cffi-bindings==21.2.0
9
+ arrow==1.3.0
10
+ arxiv==2.1.3
11
+ asttokens==3.0.0
12
+ async-lru==2.0.4
13
+ async-timeout==5.0.1
14
+ asyncache==0.3.1
15
+ asyncer==0.0.8
16
+ attrs==24.3.0
17
+ babel==2.16.0
18
+ backports.strenum==1.3.1
19
+ beautifulsoup4==4.12.3
20
+ bleach==6.2.0
21
+ blinker==1.9.0
22
+ Brotli @ file:///croot/brotli-split_1714483155106/work
23
+ cachetools==5.5.0
24
+ certifi @ file:///croot/certifi_1734473278428/work/certifi
25
+ cffi==1.17.1
26
+ charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
27
+ class-registry==2.1.2
28
+ click==8.1.8
29
+ colorama==0.4.6
30
+ comm==0.2.2
31
+ datasets==3.1.0
32
+ debugpy==1.8.11
33
+ decorator==5.1.1
34
+ defusedxml==0.7.1
35
+ dill==0.3.8
36
+ distro==1.9.0
37
+ duckduckgo_search==5.3.1b1
38
+ exceptiongroup==1.2.2
39
+ executing==2.1.0
40
+ fastjsonschema==2.21.1
41
+ feedparser==6.0.11
42
+ filelock @ file:///croot/filelock_1700591183607/work
43
+ fqdn==1.5.1
44
+ frozenlist==1.5.0
45
+ fsspec==2024.9.0
46
+ func_timeout==4.3.5
47
+ gitdb==4.0.11
48
+ GitPython==3.1.43
49
+ gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work
50
+ griffe==0.48.0
51
+ h11==0.14.0
52
+ h2==4.1.0
53
+ hpack==4.0.0
54
+ httpcore==1.0.7
55
+ httpx==0.28.1
56
+ huggingface-hub==0.27.0
57
+ hyperframe==6.0.1
58
+ idna @ file:///croot/idna_1714398848350/work
59
+ ipykernel==6.29.5
60
+ ipython==8.31.0
61
+ ipywidgets==8.1.5
62
+ isoduration==20.11.0
63
+ jedi==0.19.2
64
+ Jinja2 @ file:///croot/jinja2_1730902924303/work
65
+ jiter==0.8.2
66
+ json5==0.10.0
67
+ jsonpointer==3.0.0
68
+ jsonschema==4.23.0
69
+ jsonschema-specifications==2024.10.1
70
+ jupyter==1.0.0
71
+ jupyter-console==6.6.3
72
+ jupyter-events==0.11.0
73
+ jupyter-lsp==2.2.5
74
+ jupyter_client==8.6.2
75
+ jupyter_core==5.7.2
76
+ jupyter_server==2.15.0
77
+ jupyter_server_terminals==0.5.3
78
+ jupyterlab==4.3.4
79
+ jupyterlab_pygments==0.3.0
80
+ jupyterlab_server==2.27.3
81
+ jupyterlab_widgets==3.0.13
82
+ -e git+https://github.com/InternLM/lagent.git@e304e5d323cdbb631257fac9187d16b99476bc2f#egg=lagent
83
+ markdown-it-py==3.0.0
84
+ MarkupSafe @ file:///croot/markupsafe_1704205993651/work
85
+ matplotlib-inline==0.1.7
86
+ mdurl==0.1.2
87
+ mistune==3.0.2
88
+ mkl-service==2.4.0
89
+ mkl_fft @ file:///io/mkl313/mkl_fft_1730824109137/work
90
+ mkl_random @ file:///io/mkl313/mkl_random_1730823916628/work
91
+ mpmath @ file:///croot/mpmath_1690848262763/work
92
+ multidict==6.1.0
93
+ multiprocess==0.70.16
94
+ narwhals==1.19.1
95
+ nbclient==0.10.2
96
+ nbconvert==7.16.4
97
+ nbformat==5.10.4
98
+ nest-asyncio==1.6.0
99
+ networkx @ file:///croot/networkx_1717597493534/work
100
+ notebook==7.3.2
101
+ notebook_shim==0.2.4
102
+ numpy @ file:///croot/numpy_and_numpy_base_1725470312869/work/dist/numpy-2.0.1-cp310-cp310-linux_x86_64.whl#sha256=120568f3fd675f59e4cb6de79a5c193b4067d90878d8bab2040fda8e3a2df6fa
103
+ overrides==7.7.0
104
+ packaging==24.2
105
+ pandas==2.2.3
106
+ pandocfilters==1.5.1
107
+ parso==0.8.4
108
+ pexpect==4.9.0
109
+ pillow==10.4.0
110
+ platformdirs==4.3.6
111
+ prometheus_client==0.21.1
112
+ prompt_toolkit==3.0.48
113
+ propcache==0.2.1
114
+ protobuf==5.29.2
115
+ psutil==6.1.1
116
+ ptyprocess==0.7.0
117
+ pure_eval==0.2.3
118
+ pyarrow==18.1.0
119
+ pycparser==2.22
120
+ pydantic==2.6.4
121
+ pydantic_core==2.16.3
122
+ pydeck==0.9.1
123
+ Pygments==2.18.0
124
+ PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work
125
+ python-dateutil==2.9.0.post0
126
+ python-json-logger==3.2.1
127
+ pytz==2024.2
128
+ PyYAML @ file:///croot/pyyaml_1728657952215/work
129
+ pyzmq==26.2.0
130
+ qtconsole==5.6.1
131
+ QtPy==2.4.2
132
+ referencing==0.35.1
133
+ regex==2024.11.6
134
+ requests @ file:///croot/requests_1730999120400/work
135
+ rfc3339-validator==0.1.4
136
+ rfc3986-validator==0.1.1
137
+ rich==13.9.4
138
+ rpds-py==0.22.3
139
+ Send2Trash==1.8.3
140
+ sgmllib3k==1.0.0
141
+ six==1.17.0
142
+ smmap==5.0.1
143
+ sniffio==1.3.1
144
+ socksio==1.0.0
145
+ soupsieve==2.6
146
+ stack-data==0.6.3
147
+ streamlit==1.39.0
148
+ sympy @ file:///croot/sympy_1734622612703/work
149
+ tenacity==9.0.0
150
+ termcolor==2.4.0
151
+ terminado==0.18.1
152
+ tiktoken==0.8.0
153
+ timeout-decorator==0.5.0
154
+ tinycss2==1.4.0
155
+ toml==0.10.2
156
+ tomli==2.2.1
157
+ torch==2.1.2
158
+ torchaudio==2.1.2
159
+ torchvision==0.16.2
160
+ tornado==6.4.2
161
+ tqdm==4.67.1
162
+ traitlets==5.14.3
163
+ triton==2.1.0
164
+ types-python-dateutil==2.9.0.20241206
165
+ typing_extensions @ file:///croot/typing_extensions_1734714854207/work
166
+ tzdata==2024.2
167
+ uri-template==1.3.0
168
+ urllib3 @ file:///croot/urllib3_1727769808118/work
169
+ watchdog==5.0.3
170
+ wcwidth==0.2.13
171
+ webcolors==24.11.1
172
+ webencodings==0.5.1
173
+ websocket-client==1.8.0
174
+ widgetsnbextension==4.0.13
175
+ xxhash==3.5.0
176
+ yarl==1.18.3