Karthikeyen92 commited on
Commit
f761616
·
verified ·
1 Parent(s): 141b0a0

Update py/db_storage.py

Browse files
Files changed (1) hide show
  1. py/db_storage.py +182 -182
py/db_storage.py CHANGED
@@ -1,183 +1,183 @@
1
- import os
2
- import warnings
3
- import shutil
4
- from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
5
- from langchain_community.vectorstores import Chroma
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain.chains import RetrievalQA
8
- from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader, WikipediaLoader
9
- from typing import List, Optional, Dict, Any
10
- from langchain.schema import Document
11
- import chromadb
12
- # from langchain_community.embeddings.sentence_transformer import (SentenceTransformerEmbeddings)
13
- from langchain_community.vectorstores import FAISS
14
-
15
-
16
-
17
- warnings.filterwarnings("ignore")
18
- CHROMA_DB_PATH = os.path.join(os.getcwd(), "Stock Sentiment Analysis", "chroma_db")
19
- # FAISS_DB_PATH = os.path.join(os.getcwd(), "Stock Sentiment Analysis", "faiss_index")
20
- tesla_10k_collection = 'tesla-10k-2019-to-2023'
21
- embedding_model = ""
22
- # embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
23
-
24
-
25
- class DBStorage:
26
- def __init__(self):
27
- self.CHROMA_PATH = CHROMA_DB_PATH
28
- self.vector_store = None
29
- self.client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
30
- print(self.client.list_collections())
31
- self.collection = self.client.get_or_create_collection(name=tesla_10k_collection)
32
- print(self.collection.count())
33
-
34
- def chunk_data(self, data, chunk_size=10000):
35
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
36
- return text_splitter.split_documents(data)
37
-
38
- def create_embeddings(self, chunks):
39
- embeddings = AzureOpenAIEmbeddings(
40
- model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
41
- api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
42
- api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
43
- azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
44
- )
45
-
46
- self.vector_store = Chroma.from_documents(documents=chunks,
47
- # embedding=embeddings,
48
- embedding=embedding_model,
49
- collection_name=tesla_10k_collection,
50
- persist_directory=self.CHROMA_PATH)
51
- print("Here B")
52
- self.collection = self.client.get_or_create_collection(name=tesla_10k_collection)
53
- print("here"+str(self.collection.count()))
54
- # return self.vector_store
55
-
56
- def create_vector_store(self, chunks):
57
- embeddings = AzureOpenAIEmbeddings(
58
- model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
59
- api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
60
- api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
61
- azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
62
- )
63
- return FAISS.from_documents(chunks, embedding=embeddings)
64
- # vector_store.save_local(FAISS_DB_PATH)
65
-
66
-
67
- def load_embeddings(self):
68
- embeddings = AzureOpenAIEmbeddings(
69
- model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
70
- api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
71
- api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
72
- azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
73
- )
74
-
75
- self.vector_store = Chroma(collection_name=tesla_10k_collection,
76
- persist_directory=CHROMA_DB_PATH,
77
- # embedding_function=embeddings
78
- embedding_function=embedding_model
79
- )
80
- print("loaded vector store: ")
81
- print(self.vector_store)
82
- # return self.vector_store
83
-
84
- def load_vectors(self,FAISS_DB_PATH):
85
- embeddings = AzureOpenAIEmbeddings(
86
- model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
87
- api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
88
- api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
89
- azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
90
- )
91
-
92
- self.vector_store = FAISS.load_local(folder_path=FAISS_DB_PATH,
93
- embeddings=embeddings,
94
- allow_dangerous_deserialization=True)
95
-
96
-
97
-
98
- def fetch_documents(self, metadata_filter: Dict[str, Any]) -> List[Document]:
99
- results = self.collection.get(
100
- where=metadata_filter,
101
- include=["documents", "metadatas"],
102
- )
103
-
104
- documents = []
105
- for content, metadata in zip(results['documents'][0], results['metadatas'][0]):
106
- documents.append(Document(page_content=content, metadata=metadata))
107
-
108
- return documents
109
-
110
-
111
- def get_context_for_query(self, question, k=3):
112
- print(self.vector_store)
113
- # if not self.vector_store:
114
- # raise ValueError("Vector store not initialized. Call create_embeddings() or load_embeddings() first.")
115
-
116
- # relevant_document_chunks=self.fetch_documents({"company": question})
117
-
118
- # retriever = self.vector_store.as_retriever(search_type='similarity', search_kwargs={'k': k})
119
- # relevant_document_chunks = retriever.get_relevant_documents(question)
120
-
121
- relevant_document_chunks = self.vector_store.similarity_search(question)
122
- # chain = get_conversational_chain(models.llm)
123
- # response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
124
- # print(response)
125
-
126
- print(relevant_document_chunks)
127
- context_list = [d.page_content for d in relevant_document_chunks]
128
- context_for_query = ". ".join(context_list)
129
- print("context_for_query: "+ str(len(context_for_query)))
130
-
131
- return context_for_query
132
-
133
- # def ask_question(self, question, k=3):
134
- # if not self.vector_store:
135
- # raise ValueError("Vector store not initialized. Call create_embeddings() or load_embeddings() first.")
136
-
137
- # llm = AzureChatOpenAI(
138
- # temperature=0,
139
- # api_key=os.getenv("AZURE_OPENAI_API_KEY"),
140
- # api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
141
- # azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
142
- # model=os.getenv("AZURE_OPENAI_MODEL_NAME")
143
- # )
144
-
145
- # retriever = self.vector_store.as_retriever(search_type='similarity', search_kwargs={'k': k})
146
- # chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
147
-
148
- # return chain.invoke(question)
149
-
150
- def embed_vectors(self,social_media_document,FAISS_DB_PATH):
151
- print("here A")
152
- chunks = self.chunk_data(social_media_document)
153
- print(len(chunks))
154
- # self.create_embeddings(chunks)
155
- vector_store = self.create_vector_store(chunks)
156
- check_and_delete(FAISS_DB_PATH)
157
- vector_store.save_local(FAISS_DB_PATH)
158
-
159
- def check_and_delete(PATH):
160
- if os.path.isdir(PATH):
161
- shutil.rmtree(PATH, onexc=lambda func, path, exc: os.chmod(path, 0o777))
162
- print(f'Deleted {PATH}')
163
-
164
- def clear_db():
165
- check_and_delete(CHROMA_DB_PATH)
166
- # check_and_delete(FAISS_DB_PATH)
167
-
168
-
169
- # Usage example
170
- if __name__ == "__main__":
171
- qa_system = DBStorage()
172
-
173
- # Load and process document
174
- social_media_document = []
175
- chunks = qa_system.chunk_data(social_media_document)
176
-
177
- # Create embeddings
178
- qa_system.create_embeddings(chunks)
179
-
180
- # # Ask a question
181
- # question = 'Summarize the whole input in 150 words'
182
- # answer = qa_system.ask_question(question)
183
  # print(answer)
 
