xzyun2011 commited on
Commit
06c933c
·
1 Parent(s): 5d1ff8c

add agent code final

Browse files
Files changed (1) hide show
  1. app.py +78 -37
app.py CHANGED
@@ -6,7 +6,6 @@ from PIL import Image
6
  import os
7
  import sys
8
  sys.path.append(os.path.dirname(__file__))
9
- import torch
10
  from download_models import download_model
11
 
12
 
@@ -41,8 +40,15 @@ def main(cfg):
41
  config_dict = OmegaConf.to_container(cfg, resolve=True)
42
 
43
  ## download model from modelscope
44
- if not os.path.exists(config_dict["llm_model"]):
45
  download_model(llm_model_path =config_dict["llm_model"])
 
 
 
 
 
 
 
46
 
47
  if cfg.use_rag:
48
  ## load rag model
@@ -72,50 +78,85 @@ def main(cfg):
72
  # 遍历session_state中的所有消息,并显示在聊天界面上
73
  for msg in st.session_state.messages:
74
  st.chat_message("user").write(msg["user"])
75
- st.chat_message("assistant").write(msg["assistant"])
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Get user input
78
  if prompt := st.chat_input("请输入你的问题,换行使用Shfit+Enter。"):
79
  # Display user input
80
  st.chat_message("user").write(prompt)
81
- # 流式显示, used streaming result
82
- if cfg.stream_response:
83
- # rag
84
- ## 初始化完整的回答字符串
85
- full_answer = ""
86
  with st.chat_message('robot'):
87
  message_placeholder = st.empty()
88
- if cfg.use_rag:
89
- for cur_response in wulewule_model.query_stream(prompt):
90
- full_answer += cur_response
91
- # Display robot response in chat message container
92
- message_placeholder.markdown(full_answer + '▌')
93
- elif cfg.use_lmdepoly:
94
- # gen_config = GenerationConfig(top_p=0.8,
95
- # top_k=40,
96
- # temperature=0.8,
97
- # max_new_tokens=2048,
98
- # repetition_penalty=1.05)
99
- messages = [{'role': 'user', 'content': f'{prompt}'}]
100
- for response in wulewule_model.stream_infer(messages):
101
- full_answer += response.text
102
- # Display robot response in chat message container
103
- message_placeholder.markdown(full_answer + '▌')
104
-
105
  message_placeholder.markdown(full_answer)
106
- # 一次性显示结果
 
 
 
 
 
 
 
 
 
 
 
107
  else:
108
- if cfg.use_lmdepoly:
109
- messages = [{'role': 'user', 'content': f'{prompt}'}]
110
- full_answer = wulewule_model(messages).text
111
- elif cfg.use_rag:
112
- full_answer = wulewule_model.query(prompt)
113
- # 显示回答
114
- st.chat_message("assistant").write(full_answer)
115
-
116
- # 将问答结果添加到 session_state 的消息历史中
117
- st.session_state.messages.append({"user": prompt, "assistant": full_answer})
118
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  if __name__ == "__main__":
 
6
  import os
7
  import sys
8
  sys.path.append(os.path.dirname(__file__))
 
9
  from download_models import download_model
10
 
11
 
 
40
  config_dict = OmegaConf.to_container(cfg, resolve=True)
41
 
42
  ## download model from modelscope
43
+ if not config_dict["use_remote"] and not os.path.exists(config_dict["llm_model"]):
44
  download_model(llm_model_path =config_dict["llm_model"])
45
+
46
+ ## agent mode, used llama-index, rturn off lmdeloy and chroma rag
47
+ if cfg.agent_mode:
48
+ ## load wulewule agent
49
+ wulewule_assistant = load_wulewule_agent(config_dict)
50
+ cfg.use_rag = False
51
+ cfg.use_lmdepoly = False
52
 
53
  if cfg.use_rag:
54
  ## load rag model
 
78
  # 遍历session_state中的所有消息,并显示在聊天界面上
