Spaces:
Runtime error
Runtime error
File size: 4,264 Bytes
04f287e 7ae21d5 04f287e 6702158 04f287e 6702158 04f287e 6702158 04f287e 6702158 04f287e 6702158 04f287e 857b56b 3b6480c 6702158 3b6480c 857b56b 3b6480c 6702158 3b6480c 6702158 3b6480c 04f287e 3b6480c 6702158 5ded842 347dbcf 6702158 5ded842 6702158 5ded842 6702158 5ded842 1a3a52c 5d0e7b5 1a3a52c 6702158 5ded842 5930e3b 7ae21d5 5930e3b 7ae21d5 08bcc7a 9e55f11 7ae21d5 5930e3b 04f287e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import glob
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from transformers import AutoTokenizer
from torch import cuda
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from auditqa.reports import files, report_list
device = 'cuda' if cuda.is_available() else 'cpu'
# path to the pdf files
path_to_data = "./data/pdf/"
def process_pdf():
"""
this method reads through the files and report_list to create the vector database
"""
# load all the files using PyMuPDFfLoader
docs = {}
for file in report_list:
try:
docs[file] = PyMuPDFLoader(path_to_data + file + '.pdf').load()
except Exception as e:
print("Exception: ", e)
# text splitter based on the tokenizer of a model of your choosing
# to make texts fit exactly a transformer's context window size
# langchain text splitters: https://python.langchain.com/docs/modules/data_connection/document_transformers/
chunk_size = 256
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5"),
chunk_size=chunk_size,
chunk_overlap=10,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n"],
)
# we iterate through the files which contain information about its
# 'source'=='category', 'subtype', these are used in UI for document selection
# which will be used later for filtering database
all_documents = {}
categories = list(files.keys())
# iterate through 'source'
for category in categories:
print(category)
all_documents[category] = []
subtypes = list(files[category].keys())
# iterate through 'subtype' within the source
# example source/category == 'District', has subtypes which is district names
for subtype in subtypes:
print(subtype)
for file in files[category][subtype]:
# create the chunks
doc_processed = text_splitter.split_documents(docs[file])
# add metadata information
for doc in doc_processed:
doc.metadata["source"] = category
doc.metadata["subtype"] = subtype
doc.metadata["year"] = file[-4:]
all_documents[category].append(doc_processed)
# convert list of list to flat list
for key, docs_processed in all_documents.items():
docs_processed = [item for sublist in docs_processed for item in sublist]
all_documents[key] = docs_processed
all_documents['allreports'] = [sublist for key,sublist in all_documents.items()]
all_documents['allreports'] = [item for sublist in all_documents['allreports'] for item in sublist]
# define embedding model
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name="BAAI/bge-small-en-v1.5"
)
# placeholder for collection
qdrant_collections = {}
for file,value in all_documents.items():
print("emebddings for:",file)
qdrant_collections[file] = Qdrant.from_documents(
value,
embeddings,
path=f"./data/{file}",
collection_name=file,
)
print("done")
return qdrant_collections
def get_local_qdrant():
qdrant_collections = {}
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name="BAAI/bge-small-en-v1.5")
list_ = ['Consolidated','District','Ministry','allreports']
for val in list_:
client = QdrantClient(path=f"./data/{val}")
print(client.get_collections())
qdrant_collections[val] = Qdrant(client=client, collection_name=val, embeddings=embeddings, )
return qdrant_collections
|