liuhua
liuhua
commited on
Commit
·
fea9976
1
Parent(s):
11bef16
Add parameters for ask_chat and fix bugs in list_sessions (#4119)
Browse files### What problem does this PR solve?
Add parameters for ask_chat and fix bugs in list_sessions
#4105
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
Co-authored-by: liuhua <[email protected]>
api/apps/sdk/session.py
CHANGED
@@ -65,20 +65,24 @@ def create(tenant_id, chat_id):
|
|
65 |
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
|
66 |
@token_required
|
67 |
def create_agent_session(tenant_id, agent_id):
|
|
|
68 |
e, cvs = UserCanvasService.get_by_id(agent_id)
|
69 |
if not e:
|
70 |
return get_error_data_result("Agent not found.")
|
71 |
|
|
|
|
|
|
|
72 |
if not isinstance(cvs.dsl, str):
|
73 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
74 |
|
75 |
canvas = Canvas(cvs.dsl, tenant_id)
|
76 |
if canvas.get_preset_param():
|
77 |
-
return get_error_data_result("The agent
|
78 |
conv = {
|
79 |
"id": get_uuid(),
|
80 |
"dialog_id": cvs.id,
|
81 |
-
"user_id":
|
82 |
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
83 |
"source": "agent",
|
84 |
"dsl": json.loads(cvs.dsl)
|
@@ -199,17 +203,15 @@ def list_session(tenant_id, chat_id):
|
|
199 |
chunks = conv["reference"][chunk_num]["chunks"]
|
200 |
for chunk in chunks:
|
201 |
new_chunk = {
|
202 |
-
"id": chunk
|
203 |
-
"content": chunk
|
204 |
-
"document_id": chunk
|
205 |
-
"document_name": chunk
|
206 |
-
"dataset_id": chunk
|
207 |
-
"image_id": chunk.get("image_id", ""),
|
208 |
-
"
|
209 |
-
"vector_similarity": chunk["vector_similarity"],
|
210 |
-
"term_similarity": chunk["term_similarity"],
|
211 |
-
"positions": chunk["positions"],
|
212 |
}
|
|
|
213 |
chunk_list.append(new_chunk)
|
214 |
chunk_num += 1
|
215 |
messages[message_num]["reference"] = chunk_list
|
@@ -254,16 +256,13 @@ def list_agent_session(tenant_id, agent_id):
|
|
254 |
chunks = conv["reference"][chunk_num]["chunks"]
|
255 |
for chunk in chunks:
|
256 |
new_chunk = {
|
257 |
-
"id": chunk
|
258 |
-
"content": chunk
|
259 |
-
"document_id": chunk
|
260 |
-
"document_name": chunk
|
261 |
-
"dataset_id": chunk
|
262 |
-
"image_id": chunk.get("image_id", ""),
|
263 |
-
"
|
264 |
-
"vector_similarity": chunk["vector_similarity"],
|
265 |
-
"term_similarity": chunk["term_similarity"],
|
266 |
-
"positions": chunk["positions"],
|
267 |
}
|
268 |
chunk_list.append(new_chunk)
|
269 |
chunk_num += 1
|
|
|
65 |
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
|
66 |
@token_required
|
67 |
def create_agent_session(tenant_id, agent_id):
|
68 |
+
req = request.json
|
69 |
e, cvs = UserCanvasService.get_by_id(agent_id)
|
70 |
if not e:
|
71 |
return get_error_data_result("Agent not found.")
|
72 |
|
73 |
+
if not UserCanvasService.query(user_id=tenant_id,id=agent_id):
|
74 |
+
return get_error_data_result("You cannot access the agent.")
|
75 |
+
|
76 |
if not isinstance(cvs.dsl, str):
|
77 |
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
78 |
|
79 |
canvas = Canvas(cvs.dsl, tenant_id)
|
80 |
if canvas.get_preset_param():
|
81 |
+
return get_error_data_result("The agent cannot create a session directly")
|
82 |
conv = {
|
83 |
"id": get_uuid(),
|
84 |
"dialog_id": cvs.id,
|
85 |
+
"user_id": req.get("usr_id","") if isinstance(req, dict) else "",
|
86 |
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
87 |
"source": "agent",
|
88 |
"dsl": json.loads(cvs.dsl)
|
|
|
203 |
chunks = conv["reference"][chunk_num]["chunks"]
|
204 |
for chunk in chunks:
|
205 |
new_chunk = {
|
206 |
+
"id": chunk.get("chunk_id", chunk.get("id")),
|
207 |
+
"content": chunk.get("content_with_weight", chunk.get("content")),
|
208 |
+
"document_id": chunk.get("doc_id", chunk.get("document_id")),
|
209 |
+
"document_name": chunk.get("docnm_kwd", chunk.get("document_name")),
|
210 |
+
"dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
|
211 |
+
"image_id": chunk.get("image_id", chunk.get("img_id")),
|
212 |
+
"positions": chunk.get("positions", chunk.get("position_int")),
|
|
|
|
|
|
|
213 |
}
|
214 |
+
|
215 |
chunk_list.append(new_chunk)
|
216 |
chunk_num += 1
|
217 |
messages[message_num]["reference"] = chunk_list
|
|
|
256 |
chunks = conv["reference"][chunk_num]["chunks"]
|
257 |
for chunk in chunks:
|
258 |
new_chunk = {
|
259 |
+
"id": chunk.get("chunk_id", chunk.get("id")),
|
260 |
+
"content": chunk.get("content_with_weight", chunk.get("content")),
|
261 |
+
"document_id": chunk.get("doc_id", chunk.get("document_id")),
|
262 |
+
"document_name": chunk.get("docnm_kwd", chunk.get("document_name")),
|
263 |
+
"dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
|
264 |
+
"image_id": chunk.get("image_id", chunk.get("img_id")),
|
265 |
+
"positions": chunk.get("positions", chunk.get("position_int")),
|
|
|
|
|
|
|
266 |
}
|
267 |
chunk_list.append(new_chunk)
|
268 |
chunk_num += 1
|
sdk/python/ragflow_sdk/modules/session.py
CHANGED
@@ -17,11 +17,11 @@ class Session(Base):
|
|
17 |
self.__session_type = "agent"
|
18 |
super().__init__(rag, res_dict)
|
19 |
|
20 |
-
def ask(self, question,stream=True):
|
21 |
if self.__session_type == "agent":
|
22 |
res=self._ask_agent(question,stream)
|
23 |
elif self.__session_type == "chat":
|
24 |
-
res=self._ask_chat(question,stream)
|
25 |
for line in res.iter_lines():
|
26 |
line = line.decode("utf-8")
|
27 |
if line.startswith("{"):
|
@@ -45,9 +45,11 @@ class Session(Base):
|
|
45 |
yield message
|
46 |
|
47 |
|
48 |
-
def _ask_chat(self, question: str, stream: bool):
|
|
|
|
|
49 |
res = self.post(f"/chats/{self.chat_id}/completions",
|
50 |
-
|
51 |
return res
|
52 |
def _ask_agent(self,question:str,stream:bool):
|
53 |
res = self.post(f"/agents/{self.agent_id}/completions",
|
|
|
17 |
self.__session_type = "agent"
|
18 |
super().__init__(rag, res_dict)
|
19 |
|
20 |
+
def ask(self, question,stream=True,**kwargs):
|
21 |
if self.__session_type == "agent":
|
22 |
res=self._ask_agent(question,stream)
|
23 |
elif self.__session_type == "chat":
|
24 |
+
res=self._ask_chat(question,stream,**kwargs)
|
25 |
for line in res.iter_lines():
|
26 |
line = line.decode("utf-8")
|
27 |
if line.startswith("{"):
|
|
|
45 |
yield message
|
46 |
|
47 |
|
48 |
+
def _ask_chat(self, question: str, stream: bool,**kwargs):
|
49 |
+
json_data={"question": question, "stream": True,"session_id":self.id}
|
50 |
+
json_data.update(kwargs)
|
51 |
res = self.post(f"/chats/{self.chat_id}/completions",
|
52 |
+
json_data, stream=stream)
|
53 |
return res
|
54 |
def _ask_agent(self,question:str,stream:bool):
|
55 |
res = self.post(f"/agents/{self.agent_id}/completions",
|