xzyun2011 commited on
Commit
f199dd6
·
1 Parent(s): 06c933c

add agent code final

Browse files
Files changed (1) hide show
  1. agent/wulewule_agent.py +303 -0
agent/wulewule_agent.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ import requests
4
+ from typing import List, Dict, Any, Optional, Iterator
5
+ from PIL import Image
6
+ import re
7
+ import torch
8
+
9
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
10
+ # from llama_index.core.postprocessor import LLMRerank
11
+ from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolMetadata
12
+ from llama_index.core.agent import ReActAgent
13
+
14
+ from openai import OpenAI
15
+
16
+ import sys
17
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+
19
+
20
+ import streamlit as st
21
+
22
+
23
+ class PromptEngineerAgent:
24
+ """专门用于优化提示词的代理"""
25
+ def __init__(self, llm):
26
+ self.llm = llm
27
+
28
+ def optimize_image_prompt(self, user_input: str) -> str:
29
+ """
30
+ 将用户的图像需求转换为优化的stable-diffusion提示词
31
+ """
32
+ prompt_template = f"""
33
+ 请将以下用户的图像需求转换为stable-diffusion所需的文生图提示词。
34
+
35
+ 用户需求: {user_input}
36
+
37
+ 请生成一个优化的英文提示词,格式要求:
38
+ 1. 使用详细的描述性语言
39
+ 2. 包含具体的艺术风格
40
+ 3. 说明构图和视角
41
+ 4. 描述光影和氛围
42
+ 5. 添加相关的艺术家参考或风格类型
43
+
44
+ 提示词:
45
+ """
46
+
47
+ response = self.llm.complete(prompt_template)
48
+ return str(response)
49
+
50
+ def optimize_voice_prompt(self, user_input: str) -> Dict[str, str]:
51
+ """
52
+ 优化语音合成的参数
53
+ """
54
+ prompt_template = f"""
55
+ 请分析以下文本,并提供优化的语音合成参数。
56
+
57
+ 文本: {user_input}
58
+
59
+ 请考虑:
60
+ 1. 最适合的语言
61
+ 2. 说话的语速
62
+ 3. 语气特点
63
+ 4. 情感色彩
64
+
65
+ 以JSON格式返回参数:
66
+ """
67
+
68
+ response = self.llm.complete(prompt_template)
69
+ try:
70
+ params = eval(str(response))
71
+ return params
72
+ except:
73
+ return {"lang": "zh", "speed": 1.0}
74
+
75
+
76
+ class MultiModalAssistant:
77
+ def __init__(self, data_source_dir, llm, api_key):
78
+ """
79
+ 初始化助手,设置必要的API密钥和加载文档
80
+ """
81
+
82
+ # 初始化LLM
83
+ self.llm = llm
84
+ self.__api_key = api_key
85
+ # 初始化Prompt Engineer Agent
86
+ self.prompt_engineer = PromptEngineerAgent(self.llm)
87
+
88
+ # 加载文档并创建索引
89
+ documents = SimpleDirectoryReader(data_source_dir, recursive=False, required_exts=[".txt"]).load_data()
90
+ self.index = VectorStoreIndex.from_documents(
91
+ documents
92
+ )
93
+
94
+ # 创建rag 用于回答知识问题
95
+ self.query_engine = self.index.as_query_engine(similarity_top_k=3)
96
+
97
+ # 创建rag+reranker用于回答知识问题
98
+ # self.query_engine = self.index.as_query_engine(similarity_top_k=3,
99
+ # node_postprocessors=[
100
+ # LLMRerank(
101
+ # choice_batch_size=5,
102
+ # top_n=2,
103
+ # )],
104
+ # response_mode="tree_summarize",)
105
+ # 设置工具
106
+ tools = [
107
+ FunctionTool.from_defaults(
108
+ fn=self.rag_query,
109
+ name="rag_tool",
110
+ description="无法直接回答时,查询和《黑神话:悟空》有关知识的工具"
111
+ ),
112
+ FunctionTool.from_defaults(
113
+ fn=self.text_to_speech,
114
+ name="tts_tool",
115
+ description="将文本转换为语音的工具"
116
+ ),
117
+ FunctionTool.from_defaults(
118
+ fn=self.generate_image,
119
+ name="image_tool",
120
+ description="生成图像的工具"
121
+ )
122
+ ]
123
+
124
+ # 初始化Agent
125
+ self.agent = ReActAgent.from_tools(
126
+ tools,
127
+ llm=self.llm,
128
+ verbose=True,
129
+ max_function_calls=5,
130
+ )
131
+
132
+ ## 画图的url
133
+ self.image_url = None
134
+ self.audio_save_file = "audio.mp3"
135
+ self.audio_text = None
136
+
137
+ def rag_query(self, query: str) -> str:
138
+ """
139
+ 使用RAG系统查询知识库
140
+ """
141
+ response = self.query_engine.query(query)
142
+ return str(response)
143
+
144
+ def text_to_speech(self, text: str) -> str:
145
+ """
146
+ 将文本转换为语音
147
+ """
148
+ if not self.audio_text is None:
149
+ print(f"文本已转为语音: {self.audio_text}")
150
+ return
151
+ try:
152
+ client = OpenAI( api_key = self.__api_key, base_url="https://api.siliconflow.cn/v1")
153
+
154
+ with client.audio.speech.with_streaming_response.create(
155
+ model="fishaudio/fish-speech-1.5", # 目前仅支持 fishaudio 系列模型
156
+ voice="fishaudio/fish-speech-1.5:benjamin", # 系统预置音色
157
+ # 用户输入信息 "孙悟空身穿金色战甲,手持金箍棒,眼神锐利"
158
+ input=f"{text}",
159
+ response_format="mp3" # 支持 mp3, wav, pcm, opus 格式
160
+ ) as response:
161
+ response.stream_to_file(self.audio_save_file)
162
+
163
+ if response.status_code == 200:
164
+ self.audio_text = text
165
+ print(f"文本已转为语音: {self.audio_save_file}")
166
+ # return f"文本转语音已完成。"
167
+ else:
168
+ print("文本转语音失败,状态码:", response.status_code)
169
+ except Exception as e:
170
+ return f"文本转语音时出错: {str(e)}"
171
+
172
+ def generate_image(self, prompt: str) -> str:
173
+ """
174
+ 使用API生成图像
175
+ """
176
+ if not self.image_url is None:
177
+ print(f"图像已生成: {self.image_url}")
178
+ return
179
+ try:
180
+ # 使用Prompt Engineer优化提示词
181
+ optimized_prompt = self.prompt_engineer.optimize_image_prompt(prompt)
182
+ print(f"优化后的图像提示词: {optimized_prompt}")
183
+
184
+ ## create an image of superman in a tense, action-packed scene, with explosive energy and bold dynamic composition, in the style of Ross Tran
185
+ url = "https://api.siliconflow.cn/v1/images/generations"
186
+ payload = {
187
+ "model": "stabilityai/stable-diffusion-3-5-large",
188
+ "prompt": f"{optimized_prompt}",
189
+ "negative_prompt": "<string>",
190
+ "image_size": "1024x1024",
191
+ "batch_size": 1,
192
+ "seed": 4999999999,
193
+ "num_inference_steps": 20,
194
+ "guidance_scale": 7.5,
195
+ "prompt_enhancement": False
196
+ }
197
+
198
+ headers = {
199
+ "Authorization": f"Bearer {self.__api_key}",
200
+ "Content-Type": "application/json"
201
+ }
202
+ response = requests.request("POST", url, json=payload, headers=headers)
203
+ if response.status_code == 200:
204
+ data = response.json()
205
+ self.image_url = data['data'][0]['url']
206
+ print(f"图像已生成: {self.image_url}")
207
+ # return f"图像已生成。"
208
+ # return f"图像已生成已完成。继续下一个任务"
209
+ else:
210
+ print("生成图像失败,状态码:", response.status_code)
211
+
212
+ except Exception as e:
213
+ return f"生成图像时出错: {str(e)}"
214
+
215
+ def chat(self, user_input: str) -> dict:
216
+ """
217
+ 处理用户输入并返回适当的响应
218
+ """
219
+ # 创建提示来帮助agent理解如何处理不同类型的请求
220
+ prompt = f"""
221
+ 用户输入: {user_input}
222
+
223
+ 请根据以下规则处理这个请求:
224
+ 1. 如果是知识相关的问题,使用rag_tool查询知识库
225
+ 2. 如果用户要求语音输出,使用tts_tool转换文本
226
+ 3. 如果用户要求生成图像,使用image_tool生成
227
+
228
+ 根据需求请选择合适的工具并执行操作,可能需要多个工具。
229
+ """
230
+ self.image_url = None
231
+ self.audio_text = None
232
+ response = self.agent.chat(prompt)
233
+ response_dict = {"response": str(response), "image_url": self.image_url, "audio_text": self.audio_text }
234
+ return response_dict
235
+
236
+
237
+ if __name__ == "__main__":
238
+ ## load wulewule agent
239
+ wulewule_assistant = load_wulewule_agent()
240
+
241
+ ## streamlit setting
242
+ if "messages" not in st.session_state:
243
+ st.session_state["messages"] = []
244
+
245
+ # 在侧边栏中创建一个标题和一个链接
246
+ with st.sidebar:
247
+ st.markdown("## 悟了悟了💡")
248
+ logo_path = "assets/sd_wulewule.webp"
249
+ if os.path.exists(logo_path):
250
+ image = Image.open(logo_path)
251
+ st.image(image, caption='wulewule')
252
+ "[InternLM](https://github.com/InternLM)"
253
+ "[悟了悟了](https://github.com/xzyun2011/wulewule.git)"
254
+
255
+ # 创建一个标题
256
+ st.title("悟了悟了:黑神话悟空AI助手🐒")
257
+
258
+ # 遍历session_state中的所有消息,并显示在聊天界面上
259
+ for msg in st.session_state.messages:
260
+ st.chat_message("user").write(msg["user"])
261
+ assistant_res = msg["assistant"]
262
+ if isinstance(assistant_res, str):
263
+ st.chat_message("assistant").write(assistant_res)
264
+ elif isinstance(assistant_res, dict):
265
+ image_url = assistant_res["image_url"]
266
+ audio_text = assistant_res["audio_text"]
267
+ st.chat_message("assistant").write(assistant_res["response"])
268
+ if image_url:
269
+ # 使用st.image展示URL图像,并设置使用列宽
270
+ st.image( image_url, width=256 )
271
+ if audio_text:
272
+ # 使用st.audio函数播放音频
273
+ st.audio("audio.mp3")
274
+ st.write(f"语音内容为: {audio_text}")
275
+
276
+
277
+ # Get user input #你觉得悟空长啥样,按你的想法画一个
278
+ if prompt := st.chat_input("请输入你的问题,换行使用Shfit+Enter。"):
279
+ # Display user input
280
+ st.chat_message("user").write(prompt)
281
+ ## 初始化完整的回答字符串
282
+ full_answer = ""
283
+ with st.chat_message('robot'):
284
+ message_placeholder = st.empty()
285
+ response_dict = wulewule_assistant.chat(prompt)
286
+ image_url = response_dict["image_url"]
287
+ audio_text = response_dict["audio_text"]
288
+ for cur_response in response_dict["response"]:
289
+ full_answer += cur_response
290
+ # Display robot response in chat message container
291
+ message_placeholder.markdown(full_answer + '▌')
292
+ message_placeholder.markdown(full_answer)
293
+ # 将问答结果添加到 session_state 的消息历史中
294
+ st.session_state.messages.append({"user": prompt, "assistant": response_dict})
295
+ if image_url:
296
+ # 使用st.image展示URL图像,并设置使用列宽
297
+ st.image( image_url, width=256 )
298
+
299
+ if audio_text:
300
+ # 使用st.audio函数播放音频
301
+ st.audio("audio.mp3")
302
+ st.write(f"语音内容为: {audio_text}")
303
+ torch.cuda.empty_cache()