wwwlll commited on
Commit
17f4221
·
1 Parent(s): 7b030d6

Fix retrieval API error and add multi-kb search (#1928)

Browse files

### What problem does this PR solve?
Type of change
Bug Fix (Import necessary class for retrieval API )
New Feature (Add multi-KB search to retrieval API)

Files changed (1) hide show
  1. api/apps/api_app.py +16 -16
api/apps/api_app.py CHANGED
@@ -18,9 +18,10 @@ import os
18
  import re
19
  from datetime import datetime, timedelta
20
  from flask import request, Response
 
21
  from flask_login import login_required, current_user
22
 
23
- from api.db import FileType, ParserType, FileSource
24
  from api.db.db_models import APIToken, API4Conversation, Task, File
25
  from api.db.services import duplicate_name
26
  from api.db.services.api_service import APITokenService, API4ConversationService
@@ -37,6 +38,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
37
  from itsdangerous import URLSafeTimedSerializer
38
 
39
  from api.utils.file_utils import filename_type, thumbnail
 
40
  from rag.utils.minio_conn import MINIO
41
 
42
  from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
@@ -694,7 +696,7 @@ def retrieval():
694
  data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
695
 
696
  req = request.json
697
- kb_id = req.get("kb_id")
698
  doc_ids = req.get("doc_ids", [])
699
  question = req.get("question")
700
  page = int(req.get("page", 1))
@@ -704,32 +706,30 @@ def retrieval():
704
  top = int(req.get("top_k", 1024))
705
 
706
  try:
707
- e, kb = KnowledgebaseService.get_by_id(kb_id)
708
- if not e:
709
- return get_data_error_result(retmsg="Knowledgebase not found!")
 
 
710
 
711
  embd_mdl = TenantLLMService.model_instance(
712
- kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
713
-
714
  rerank_mdl = None
715
  if req.get("rerank_id"):
716
  rerank_mdl = TenantLLMService.model_instance(
717
- kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
718
-
719
  if req.get("keyword", False):
720
- chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
721
  question += keyword_extraction(chat_mdl, question)
722
-
723
- ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
724
- similarity_threshold, vector_similarity_weight, top,
725
- doc_ids, rerank_mdl=rerank_mdl)
726
  for c in ranks["chunks"]:
727
  if "vector" in c:
728
  del c["vector"]
729
-
730
  return get_json_result(data=ranks)
731
  except Exception as e:
732
  if str(e).find("not_found") > 0:
733
  return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
734
  retcode=RetCode.DATA_ERROR)
735
- return server_error_response(e)
 
18
  import re
19
  from datetime import datetime, timedelta
20
  from flask import request, Response
21
+ from api.db.services.llm_service import TenantLLMService
22
  from flask_login import login_required, current_user
23
 
24
+ from api.db import FileType, LLMType, ParserType, FileSource
25
  from api.db.db_models import APIToken, API4Conversation, Task, File
26
  from api.db.services import duplicate_name
27
  from api.db.services.api_service import APITokenService, API4ConversationService
 
38
  from itsdangerous import URLSafeTimedSerializer
39
 
40
  from api.utils.file_utils import filename_type, thumbnail
41
+ from rag.nlp import keyword_extraction
42
  from rag.utils.minio_conn import MINIO
43
 
44
  from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
 
696
  data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
697
 
698
  req = request.json
699
+ kb_ids = req.get("kb_id",[])
700
  doc_ids = req.get("doc_ids", [])
701
  question = req.get("question")
702
  page = int(req.get("page", 1))
 
706
  top = int(req.get("top_k", 1024))
707
 
708
  try:
709
+ kbs = KnowledgebaseService.get_by_ids(kb_ids)
710
+ embd_nms = list(set([kb.embd_id for kb in kbs]))
711
+ if len(embd_nms) != 1:
712
+ return get_json_result(
713
+ data=False, retmsg='Knowledge bases use different embedding models or does not exist."', retcode=RetCode.AUTHENTICATION_ERROR)
714
 
715
  embd_mdl = TenantLLMService.model_instance(
716
+ kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
 
717
  rerank_mdl = None
718
  if req.get("rerank_id"):
719
  rerank_mdl = TenantLLMService.model_instance(
720
+ kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
 
721
  if req.get("keyword", False):
722
+ chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
723
  question += keyword_extraction(chat_mdl, question)
724
+ ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
725
+ similarity_threshold, vector_similarity_weight, top,
726
+ doc_ids, rerank_mdl=rerank_mdl)
 
727
  for c in ranks["chunks"]:
728
  if "vector" in c:
729
  del c["vector"]
 
730
  return get_json_result(data=ranks)
731
  except Exception as e:
732
  if str(e).find("not_found") > 0:
733
  return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
734
  retcode=RetCode.DATA_ERROR)
735
+ return server_error_response(e)