Last commit not found
# | |
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import json | |
from copy import deepcopy | |
import pandas as pd | |
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr | |
from rag.nlp.search import Dealer | |
class KGSearch(Dealer): | |
def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False): | |
def merge_into_first(sres, title="") -> dict[str, str]: | |
if not sres: | |
return {} | |
content_with_weight = "" | |
df, texts = [],[] | |
for d in sres.values(): | |
try: | |
df.append(json.loads(d["content_with_weight"])) | |
except Exception: | |
texts.append(d["content_with_weight"]) | |
if df: | |
content_with_weight = title + "\n" + pd.DataFrame(df).to_csv() | |
else: | |
content_with_weight = title + "\n" + "\n".join(texts) | |
first_id = "" | |
first_source = {} | |
for k, v in sres.items(): | |
first_id = id | |
first_source = deepcopy(v) | |
break | |
first_source["content_with_weight"] = content_with_weight | |
first_id = next(iter(sres)) | |
return {first_id: first_source} | |
qst = req.get("question", "") | |
matchText, keywords = self.qryr.question(qst, min_match=0.05) | |
condition = self.get_filters(req) | |
## Entity retrieval | |
condition.update({"knowledge_graph_kwd": ["entity"]}) | |
assert emb_mdl, "No embedding model selected" | |
matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1)) | |
q_vec = matchDense.embedding_data | |
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", | |
"doc_id", f"q_{len(q_vec)}_vec", "position_int", "name_kwd", | |
"available_int", "content_with_weight", | |
"weight_int", "weight_flt" | |
]) | |
fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"}) | |
ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids) | |
ent_res_fields = self.dataStore.getFields(ent_res, src) | |
entities = [d["name_kwd"] for d in ent_res_fields.values() if d.get("name_kwd")] | |
ent_ids = self.dataStore.getChunkIds(ent_res) | |
ent_content = merge_into_first(ent_res_fields, "-Entities-") | |
if ent_content: | |
ent_ids = list(ent_content.keys()) | |
## Community retrieval | |
condition = self.get_filters(req) | |
condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]}) | |
comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids) | |
comm_res_fields = self.dataStore.getFields(comm_res, src) | |
comm_ids = self.dataStore.getChunkIds(comm_res) | |
comm_content = merge_into_first(comm_res_fields, "-Community Report-") | |
if comm_content: | |
comm_ids = list(comm_content.keys()) | |
## Text content retrieval | |
condition = self.get_filters(req) | |
condition.update({"knowledge_graph_kwd": ["text"]}) | |
txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids) | |
txt_res_fields = self.dataStore.getFields(txt_res, src) | |
txt_ids = self.dataStore.getChunkIds(txt_res) | |
txt_content = merge_into_first(txt_res_fields, "-Original Content-") | |
if txt_content: | |
txt_ids = list(txt_content.keys()) | |
return self.SearchResult( | |
total=len(ent_ids) + len(comm_ids) + len(txt_ids), | |
ids=[*ent_ids, *comm_ids, *txt_ids], | |
query_vector=q_vec, | |
highlight=None, | |
field={**ent_content, **comm_content, **txt_content}, | |
keywords=[] | |
) | |