liuhua liuhua commited on
Commit
fea9976
·
1 Parent(s): 11bef16

Add parameters for ask_chat and fix bugs in list_sessions (#4119)

Browse files

### What problem does this PR solve?

Add parameters for ask_chat and fix bugs in list_sessions
#4105
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: liuhua <[email protected]>

api/apps/sdk/session.py CHANGED
@@ -65,20 +65,24 @@ def create(tenant_id, chat_id):
65
  @manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
66
  @token_required
67
  def create_agent_session(tenant_id, agent_id):
 
68
  e, cvs = UserCanvasService.get_by_id(agent_id)
69
  if not e:
70
  return get_error_data_result("Agent not found.")
71
 
 
 
 
72
  if not isinstance(cvs.dsl, str):
73
  cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
74
 
75
  canvas = Canvas(cvs.dsl, tenant_id)
76
  if canvas.get_preset_param():
77
- return get_error_data_result("The agent can't create a session directly")
78
  conv = {
79
  "id": get_uuid(),
80
  "dialog_id": cvs.id,
81
- "user_id": tenant_id,
82
  "message": [{"role": "assistant", "content": canvas.get_prologue()}],
83
  "source": "agent",
84
  "dsl": json.loads(cvs.dsl)
@@ -199,17 +203,15 @@ def list_session(tenant_id, chat_id):
199
  chunks = conv["reference"][chunk_num]["chunks"]
200
  for chunk in chunks:
201
  new_chunk = {
202
- "id": chunk["chunk_id"],
203
- "content": chunk["content_with_weight"],
204
- "document_id": chunk["doc_id"],
205
- "document_name": chunk["docnm_kwd"],
206
- "dataset_id": chunk["kb_id"],
207
- "image_id": chunk.get("image_id", ""),
208
- "similarity": chunk["similarity"],
209
- "vector_similarity": chunk["vector_similarity"],
210
- "term_similarity": chunk["term_similarity"],
211
- "positions": chunk["positions"],
212
  }
 
213
  chunk_list.append(new_chunk)
214
  chunk_num += 1
215
  messages[message_num]["reference"] = chunk_list
@@ -254,16 +256,13 @@ def list_agent_session(tenant_id, agent_id):
254
  chunks = conv["reference"][chunk_num]["chunks"]
255
  for chunk in chunks:
256
  new_chunk = {
257
- "id": chunk["chunk_id"],
258
- "content": chunk["content"],
259
- "document_id": chunk["doc_id"],
260
- "document_name": chunk["docnm_kwd"],
261
- "dataset_id": chunk["kb_id"],
262
- "image_id": chunk.get("image_id", ""),
263
- "similarity": chunk["similarity"],
264
- "vector_similarity": chunk["vector_similarity"],
265
- "term_similarity": chunk["term_similarity"],
266
- "positions": chunk["positions"],
267
  }
268
  chunk_list.append(new_chunk)
269
  chunk_num += 1
 
65
  @manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
66
  @token_required
67
  def create_agent_session(tenant_id, agent_id):
68
+ req = request.json
69
  e, cvs = UserCanvasService.get_by_id(agent_id)
70
  if not e:
71
  return get_error_data_result("Agent not found.")
72
 
73
+ if not UserCanvasService.query(user_id=tenant_id,id=agent_id):
74
+ return get_error_data_result("You cannot access the agent.")
75
+
76
  if not isinstance(cvs.dsl, str):
77
  cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
78
 
79
  canvas = Canvas(cvs.dsl, tenant_id)
80
  if canvas.get_preset_param():
81
+ return get_error_data_result("The agent cannot create a session directly")
82
  conv = {
83
  "id": get_uuid(),
84
  "dialog_id": cvs.id,
85
+ "user_id": req.get("usr_id","") if isinstance(req, dict) else "",
86
  "message": [{"role": "assistant", "content": canvas.get_prologue()}],
87
  "source": "agent",
88
  "dsl": json.loads(cvs.dsl)
 
203
  chunks = conv["reference"][chunk_num]["chunks"]
204
  for chunk in chunks:
205
  new_chunk = {
206
+ "id": chunk.get("chunk_id", chunk.get("id")),
207
+ "content": chunk.get("content_with_weight", chunk.get("content")),
208
+ "document_id": chunk.get("doc_id", chunk.get("document_id")),
209
+ "document_name": chunk.get("docnm_kwd", chunk.get("document_name")),
210
+ "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
211
+ "image_id": chunk.get("image_id", chunk.get("img_id")),
212
+ "positions": chunk.get("positions", chunk.get("position_int")),
 
 
 
213
  }
214
+
215
  chunk_list.append(new_chunk)
216
  chunk_num += 1
217
  messages[message_num]["reference"] = chunk_list
 
256
  chunks = conv["reference"][chunk_num]["chunks"]
257
  for chunk in chunks:
258
  new_chunk = {
259
+ "id": chunk.get("chunk_id", chunk.get("id")),
260
+ "content": chunk.get("content_with_weight", chunk.get("content")),
261
+ "document_id": chunk.get("doc_id", chunk.get("document_id")),
262
+ "document_name": chunk.get("docnm_kwd", chunk.get("document_name")),
263
+ "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
264
+ "image_id": chunk.get("image_id", chunk.get("img_id")),
265
+ "positions": chunk.get("positions", chunk.get("position_int")),
 
 
 
266
  }
267
  chunk_list.append(new_chunk)
268
  chunk_num += 1
sdk/python/ragflow_sdk/modules/session.py CHANGED
@@ -17,11 +17,11 @@ class Session(Base):
17
  self.__session_type = "agent"
18
  super().__init__(rag, res_dict)
19
 
20
- def ask(self, question,stream=True):
21
  if self.__session_type == "agent":
22
  res=self._ask_agent(question,stream)
23
  elif self.__session_type == "chat":
24
- res=self._ask_chat(question,stream)
25
  for line in res.iter_lines():
26
  line = line.decode("utf-8")
27
  if line.startswith("{"):
@@ -45,9 +45,11 @@ class Session(Base):
45
  yield message
46
 
47
 
48
- def _ask_chat(self, question: str, stream: bool):
 
 
49
  res = self.post(f"/chats/{self.chat_id}/completions",
50
- {"question": question, "stream": True,"session_id":self.id}, stream=stream)
51
  return res
52
  def _ask_agent(self,question:str,stream:bool):
53
  res = self.post(f"/agents/{self.agent_id}/completions",
 
17
  self.__session_type = "agent"
18
  super().__init__(rag, res_dict)
19
 
20
+ def ask(self, question,stream=True,**kwargs):
21
  if self.__session_type == "agent":
22
  res=self._ask_agent(question,stream)
23
  elif self.__session_type == "chat":
24
+ res=self._ask_chat(question,stream,**kwargs)
25
  for line in res.iter_lines():
26
  line = line.decode("utf-8")
27
  if line.startswith("{"):
 
45
  yield message
46
 
47
 
48
+ def _ask_chat(self, question: str, stream: bool,**kwargs):
49
+ json_data={"question": question, "stream": True,"session_id":self.id}
50
+ json_data.update(kwargs)
51
  res = self.post(f"/chats/{self.chat_id}/completions",
52
+ json_data, stream=stream)
53
  return res
54
  def _ask_agent(self,question:str,stream:bool):
55
  res = self.post(f"/agents/{self.agent_id}/completions",