File size: 2,122 Bytes
789ed70
 
 
 
 
 
72f3806
789ed70
40fbbca
789ed70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40fbbca
 
 
 
789ed70
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from typing import List
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import FAISS
import google.generativeai as genai
import os



class GenerateFIASSDB:
    def __init__(self,pdf_docs : List[str], save_loc:str, model_embeddings: str = "models/embedding-001")-> None:
        self.save_loc = save_loc
        self.embedding = model_embeddings
        text = self.get_pdf_text(pdf_docs)
        text_chunks = self.get_text_chunks(text)
        self.get_vector_store(text_chunks)
        pass  #configure gen ai key from config file

    def get_pdf_text(self,pdf_docs : List[str]) -> str:
        text = ""
        for pdf in pdf_docs:
            pdf_reader= PdfReader(pdf)
            for page in pdf_reader.pages:
                text+= page.extract_text()
        return text
    
    def get_text_chunks(self, text : str) -> List:
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
        chunks = text_splitter.split_text(text)
        return chunks
    
    def get_vector_store(self, text_chunks : List) -> None:
        embeddings = GoogleGenerativeAIEmbeddings(model = self.embedding)
        vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
        vector_store.save_local(self.save_loc)

class DB_Retriever:
    def __init__(self, db_loc : str, model_embeddings : str = "models/embedding-001") -> None:
        self.db_loc = db_loc
        try:
            genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
        except Exception as e:
            print(e)
        self.embeddings = GoogleGenerativeAIEmbeddings(model = model_embeddings)
        self.db = FAISS.load_local(self.db_loc, self.embeddings,allow_dangerous_deserialization  = True)
    
    def retrieve(self, query : str) -> List[str]:
        # docs = self.db.similarity_search(query)
        retriver = self.db.as_retriever()
        # output_docs =  retriver.invoke(query)
        # return output_docs
        return retriver