rk68 commited on
Commit
13e12b7
·
verified ·
1 Parent(s): c0d7c30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -0
app.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import pandas as pd
4
+ import argparse
5
+ import streamlit as st
6
+ from pinecone import Pinecone
7
+ from llama_index.llms.gemini import Gemini
8
+ from llama_index.vector_stores.pinecone import PineconeVectorStore
9
+ from llama_index.core import StorageContext, VectorStoreIndex, SimpleDirectoryReader, get_response_synthesizer
10
+ from llama_index.core import Settings
11
+ from dotenv import load_dotenv
12
+ from llama_index.core.node_parser import SentenceSplitter
13
+ from llama_parse import LlamaParse
14
+ from llama_index.core.retrievers import VectorIndexRetriever
15
+ from llama_index.core.retrievers import RouterRetriever
16
+ from llama_index.retrievers.bm25 import BM25Retriever
17
+ from llama_index.core.tools import RetrieverTool
18
+ from llama_index.core.query_engine import RetrieverQueryEngine
19
+ from langchain_groq import ChatGroq
20
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
21
+ from llama_index.llms.azure_openai import AzureOpenAI
22
+ from llama_index.embeddings.openai import OpenAIEmbedding
23
+ from sklearn.metrics.pairwise import cosine_similarity
24
+ from sklearn.feature_extraction.text import TfidfVectorizer
25
+ from llama_index.core.query_engine import FLAREInstructQueryEngine, MultiStepQueryEngine
26
+ from llama_index.core.indices.query.query_transform import HyDEQueryTransform
27
+ from llama_index.core.query_engine import TransformQueryEngine
28
+ from llama_index.core.indices.query.query_transform.base import (
29
+ StepDecomposeQueryTransform,
30
+ )
31
+
32
+ # Configure logging
33
+ logging.basicConfig(level=logging.INFO)
34
+
35
+ # Load environment variables from .env file
36
+ load_dotenv()
37
+
38
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
+ # Fetch API keys from environment variables
40
+ pinecone_api_key = os.getenv('PINECONE_API_KEY')
41
+ parse_api_key = os.getenv('PARSE_API_KEY')
42
+ azure_api_key = os.getenv('AZURE_API_KEY')
43
+ azure_endpoint = os.getenv('AZURE_ENDPOINT')
44
+ azure_deployment_name = os.getenv('AZURE_DEPLOYMENT_NAME')
45
+ azure_api_version = os.getenv('AZURE_API_VERSION')
46
+
47
+ # Global variables for lazy loading
48
+ llm = None
49
+ pinecone_index = None
50
+ query_engine = None
51
+
52
+ def log_and_exit(message):
53
+ logging.error(message)
54
+ raise SystemExit(message)
55
+
56
+ def initialize_apis(api, model):
57
+ global llm, pinecone_index
58
+ try:
59
+ if llm is None:
60
+ if api == 'groq':
61
+ if model == 'mixtral-8x7b':
62
+ llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
63
+ elif model == 'llama3-8b':
64
+ llm = ChatGroq(model="llama3-8b-8192", temperature=0)
65
+ elif model == "llama3-70b":
66
+ llm = ChatGroq(model="llama3-70b-8192", temperature=0)
67
+ elif model == "gemma-7b":
68
+ llm = ChatGroq(model="gemma-7b-it", temperature=0)
69
+
70
+ elif api == 'azure':
71
+ if model == 'gpt35':
72
+ llm = AzureOpenAI(
73
+ deployment_name="gpt35",
74
+ temperature=0,
75
+ api_key=azure_api_key,
76
+ azure_endpoint=azure_endpoint,
77
+ api_version=azure_api_version
78
+ )
79
+
80
+ if pinecone_index is None:
81
+ index_name = "demo"
82
+ pinecone_client = Pinecone(pinecone_api_key)
83
+ pinecone_index = pinecone_client.Index(index_name)
84
+ logging.info("Initialized LLM and Pinecone.")
85
+ except Exception as e:
86
+ log_and_exit(f"Error initializing APIs: {e}")
87
+
88
+ def load_pdf_data():
89
+ PDF_FILE_PATH = "policy.pdf"
90
+ try:
91
+ parser = LlamaParse(api_key=parse_api_key, result_type="markdown")
92
+ file_extractor = {".pdf": parser}
93
+ documents = SimpleDirectoryReader(input_files=[PDF_FILE_PATH], file_extractor=file_extractor).load_data()
94
+ logging.info(f"Loaded {len(documents)} documents from PDF.")
95
+ return documents
96
+ except Exception as e:
97
+ log_and_exit(f"Error loading PDF file: {e}")
98
+
99
+ def create_index(documents, embedding_model_type="HF", embedding_model="BAAI/bge-large-en-v1.5", retriever_method="BM25"):
100
+ global llm, pinecone_index
101
+ try:
102
+ if embedding_model_type == "HF":
103
+ embed_model = HuggingFaceEmbedding(model_name=embedding_model)
104
+ elif embedding_model_type == "OAI":
105
+ # embed_model = OpenAIEmbedding() implement oai EMBEDDING
106
+ pass
107
+
108
+ Settings.llm = llm
109
+ Settings.embed_model = embed_model
110
+ Settings.chunk_size = 512
111
+
112
+ if retriever_method == "BM25" or retriever_method == "BM25+Vector":
113
+ splitter = SentenceSplitter(chunk_size=512)
114
+ nodes = splitter.get_nodes_from_documents(documents)
115
+ storage_context = StorageContext.from_defaults()
116
+ return None, nodes # Return None for index when using BM25
117
+ else:
118
+ vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
119
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
120
+ index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
121
+ logging.info("Created index from documents.")
122
+ return index, None # Return None for nodes when not using BM25
123
+ except Exception as e:
124
+ log_and_exit(f"Error creating index: {e}")
125
+
126
+ def setup_query_engine(index, response_mode, nodes=None, query_engine_method=None, retriever_method=None):
127
+ global llm
128
+ try:
129
+ logging.info(f"Setting up query engine with retriever_method: {retriever_method} and query_engine_method: {query_engine_method}")
130
+
131
+ if retriever_method == 'BM25':
132
+ retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)
133
+ elif retriever_method == "BM25+Vector":
134
+ vector_retriever = VectorIndexRetriever(index=index)
135
+ bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)
136
+
137
+ retriever_tools = [
138
+ RetrieverTool.from_defaults(
139
+ retriever=vector_retriever,
140
+ description='Useful in most cases',
141
+ ),
142
+ RetrieverTool.from_defaults(
143
+ retriever=bm25_retriever,
144
+ description='Useful if searching about specific information',
145
+ ),
146
+ ]
147
+ retriever = RouterRetriever.from_defaults(
148
+ retriever_tools=retriever_tools,
149
+ llm=llm,
150
+ select_multi=True
151
+ )
152
+ else:
153
+ retriever = VectorIndexRetriever(index=index, similarity_top_k=2)
154
+
155
+ response_synthesizer = get_response_synthesizer(response_mode=response_mode)
156
+ index_query_engine = index.as_query_engine(similarity_top_k=2) if index else None
157
+
158
+ if query_engine_method == "FLARE":
159
+ query_engine = FLAREInstructQueryEngine(
160
+ query_engine=index_query_engine,
161
+ max_iterations=7,
162
+ verbose=True
163
+ )
164
+ elif query_engine_method == "MS":
165
+ step_decompose_transform = StepDecomposeQueryTransform(llm=llm, verbose=True)
166
+ index_summary = "Used to answer questions about the regulation"
167
+ query_engine = MultiStepQueryEngine(
168
+ query_engine=index_query_engine,
169
+ query_transform=step_decompose_transform,
170
+ index_summary=index_summary
171
+ )
172
+ else:
173
+ query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer)
174
+ return query_engine
175
+ except Exception as e:
176
+ log_and_exit(f"Error setting up query engine: {e}")
177
+
178
+ def retrieve_and_update_contexts(api, model, file_path):
179
+ global query_engine
180
+
181
+ if query_engine is None:
182
+ initialize_apis(api, model)
183
+ documents = load_pdf_data()
184
+ _, nodes = create_index(documents, retriever_method='BM25')
185
+ query_engine = setup_query_engine(None, response_mode="compact_accumulate", nodes=nodes, retriever_method='BM25')
186
+
187
+ df = pd.read_csv(file_path)
188
+
189
+ for idx, row in df.iterrows():
190
+ question = row['question']
191
+ response = query_engine.query(question)
192
+ retrieved_nodes = response.source_nodes
193
+ chunks = [node.text for node in retrieved_nodes]
194
+ logging.info(f"Context response for question {idx}: {response}")
195
+ df.at[idx, 'contexts'] = " ".join(chunks)
196
+
197
+ df.to_csv(file_path, index=False)
198
+ logging.info(f"Processed questions and updated the CSV file: {file_path}")
199
+
200
+ def retrieve_answers_for_modes(api, model, file_path):
201
+ global query_engine
202
+
203
+ df = pd.read_csv(file_path)
204
+ initialize_apis(api, model)
205
+ documents = load_pdf_data()
206
+ index, _ = create_index(documents)
207
+
208
+ response_modes = ["refine", "compact", "tree_summarize", "simple_summarize"]
209
+
210
+ for idx, row in df.iterrows():
211
+ question = row['question']
212
+ for mode in response_modes:
213
+ query_engine = setup_query_engine(index, response_mode=mode, retriever_method='Default')
214
+ response = query_engine.query(question)
215
+ answer_column = f"{mode}_answer"
216
+ df.at[idx, answer_column] = response.response
217
+
218
+ df.to_csv(file_path, index=False)
219
+ logging.info(f"Processed questions and updated the CSV file with answers: {file_path}")
220
+
221
+ def run_streamlit_app(api, model):
222
+ global query_engine
223
+
224
+ if query_engine is None:
225
+ initialize_apis(api, model)
226
+ documents = load_pdf_data()
227
+ index, nodes = create_index(documents)
228
+ query_engine = setup_query_engine(index, response_mode="tree_summarize", nodes=nodes, retriever_method='BM25')
229
+
230
+ if 'chat_history' not in st.session_state:
231
+ st.session_state.chat_history = []
232
+
233
+ st.title("RAG Chat Application")
234
+
235
+ query_method = st.selectbox("Select Query Engine Method", ["Default", "FLARE", "MS"])
236
+ retriever_method = st.selectbox("Select Retriever Method", ["Default", "BM25", "BM25+Vector"])
237
+ selected_api = st.selectbox("Select API", ["azure", "ollama", "groq"])
238
+ selected_model = st.selectbox("Select Model", ["llama3-8b", "llama3-70b", "mixtral-8x7b", "gemma-7b", "gpt35"])
239
+ embedding_model_type = st.selectbox("Select Embedding Model Type", ["HF", "OAI"])
240
+ embedding_model = st.selectbox("Select Embedding Model", ["BAAI/bge-large-en-v1.5", "other_model"]) # Add your embedding models here
241
+
242
+ if query_method == "FLARE":
243
+ if retriever_method in ["BM25", "BM25+Vector"]:
244
+ query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, query_engine_method="FLARE", retriever_method=retriever_method)
245
+ else:
246
+ query_engine = setup_query_engine(index, response_mode="tree_summarize", query_engine_method="FLARE", retriever_method=retriever_method)
247
+ elif query_method == "MS":
248
+ if retriever_method in ["BM25", "BM25+Vector"]:
249
+ query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, query_engine_method="MS", retriever_method=retriever_method)
250
+ else:
251
+ query_engine = setup_query_engine(index, response_mode="tree_summarize", query_engine_method="MS", retriever_method=retriever_method)
252
+ else:
253
+ if retriever_method in ["BM25", "BM25+Vector"]:
254
+ query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, retriever_method=retriever_method)
255
+ else:
256
+ query_engine = setup_query_engine(index, response_mode="tree_summarize", retriever_method=retriever_method)
257
+
258
+ for chat in st.session_state.chat_history:
259
+ with st.chat_message("user"):
260
+ st.markdown(chat['user'])
261
+ with st.chat_message("bot"):
262
+ st.markdown(chat['response'])
263
+
264
+ if question := st.chat_input("Enter your question"):
265
+ response = query_engine.query(question)
266
+ st.session_state.chat_history.append({'user': question, 'response': response.response})
267
+ st.rerun()
268
+
269
+ def run_terminal_app(api, model, query_method, retriever_method):
270
+ global query_engine
271
+
272
+ if query_engine is None:
273
+ initialize_apis(api, model)
274
+ documents = load_pdf_data()
275
+ if retriever_method == "BM25" or retriever_method == "BM25+Vector":
276
+ _, nodes = create_index(documents, retriever_method=retriever_method)
277
+ query_engine = setup_query_engine(None, response_mode="compact_accumulate", nodes=nodes, query_engine_method=query_method, retriever_method=retriever_method)
278
+ else:
279
+ index, _ = create_index(documents, retriever_method=retriever_method)
280
+ query_engine = setup_query_engine(index, response_mode="compact_accumulate", query_engine_method=query_method, retriever_method=retriever_method)
281
+
282
+ while True:
283
+ question = input("Enter your question (or type 'exit' to quit): ")
284
+ if question.lower() == 'exit':
285
+ break
286
+ response = query_engine.query(question)
287
+ retrieved_nodes = response.source_nodes
288
+ chunks = [node.text for node in retrieved_nodes]
289
+ print("Contexts:")
290
+ for chunk in chunks:
291
+ print(chunk)
292
+ if retriever_method == "BM25" or retriever_method == "BM25+Vector":
293
+ query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, query_engine_method=query_method, retriever_method=retriever_method)
294
+ else:
295
+ query_engine = setup_query_engine(index, response_mode="tree_summarize", query_engine_method=query_method, retriever_method=retriever_method)
296
+ final_response = query_engine.query(question)
297
+ print("Final Answer:", final_response.response)
298
+
299
+ def main():
300
+ parser = argparse.ArgumentParser(description="Run the RAG app.")
301
+ parser.add_argument('--mode', type=str, choices=['terminal', 'benchmark', 'retrieve_contexts', 'retrieve_answers'], required=False, default='terminal', help="Mode to run the application in: 'terminal', 'benchmark', 'retrieve_contexts', 'retrieve_answers'")
302
+ parser.add_argument('--api', type=str, choices=['azure', 'ollama', 'groq'], required=False, default='azure', help='Which api to use to call LLMs: ollama, groq or azure (openai)')
303
+ parser.add_argument('--model', type=str, choices=['llama3-8b', 'llama3-70b', 'mixtral-8x7b', 'gemma-7b', 'gpt35'], default='gpt35')
304
+ parser.add_argument('--embedding_model_type', type=str, choices=['HF'], required=False, default="HF")
305
+ parser.add_argument('--embedding_model', type=str, default="BAAI/bge-large-en-v1.5")
306
+ parser.add_argument('--csv_file', type=str, required=False, help='Path to the CSV file containing questions')
307
+ parser.add_argument('--query_method', type=str, choices=['Default', 'FLARE', 'MS'], required=False, default='Default', help='Query Engine Method to use')
308
+ parser.add_argument('--retriever_method', type=str, choices=['Default', 'BM25', 'BM25+Vector'], required=False, default='Default', help='Retriever Method to use')
309
+ args = parser.parse_args()
310
+
311
+ if args.mode == 'terminal':
312
+ run_terminal_app(args.api, args.model, args.query_method, args.retriever_method)
313
+ elif args.mode == 'retrieve_contexts':
314
+ if args.csv_file:
315
+ retrieve_and_update_contexts(args.api, args.model, args.csv_file)
316
+ else:
317
+ log_and_exit("CSV file path is required for retrieve_contexts mode.")
318
+ elif args.mode == 'retrieve_answers':
319
+ if args.csv_file:
320
+ retrieve_answers_for_modes(args.api, args.model, args.csv_file)
321
+ else:
322
+ log_and_exit("CSV file path is required for retrieve_answers mode.")
323
+ elif args.mode == 'benchmark':
324
+ pass
325
+
326
+ if __name__ == "__main__":
327
+ import sys
328
+ if 'streamlit' in sys.argv[0]:
329
+ run_streamlit_app('azure', 'gpt35')
330
+ else:
331
+ main()