wwwlll
commited on
Commit
·
17f4221
1
Parent(s):
7b030d6
Fix retrieval API error and add multi-kb search (#1928)
Browse files### What problem does this PR solve?
Type of change
Bug Fix (Import necessary class for retrieval API )
New Feature (Add multi-KB search to retrieval API)
- api/apps/api_app.py +16 -16
api/apps/api_app.py
CHANGED
@@ -18,9 +18,10 @@ import os
|
|
18 |
import re
|
19 |
from datetime import datetime, timedelta
|
20 |
from flask import request, Response
|
|
|
21 |
from flask_login import login_required, current_user
|
22 |
|
23 |
-
from api.db import FileType, ParserType, FileSource
|
24 |
from api.db.db_models import APIToken, API4Conversation, Task, File
|
25 |
from api.db.services import duplicate_name
|
26 |
from api.db.services.api_service import APITokenService, API4ConversationService
|
@@ -37,6 +38,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
|
37 |
from itsdangerous import URLSafeTimedSerializer
|
38 |
|
39 |
from api.utils.file_utils import filename_type, thumbnail
|
|
|
40 |
from rag.utils.minio_conn import MINIO
|
41 |
|
42 |
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
@@ -694,7 +696,7 @@ def retrieval():
|
|
694 |
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
|
695 |
|
696 |
req = request.json
|
697 |
-
|
698 |
doc_ids = req.get("doc_ids", [])
|
699 |
question = req.get("question")
|
700 |
page = int(req.get("page", 1))
|
@@ -704,32 +706,30 @@ def retrieval():
|
|
704 |
top = int(req.get("top_k", 1024))
|
705 |
|
706 |
try:
|
707 |
-
|
708 |
-
|
709 |
-
|
|
|
|
|
710 |
|
711 |
embd_mdl = TenantLLMService.model_instance(
|
712 |
-
|
713 |
-
|
714 |
rerank_mdl = None
|
715 |
if req.get("rerank_id"):
|
716 |
rerank_mdl = TenantLLMService.model_instance(
|
717 |
-
|
718 |
-
|
719 |
if req.get("keyword", False):
|
720 |
-
chat_mdl = TenantLLMService.model_instance(
|
721 |
question += keyword_extraction(chat_mdl, question)
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
doc_ids, rerank_mdl=rerank_mdl)
|
726 |
for c in ranks["chunks"]:
|
727 |
if "vector" in c:
|
728 |
del c["vector"]
|
729 |
-
|
730 |
return get_json_result(data=ranks)
|
731 |
except Exception as e:
|
732 |
if str(e).find("not_found") > 0:
|
733 |
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
734 |
retcode=RetCode.DATA_ERROR)
|
735 |
-
return server_error_response(e)
|
|
|
18 |
import re
|
19 |
from datetime import datetime, timedelta
|
20 |
from flask import request, Response
|
21 |
+
from api.db.services.llm_service import TenantLLMService
|
22 |
from flask_login import login_required, current_user
|
23 |
|
24 |
+
from api.db import FileType, LLMType, ParserType, FileSource
|
25 |
from api.db.db_models import APIToken, API4Conversation, Task, File
|
26 |
from api.db.services import duplicate_name
|
27 |
from api.db.services.api_service import APITokenService, API4ConversationService
|
|
|
38 |
from itsdangerous import URLSafeTimedSerializer
|
39 |
|
40 |
from api.utils.file_utils import filename_type, thumbnail
|
41 |
+
from rag.nlp import keyword_extraction
|
42 |
from rag.utils.minio_conn import MINIO
|
43 |
|
44 |
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
|
|
696 |
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
|
697 |
|
698 |
req = request.json
|
699 |
+
kb_ids = req.get("kb_id",[])
|
700 |
doc_ids = req.get("doc_ids", [])
|
701 |
question = req.get("question")
|
702 |
page = int(req.get("page", 1))
|
|
|
706 |
top = int(req.get("top_k", 1024))
|
707 |
|
708 |
try:
|
709 |
+
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
710 |
+
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
711 |
+
if len(embd_nms) != 1:
|
712 |
+
return get_json_result(
|
713 |
+
data=False, retmsg='Knowledge bases use different embedding models or does not exist."', retcode=RetCode.AUTHENTICATION_ERROR)
|
714 |
|
715 |
embd_mdl = TenantLLMService.model_instance(
|
716 |
+
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
|
|
|
717 |
rerank_mdl = None
|
718 |
if req.get("rerank_id"):
|
719 |
rerank_mdl = TenantLLMService.model_instance(
|
720 |
+
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
|
|
721 |
if req.get("keyword", False):
|
722 |
+
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
|
723 |
question += keyword_extraction(chat_mdl, question)
|
724 |
+
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
725 |
+
similarity_threshold, vector_similarity_weight, top,
|
726 |
+
doc_ids, rerank_mdl=rerank_mdl)
|
|
|
727 |
for c in ranks["chunks"]:
|
728 |
if "vector" in c:
|
729 |
del c["vector"]
|
|
|
730 |
return get_json_result(data=ranks)
|
731 |
except Exception as e:
|
732 |
if str(e).find("not_found") > 0:
|
733 |
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
734 |
retcode=RetCode.DATA_ERROR)
|
735 |
+
return server_error_response(e)
|