Sujithanumala commited on
Commit
789ed70
·
verified ·
1 Parent(s): 9abed2e

Update Classes/Helper_Class.py

Browse files
Files changed (1) hide show
  1. Classes/Helper_Class.py +48 -57
Classes/Helper_Class.py CHANGED
@@ -1,57 +1,48 @@
1
- from typing import List
2
- from PyPDF2 import PdfReader
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- from langchain_google_genai import GoogleGenerativeAIEmbeddings
5
- from langchain_community.vectorstores import FAISS
6
- import os
7
- import google.generativeai as genai
8
-
9
- os.environ["GOOGLE_API_KEY"] = "AIzaSyBoghqvvnMMS4bA61LjQkkPNdIRetqk438"
10
- genai.configure(api_key="AIzaSyBoghqvvnMMS4bA61LjQkkPNdIRetqk438")
11
-
12
-
13
- class GenerateFIASSDB:
14
- def __init__(self,pdf_docs : List[str], save_loc:str, model_embeddings: str = "models/embedding-001")-> None:
15
- self.save_loc = save_loc
16
- self.embedding = model_embeddings
17
- text = self.get_pdf_text(pdf_docs)
18
- text_chunks = self.get_text_chunks(text)
19
- self.get_vector_store(text_chunks)
20
- pass #configure gen ai key from config file
21
-
22
- def get_pdf_text(self,pdf_docs : List[str]) -> str:
23
- text = ""
24
- for pdf in pdf_docs:
25
- pdf_reader= PdfReader(pdf)
26
- for page in pdf_reader.pages:
27
- text+= page.extract_text()
28
- return text
29
-
30
- def get_text_chunks(self, text : str) -> List:
31
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
32
- chunks = text_splitter.split_text(text)
33
- return chunks
34
-
35
- def get_vector_store(self, text_chunks : List) -> None:
36
- embeddings = GoogleGenerativeAIEmbeddings(model = self.embedding)
37
- vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
38
- vector_store.save_local(self.save_loc)
39
-
40
-
41
- class DB_Retriever:
42
- def __init__(self, db_loc : str, model_embeddings : str = "models/embedding-001") -> None:
43
- self.db_loc = db_loc
44
- self.embeddings = GoogleGenerativeAIEmbeddings(model = model_embeddings)
45
- self.db = FAISS.load_local(self.db_loc, self.embeddings,allow_dangerous_deserialization = True)
46
-
47
- def retrieve(self, query : str) -> List[str]:
48
- # docs = self.db.similarity_search(query)
49
- retriver = self.db.as_retriever()
50
- # output_docs = retriver.invoke(query)
51
- # return output_docs
52
- return retriver
53
-
54
- if __name__ =="__main__":
55
- res = DB_Retriever("src/faiss_index").retrieve("What is cloud adapter in google connection?")
56
- print(len(res))
57
- print('\n\n\n\n',res[1])
 
1
+ from typing import List
2
+ from PyPDF2 import PdfReader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ import google.generativeai as genai
7
+
8
+
9
+ class GenerateFIASSDB:
10
+ def __init__(self,pdf_docs : List[str], save_loc:str, model_embeddings: str = "models/embedding-001")-> None:
11
+ self.save_loc = save_loc
12
+ self.embedding = model_embeddings
13
+ text = self.get_pdf_text(pdf_docs)
14
+ text_chunks = self.get_text_chunks(text)
15
+ self.get_vector_store(text_chunks)
16
+ pass #configure gen ai key from config file
17
+
18
+ def get_pdf_text(self,pdf_docs : List[str]) -> str:
19
+ text = ""
20
+ for pdf in pdf_docs:
21
+ pdf_reader= PdfReader(pdf)
22
+ for page in pdf_reader.pages:
23
+ text+= page.extract_text()
24
+ return text
25
+
26
+ def get_text_chunks(self, text : str) -> List:
27
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
28
+ chunks = text_splitter.split_text(text)
29
+ return chunks
30
+
31
+ def get_vector_store(self, text_chunks : List) -> None:
32
+ embeddings = GoogleGenerativeAIEmbeddings(model = self.embedding)
33
+ vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
34
+ vector_store.save_local(self.save_loc)
35
+
36
+
37
+ class DB_Retriever:
38
+ def __init__(self, db_loc : str, model_embeddings : str = "models/embedding-001") -> None:
39
+ self.db_loc = db_loc
40
+ self.embeddings = GoogleGenerativeAIEmbeddings(model = model_embeddings)
41
+ self.db = FAISS.load_local(self.db_loc, self.embeddings,allow_dangerous_deserialization = True)
42
+
43
+ def retrieve(self, query : str) -> List[str]:
44
+ # docs = self.db.similarity_search(query)
45
+ retriver = self.db.as_retriever()
46
+ # output_docs = retriver.invoke(query)
47
+ # return output_docs
48
+ return retriver