File size: 4,984 Bytes
f5e5ccb
c726440
f5e5ccb
 
 
 
 
 
 
 
 
458b338
 
 
f5e5ccb
458b338
 
 
 
 
 
 
 
f5e5ccb
458b338
f5e5ccb
458b338
 
 
 
 
 
f5e5ccb
 
 
 
 
 
 
 
 
 
 
 
 
458b338
f5e5ccb
 
 
 
 
 
 
 
 
 
 
 
 
 
337280a
 
 
 
 
 
f5e5ccb
 
 
 
 
ee5bae7
f5e5ccb
 
 
 
ee5bae7
 
f5e5ccb
 
 
 
 
 
 
 
 
 
 
 
 
458b338
 
f5e5ccb
 
 
 
 
 
 
 
 
 
 
13cf722
f5e5ccb
 
 
 
13cf722
 
 
34798a4
13cf722
 
 
 
 
 
 
 
 
 
 
f5e5ccb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import glob
import json
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from transformers import AutoTokenizer
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from auditqa.reports import files, report_list
from langchain.docstore.document import Document
import configparser

# read all the necessary variables
device = 'cuda' if cuda.is_available() else 'cpu'
path_to_data = "./reports/"       


##---------------------fucntions -------------------------------------------##
def getconfig(configfile_path:str):
    """
    configfile_path: file path of .cfg file
    """

    config = configparser.ConfigParser()

    try:
        config.read_file(open(configfile_path))
        return config
    except:
        logging.warning("config file not found")
        
def open_file(filepath):
    with open(filepath) as file:
        simple_json = json.load(file)
    return simple_json

def load_chunks():
    """
    this method reads through the files and report_list to create the vector database
    """

    #  we iterate through the files which contain information about its
    # 'source'=='category', 'subtype', these are used in UI for document selection
    #  which will be used later for filtering database
    config = getconfig("./model_params.cfg")
    all_documents = {}
    categories = list(files.keys())
    # iterate through 'source'
    for category in categories:
        print("documents splitting in source:",category)
        all_documents[category] = []
        subtypes = list(files[category].keys())
        # iterate through 'subtype' within the source
        # example source/category == 'District', has subtypes which is district names
        for subtype in subtypes:
            print("document splitting for subtype:",subtype)
            for file in files[category][subtype]:

                # load the chunks
                try:
                    doc_processed = open_file(path_to_data + file + "/"+ file+ ".chunks.json" )

                
                except Exception as e:
                    print("Exception: ", e)
                print("chunks in subtype:",subtype, "are:",len(doc_processed))

                # add metadata information 
                chunks_list = []
                for doc in doc_processed:
                    chunks_list.append(Document(page_content= doc['content'], 
                             metadata={"source": category,
                                      "subtype":subtype,
                                      "year":file[-4:],
                                      "filename":file,
                                      "page":doc['metadata']['page'],
                                      "headings":doc['metadata']['headings']}))

                all_documents[category].append(chunks_list)
    
    # convert list of list to flat list
    for key, docs_processed in all_documents.items():
        docs_processed = [item for sublist in docs_processed for item in sublist]
        print("length of chunks in source:",key, "are:",len(docs_processed))
        all_documents[key] = docs_processed
    all_documents['allreports'] = [sublist for key,sublist in all_documents.items()]
    all_documents['allreports'] = [item for sublist in all_documents['allreports'] for item in sublist]
    # define embedding model
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE')))},
        model_name=config.get('retriever','MODEL')
    )
    # placeholder for collection
    qdrant_collections = {}
    
    
    for file,value in all_documents.items():
        if file == "allreports":
            print("emebddings for:",file)
            qdrant_collections[file] = Qdrant.from_documents(
                value,
                embeddings,
                path="/data/local_qdrant",
                collection_name=file,
            )
    print(qdrant_collections)
    print("vector embeddings done")
    return qdrant_collections

def get_local_qdrant(): 
    """once the local qdrant server is created this is used to make the connection to exisitng server"""
    config = getconfig("./model_params.cfg")
    qdrant_collections = {}
    embeddings = HuggingFaceEmbeddings(
        model_kwargs = {'device': device},
        encode_kwargs = {'normalize_embeddings': True},
        model_name=config.get('retriever','MODEL'))
    #list_ = ['Consolidated','District','Ministry','allreports']
    #for val in list_:
    client = QdrantClient(path="/data/local_qdrant") 
    print(client.get_collections())
    qdrant_collections['allreports'] = Qdrant(client=client, collection_name='allreports', embeddings=embeddings, )
    return qdrant_collections