1
+ import os
2
+ import warnings
3
+ import shutil
4
+ from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
5
+ from langchain_community.vectorstores import Chroma
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.chains import RetrievalQA
8
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader, WikipediaLoader
9
+ from typing import List, Optional, Dict, Any
10
+ from langchain.schema import Document
11
+ import chromadb
12
+ # from langchain_community.embeddings.sentence_transformer import (SentenceTransformerEmbeddings)
13
+ from langchain_community.vectorstores import FAISS
14
+
15
+
16
+
17
+ warnings.filterwarnings("ignore")
18
+ CHROMA_DB_PATH = os.path.join(os.getcwd(), "chroma_db")
19
+ # FAISS_DB_PATH = os.path.join(os.getcwd(), "faiss_index")
20
+ tesla_10k_collection = 'tesla-10k-2019-to-2023'
21
+ embedding_model = ""
22
+ # embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
23
+
24
+
25
+ class DBStorage:
26
+ def __init__(self):
27
+ self.CHROMA_PATH = CHROMA_DB_PATH
28
+ self.vector_store = None
29
+ self.client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
30
+ print(self.client.list_collections())
31
+ self.collection = self.client.get_or_create_collection(name=tesla_10k_collection)
32
+ print(self.collection.count())
33
+
34
+ def chunk_data(self, data, chunk_size=10000):
35
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
36
+ return text_splitter.split_documents(data)
37
+
38
+ def create_embeddings(self, chunks):
39
+ embeddings = AzureOpenAIEmbeddings(
40
+ model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
41
+ api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
42
+ api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
43
+ azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
44
+ )
45
+
46
+ self.vector_store = Chroma.from_documents(documents=chunks,
47
+ # embedding=embeddings,
48
+ embedding=embedding_model,
49
+ collection_name=tesla_10k_collection,
50
+ persist_directory=self.CHROMA_PATH)
51
+ print("Here B")
52
+ self.collection = self.client.get_or_create_collection(name=tesla_10k_collection)
53
+ print("here"+str(self.collection.count()))
54
+ # return self.vector_store
55
+
56
+ def create_vector_store(self, chunks):
57
+ embeddings = AzureOpenAIEmbeddings(
58
+ model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
59
+ api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
60
+ api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
61
+ azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
62
+ )
63
+ return FAISS.from_documents(chunks, embedding=embeddings)
64
+ # vector_store.save_local(FAISS_DB_PATH)
65
+
66
+
67
+ def load_embeddings(self):
68
+ embeddings = AzureOpenAIEmbeddings(
69
+ model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
70
+ api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
71
+ api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
72
+ azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
73
+ )
74
+
75
+ self.vector_store = Chroma(collection_name=tesla_10k_collection,
76
+ persist_directory=CHROMA_DB_PATH,
77
+ # embedding_function=embeddings
78
+ embedding_function=embedding_model
79
+ )
80
+ print("loaded vector store: ")
81
+ print(self.vector_store)
82
+ # return self.vector_store
83
+
84
+ def load_vectors(self,FAISS_DB_PATH):
85
+ embeddings = AzureOpenAIEmbeddings(
86
+ model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
87
+ api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
88
+ api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
89
+ azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
90
+ )
91
+
92
+ self.vector_store = FAISS.load_local(folder_path=FAISS_DB_PATH,
93
+ embeddings=embeddings,
94
+ allow_dangerous_deserialization=True)
95
+
96
+
97
+
98
+ def fetch_documents(self, metadata_filter: Dict[str, Any]) -> List[Document]:
99
+ results = self.collection.get(
100
+ where=metadata_filter,
101
+ include=["documents", "metadatas"],
102
+ )
103
+
104
+ documents = []
105
+ for content, metadata in zip(results['documents'][0], results['metadatas'][0]):
106
+ documents.append(Document(page_content=content, metadata=metadata))
107
+
108
+ return documents
109
+
110
+
111
+ def get_context_for_query(self, question, k=3):
112
+ print(self.vector_store)
113
+ # if not self.vector_store:
114
+ # raise ValueError("Vector store not initialized. Call create_embeddings() or load_embeddings() first.")
115
+
116
+ # relevant_document_chunks=self.fetch_documents({"company": question})
117
+
118
+ # retriever = self.vector_store.as_retriever(search_type='similarity', search_kwargs={'k': k})
119
+ # relevant_document_chunks = retriever.get_relevant_documents(question)
120
+
121
+ relevant_document_chunks = self.vector_store.similarity_search(question)
122
+ # chain = get_conversational_chain(models.llm)
123
+ # response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
124
+ # print(response)
125
+
126
+ print(relevant_document_chunks)
127
+ context_list = [d.page_content for d in relevant_document_chunks]
128
+ context_for_query = ". ".join(context_list)
129
+ print("context_for_query: "+ str(len(context_for_query)))
130
+
131
+ return context_for_query
132
+
133
+ # def ask_question(self, question, k=3):
134
+ # if not self.vector_store:
135
+ # raise ValueError("Vector store not initialized. Call create_embeddings() or load_embeddings() first.")
136
+
137
+ # llm = AzureChatOpenAI(
138
+ # temperature=0,
139
+ # api_key=os.getenv("AZURE_OPENAI_API_KEY"),
140
+ # api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
141
+ # azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
142
+ # model=os.getenv("AZURE_OPENAI_MODEL_NAME")
143
+ # )
144
+
145
+ # retriever = self.vector_store.as_retriever(search_type='similarity', search_kwargs={'k': k})
146
+ # chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
147
+
148
+ # return chain.invoke(question)
149
+
150
+ def embed_vectors(self,social_media_document,FAISS_DB_PATH):
151
+ print("here A")
152
+ chunks = self.chunk_data(social_media_document)
153
+ print(len(chunks))
154
+ # self.create_embeddings(chunks)
155
+ vector_store = self.create_vector_store(chunks)
156
+ check_and_delete(FAISS_DB_PATH)
157
+ vector_store.save_local(FAISS_DB_PATH)
158
+
159
+ def check_and_delete(PATH):
160
+ if os.path.isdir(PATH):
161
+ shutil.rmtree(PATH, onexc=lambda func, path, exc: os.chmod(path, 0o777))
162
+ print(f'Deleted {PATH}')
163
+
164
+ def clear_db():
165
+ check_and_delete(CHROMA_DB_PATH)
166
+ # check_and_delete(FAISS_DB_PATH)
167
+
168
+
169
+ # Usage example
170
+ if __name__ == "__main__":
171
+ qa_system = DBStorage()
172
+
173
+ # Load and process document
174
+ social_media_document = []
175
+ chunks = qa_system.chunk_data(social_media_document)
176
+
177
+ # Create embeddings
178
+ qa_system.create_embeddings(chunks)
179
+
180
+ # # Ask a question
181
+ # question = 'Summarize the whole input in 150 words'
182
+ # answer = qa_system.ask_question(question)
183
  # print(answer)