79
  for msg in st.session_state.messages:
80
  st.chat_message("user").write(msg["user"])
81
+ assistant_res = msg["assistant"]
82
+ if isinstance(assistant_res, str):
83
+ st.chat_message("assistant").write(assistant_res)
84
+ elif cfg.agent_mode and isinstance(assistant_res, dict):
85
+ image_url = assistant_res["image_url"]
86
+ audio_text = assistant_res["audio_text"]
87
+ st.chat_message("assistant").write(assistant_res["response"])
88
+ if image_url:
89
+ # 使用st.image展示URL图像,并设置使用列宽
90
+ st.image( image_url, width=256 )
91
+ if audio_text:
92
+ # 使用st.audio函数播放音频
93
+ st.audio("audio.mp3")
94
+ st.write(f"语音内容为: \n\n{audio_text}")
95
 
96
  # Get user input
97
  if prompt := st.chat_input("请输入你的问题,换行使用Shfit+Enter。"):
98
  # Display user input
99
  st.chat_message("user").write(prompt)
100
+ ## 初始化完整的回答字符串
101
+ full_answer = ""
102
+ if cfg.agent_mode:
 
 
103
  with st.chat_message('robot'):
104
  message_placeholder = st.empty()
105
+ response_dict = wulewule_assistant.chat(prompt)
106
+ image_url = response_dict["image_url"]
107
+ audio_text = response_dict["audio_text"]
108
+ for cur_response in response_dict["response"]:
109
+ full_answer += cur_response
110
+ # Display robot response in chat message container
111
+ message_placeholder.markdown(full_answer + '▌')
 
 
 
 
 
 
 
 
 
 
112
  message_placeholder.markdown(full_answer)
113
+ # 将问答结果添加到 session_state 的消息历史中
114
+ st.session_state.messages.append({"user": prompt, "assistant": response_dict})
115
+ if image_url:
116
+ # 使用st.image展示URL图像,并设置使用列宽
117
+ st.image( image_url, width=256 )
118
+
119
+ if audio_text:
120
+ # 使用st.audio函数播放音频
121
+ st.audio("audio.mp3")
122
+ st.write(f"语音内容为: \n\n{audio_text}")
123
+
124
+ # 流式显示, used streaming result
125
  else:
126
+ if cfg.stream_response:
127
+ # rag
128
+ with st.chat_message('robot'):
129
+ message_placeholder = st.empty()
130
+ if cfg.use_rag:
131
+ for cur_response in wulewule_model.query_stream(prompt):
132
+ full_answer += cur_response
133
+ # Display robot response in chat message container
134
+ message_placeholder.markdown(full_answer + '▌')
135
+ elif cfg.use_lmdepoly:
136
+ # gen_config = GenerationConfig(top_p=0.8,
137
+ # top_k=40,
138
+ # temperature=0.8,
139
+ # max_new_tokens=2048,
140
+ # repetition_penalty=1.05)
141
+ messages = [{'role': 'user', 'content': f'{prompt}'}]
142
+ for response in wulewule_model.stream_infer(messages):
143
+ full_answer += response.text
144
+ # Display robot response in chat message container
145
+ message_placeholder.markdown(full_answer + '▌')
146
+
147
+ message_placeholder.markdown(full_answer)
148
+ # 一次性显示结果
149
+ else:
150
+ if cfg.use_lmdepoly:
151
+ messages = [{'role': 'user', 'content': f'{prompt}'}]
152
+ full_answer = wulewule_model(messages).text
153
+ elif cfg.use_rag:
154
+ full_answer = wulewule_model.query(prompt)
155
+ # 显示回答
156
+ st.chat_message("assistant").write(full_answer)
157
+
158
+ # 将问答结果添加到 session_state 的消息历史中
159
+ st.session_state.messages.append({"user": prompt, "assistant": full_answer})
160
 
161
 
162
  if __name__ == "__main__":