add agent code final
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ from PIL import Image
|
|
6 |
import os
|
7 |
import sys
|
8 |
sys.path.append(os.path.dirname(__file__))
|
9 |
-
import torch
|
10 |
from download_models import download_model
|
11 |
|
12 |
|
@@ -41,8 +40,15 @@ def main(cfg):
|
|
41 |
config_dict = OmegaConf.to_container(cfg, resolve=True)
|
42 |
|
43 |
## download model from modelscope
|
44 |
-
if not os.path.exists(config_dict["llm_model"]):
|
45 |
download_model(llm_model_path =config_dict["llm_model"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
if cfg.use_rag:
|
48 |
## load rag model
|
@@ -72,50 +78,85 @@ def main(cfg):
|
|
72 |
# 遍历session_state中的所有消息,并显示在聊天界面上
|
73 |
for msg in st.session_state.messages:
|
74 |
st.chat_message("user").write(msg["user"])
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# Get user input
|
78 |
if prompt := st.chat_input("请输入你的问题,换行使用Shfit+Enter。"):
|
79 |
# Display user input
|
80 |
st.chat_message("user").write(prompt)
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
## 初始化完整的回答字符串
|
85 |
-
full_answer = ""
|
86 |
with st.chat_message('robot'):
|
87 |
message_placeholder = st.empty()
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
# top_k=40,
|
96 |
-
# temperature=0.8,
|
97 |
-
# max_new_tokens=2048,
|
98 |
-
# repetition_penalty=1.05)
|
99 |
-
messages = [{'role': 'user', 'content': f'{prompt}'}]
|
100 |
-
for response in wulewule_model.stream_infer(messages):
|
101 |
-
full_answer += response.text
|
102 |
-
# Display robot response in chat message container
|
103 |
-
message_placeholder.markdown(full_answer + '▌')
|
104 |
-
|
105 |
message_placeholder.markdown(full_answer)
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
else:
|
108 |
-
if cfg.
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
|
121 |
if __name__ == "__main__":
|
|
|
6 |
import os
|
7 |
import sys
|
8 |
sys.path.append(os.path.dirname(__file__))
|
|
|
9 |
from download_models import download_model
|
10 |
|
11 |
|
|
|
40 |
config_dict = OmegaConf.to_container(cfg, resolve=True)
|
41 |
|
42 |
## download model from modelscope
|
43 |
+
if not config_dict["use_remote"] and not os.path.exists(config_dict["llm_model"]):
|
44 |
download_model(llm_model_path =config_dict["llm_model"])
|
45 |
+
|
46 |
+
## agent mode, used llama-index, rturn off lmdeloy and chroma rag
|
47 |
+
if cfg.agent_mode:
|
48 |
+
## load wulewule agent
|
49 |
+
wulewule_assistant = load_wulewule_agent(config_dict)
|
50 |
+
cfg.use_rag = False
|
51 |
+
cfg.use_lmdepoly = False
|
52 |
|
53 |
if cfg.use_rag:
|
54 |
## load rag model
|
|
|
78 |
# 遍历session_state中的所有消息,并显示在聊天界面上
|
79 |
for msg in st.session_state.messages:
|
80 |
st.chat_message("user").write(msg["user"])
|
81 |
+
assistant_res = msg["assistant"]
|
82 |
+
if isinstance(assistant_res, str):
|
83 |
+
st.chat_message("assistant").write(assistant_res)
|
84 |
+
elif cfg.agent_mode and isinstance(assistant_res, dict):
|
85 |
+
image_url = assistant_res["image_url"]
|
86 |
+
audio_text = assistant_res["audio_text"]
|
87 |
+
st.chat_message("assistant").write(assistant_res["response"])
|
88 |
+
if image_url:
|
89 |
+
# 使用st.image展示URL图像,并设置使用列宽
|
90 |
+
st.image( image_url, width=256 )
|
91 |
+
if audio_text:
|
92 |
+
# 使用st.audio函数播放音频
|
93 |
+
st.audio("audio.mp3")
|
94 |
+
st.write(f"语音内容为: \n\n{audio_text}")
|
95 |
|
96 |
# Get user input
|
97 |
if prompt := st.chat_input("请输入你的问题,换行使用Shfit+Enter。"):
|
98 |
# Display user input
|
99 |
st.chat_message("user").write(prompt)
|
100 |
+
## 初始化完整的回答字符串
|
101 |
+
full_answer = ""
|
102 |
+
if cfg.agent_mode:
|
|
|
|
|
103 |
with st.chat_message('robot'):
|
104 |
message_placeholder = st.empty()
|
105 |
+
response_dict = wulewule_assistant.chat(prompt)
|
106 |
+
image_url = response_dict["image_url"]
|
107 |
+
audio_text = response_dict["audio_text"]
|
108 |
+
for cur_response in response_dict["response"]:
|
109 |
+
full_answer += cur_response
|
110 |
+
# Display robot response in chat message container
|
111 |
+
message_placeholder.markdown(full_answer + '▌')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
message_placeholder.markdown(full_answer)
|
113 |
+
# 将问答结果添加到 session_state 的消息历史中
|
114 |
+
st.session_state.messages.append({"user": prompt, "assistant": response_dict})
|
115 |
+
if image_url:
|
116 |
+
# 使用st.image展示URL图像,并设置使用列宽
|
117 |
+
st.image( image_url, width=256 )
|
118 |
+
|
119 |
+
if audio_text:
|
120 |
+
# 使用st.audio函数播放音频
|
121 |
+
st.audio("audio.mp3")
|
122 |
+
st.write(f"语音内容为: \n\n{audio_text}")
|
123 |
+
|
124 |
+
# 流式显示, used streaming result
|
125 |
else:
|
126 |
+
if cfg.stream_response:
|
127 |
+
# rag
|
128 |
+
with st.chat_message('robot'):
|
129 |
+
message_placeholder = st.empty()
|
130 |
+
if cfg.use_rag:
|
131 |
+
for cur_response in wulewule_model.query_stream(prompt):
|
132 |
+
full_answer += cur_response
|
133 |
+
# Display robot response in chat message container
|
134 |
+
message_placeholder.markdown(full_answer + '▌')
|
135 |
+
elif cfg.use_lmdepoly:
|
136 |
+
# gen_config = GenerationConfig(top_p=0.8,
|
137 |
+
# top_k=40,
|
138 |
+
# temperature=0.8,
|
139 |
+
# max_new_tokens=2048,
|
140 |
+
# repetition_penalty=1.05)
|
141 |
+
messages = [{'role': 'user', 'content': f'{prompt}'}]
|
142 |
+
for response in wulewule_model.stream_infer(messages):
|
143 |
+
full_answer += response.text
|
144 |
+
# Display robot response in chat message container
|
145 |
+
message_placeholder.markdown(full_answer + '▌')
|
146 |
+
|
147 |
+
message_placeholder.markdown(full_answer)
|
148 |
+
# 一次性显示结果
|
149 |
+
else:
|
150 |
+
if cfg.use_lmdepoly:
|
151 |
+
messages = [{'role': 'user', 'content': f'{prompt}'}]
|
152 |
+
full_answer = wulewule_model(messages).text
|
153 |
+
elif cfg.use_rag:
|
154 |
+
full_answer = wulewule_model.query(prompt)
|
155 |
+
# 显示回答
|
156 |
+
st.chat_message("assistant").write(full_answer)
|
157 |
+
|
158 |
+
# 将问答结果添加到 session_state 的消息历史中
|
159 |
+
st.session_state.messages.append({"user": prompt, "assistant": full_answer})
|
160 |
|
161 |
|
162 |
if __name__ == "__main__":
|