File size: 4,391 Bytes
ab2ded1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import json
from copy import deepcopy

import pandas as pd
from elasticsearch_dsl import Q, Search

from rag.nlp.search import Dealer


class KGSearch(Dealer):
    def search(self, req, idxnm, emb_mdl=None):
        def merge_into_first(sres, title=""):
            df,texts = [],[]
            for d in sres["hits"]["hits"]:
                try:
                    df.append(json.loads(d["_source"]["content_with_weight"]))
                except Exception as e:
                    texts.append(d["_source"]["content_with_weight"])
                    pass
            if not df and not texts: return False
            if df:
                try:
                    sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
                except Exception as e:
                    pass
            else:
                sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts)
            return True

        src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
                                 "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd",
                                 "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
                                 "weight_int", "weight_flt", "rank_int"
                                 ])

        qst = req.get("question", "")
        binary_query, keywords = self.qryr.question(qst, min_match="5%")
        binary_query = self._add_filters(binary_query, req)

        ## Entity retrieval
        bqry = deepcopy(binary_query)
        bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"]))
        s = Search()
        s = s.query(bqry)[0: 32]

        s = s.to_dict()
        q_vec = []
        if req.get("vector"):
            assert emb_mdl, "No embedding model selected"
            s["knn"] = self._vector(
                qst, emb_mdl, req.get(
                    "similarity", 0.1), 1024)
            s["knn"]["filter"] = bqry.to_dict()
            q_vec = s["knn"]["query_vector"]

        ent_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
        entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
        ent_ids = self.es.getDocIds(ent_res)
        if merge_into_first(ent_res, "-Entities-"):
            ent_ids = ent_ids[0:1]

        ## Community retrieval
        bqry = deepcopy(binary_query)
        bqry.filter.append(Q("terms", entities_kwd=entities))
        bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"]))
        s = Search()
        s = s.query(bqry)[0: 32]
        s = s.to_dict()
        comm_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
        comm_ids = self.es.getDocIds(comm_res)
        if merge_into_first(comm_res, "-Community Report-"):
            comm_ids = comm_ids[0:1]

        ## Text content retrieval
        bqry = deepcopy(binary_query)
        bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"]))
        s = Search()
        s = s.query(bqry)[0: 6]
        s = s.to_dict()
        txt_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
        txt_ids = self.es.getDocIds(comm_res)
        if merge_into_first(txt_res, "-Original Content-"):
            txt_ids = comm_ids[0:1]

        return self.SearchResult(
            total=len(ent_ids) + len(comm_ids) + len(txt_ids),
            ids=[*ent_ids, *comm_ids, *txt_ids],
            query_vector=q_vec,
            aggregation=None,
            highlight=None,
            field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)},
            keywords=[]
        )