File size: 7,203 Bytes
f5e5ccb
c726440
f5e5ccb
 
 
 
 
 
 
 
 
458b338
 
 
f5e5ccb
458b338
 
 
b3ec1fd
458b338
 
b3ec1fd
 
 
 
458b338
 
f5e5ccb
458b338
f5e5ccb
458b338
 
 
 
 
 
f5e5ccb
 
 
 
 
 
 
 
 
 
 
 
 
458b338
f5e5ccb
 
 
 
 
 
 
 
 
 
 
 
 
 
337280a
 
 
 
 
 
f5e5ccb
 
 
 
 
ee5bae7
f5e5ccb
 
 
 
ee5bae7
 
f5e5ccb
 
 
 
 
 
 
 
 
 
 
 
 
458b338
 
f5e5ccb
 
 
 
 
 
 
 
 
 
 
13cf722
f5e5ccb
 
 
 
13cf722
 
dc8b4b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13cf722
34798a4
13cf722
 
 
 
 
 
 
b3ec1fd
13cf722
f5e5ccb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import glob
import json
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from transformers import AutoTokenizer
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from auditqa.reports import files, report_list
from langchain.docstore.document import Document
import configparser

# read all the necessary variables
device = 'cuda' if cuda.is_available() else 'cpu'
path_to_data = "./reports/"       


##---------------------functions -------------------------------------------##
def getconfig(configfile_path:str):
    """
    Read the config file

    Params
    ----------------
    configfile_path: file path of .cfg file
    """

    config = configparser.ConfigParser()

    try:
        config.read_file(open(configfile_path))
        return config
    except:
        logging.warning("config file not found")
        
def open_file(filepath):
    with open(filepath) as file:
        simple_json = json.load(file)
    return simple_json

def load_chunks():
    """
    this method reads through the files and report_list to create the vector database
    """

    #  we iterate through the files which contain information about its
    # 'source'=='category', 'subtype', these are used in UI for document selection
    #  which will be used later for filtering database
    config = getconfig("./model_params.cfg")
    all_documents = {}
    categories = list(files.keys())
    # iterate through 'source'
    for category in categories:
        print("documents splitting in source:",category)
        all_documents[category] = []
        subtypes = list(files[category].keys())
        # iterate through 'subtype' within the source
        # example source/category == 'District', has subtypes which is district names
        for subtype in subtypes:
            print("document splitting for subtype:",subtype)
            for file in files[category][subtype]:

                # load the chunks
                try:
                    doc_processed = open_file(path_to_data + file + "/"+ file+ ".chunks.json" )

                
                except Exception as e:
                    print("Exception: ", e)
                print("chunks in subtype:",subtype, "are:",len(doc_processed))

                # add metadata information 
                chunks_list = []
                for doc in doc_processed:
                    chunks_list.append(Document(page_content= doc['content'], 
                             metadata={"source": category,
                                      "subtype":subtype,
                                      "year":file[-4:],
                                      "filename":file,
                                      "page":doc['metadata']['page'],
                                      "headings":doc['metadata']['headings']}))

                all_documents[category].append(chunks_list)
    
    # convert list of list to flat list
    for key, docs_processed in all_documents.items():
        docs_processed = [item for sublist in docs_processed for item in sublist]
        print("length of chunks in source:",key, "are:",len(docs_processed))
        all_documents[key] = docs_processed
    all_documents['allreports'] = [sublist for key,sublist in all_documents.items()]
    all_documents['allreports'] = [item for sublist in all_documents['allreports'] for item in sublist]
    # define embedding model
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE')))},
        model_name=config.get('retriever','MODEL')
    )
    # placeholder for collection
    qdrant_collections = {}
    
    
    for file,value in all_documents.items():
        if file == "allreports":
            print("emebddings for:",file)
            qdrant_collections[file] = Qdrant.from_documents(
                value,
                embeddings,
                path="/data/local_qdrant",
                collection_name=file,
            )
    print(qdrant_collections)
    print("vector embeddings done")
    return qdrant_collections

def load_new_chunks():
    """
    this method reads through the files and report_list to create the vector database
    """

    #  we iterate through the files which contain information about its
    # 'source'=='category', 'subtype', these are used in UI for document selection
    #  which will be used later for filtering database
    config = getconfig("./model_params.cfg")
    files = pd.read_json("./axa_processed_chunks_update.json")
    all_documents= []
    # iterate through 'source'
    for i in range(len(files)):
        # load the chunks
        try:
            doc_processed = open_file(path_to_data + "/chunks/"+ os.path.basename(files.loc[i,'chunks_filepath']))
            doc_processed = doc_processed['paragraphs']

        except Exception as e:
            print("Exception: ", e)
        print("chunks in subtype:", files.loc[0,'filename'], "are:",len(doc_processed))

        # add metadata information 
        
        for doc in doc_processed:
            all_documents.append(Document(page_content= doc['content'], 
                        metadata={"source": files.loc[i,'category'],
                                "subtype":os.path.splitext(files.loc[i,'filename'])[0],
                                "year":files.loc[i,'year'],
                                "filename":files.loc[0,'filename'],
                                "page":doc['metadata']['page'],
                                "headings":doc['metadata']['headings']}))
    
    # convert list of list to flat list
    print("length of chunks:",len(all_documents))

    # define embedding model
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE')))},
        model_name=config.get('retriever','MODEL')
    )
    # placeholder for collection
    qdrant_collections = {}  
    qdrant_collections['allreports'] = Qdrant.from_documents(
                all_documents,
                embeddings,
                path="/data/local_qdrant",
                collection_name='allreports',
            )
    print(qdrant_collections)
    print("vector embeddings done")
    return qdrant_collections

def get_local_qdrant(): 
    """once the local qdrant server is created this is used to make the connection to exisitng server"""
    config = getconfig("./model_params.cfg")
    qdrant_collections = {}
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': True},
        model_name=config.get('retriever','MODEL'))
    client = QdrantClient(path="/data/local_qdrant") 
    print("Collections in local Qdrant:",client.get_collections())
    qdrant_collections['allreports'] = Qdrant(client=client, collection_name='allreports', embeddings=embeddings, )
    return qdrant_collections