rk68 commited on
Commit
ff8b6e7
·
verified ·
1 Parent(s): 6d23d33

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -331
app.py DELETED
@@ -1,331 +0,0 @@
1
- import os
2
- import logging
3
- import pandas as pd
4
- import argparse
5
- import streamlit as st
6
- from pinecone import Pinecone
7
- from llama_index.llms.gemini import Gemini
8
- from llama_index.vector_stores.pinecone import PineconeVectorStore
9
- from llama_index.core import StorageContext, VectorStoreIndex, SimpleDirectoryReader, get_response_synthesizer
10
- from llama_index.core import Settings
11
- from dotenv import load_dotenv
12
- from llama_index.core.node_parser import SentenceSplitter
13
- from llama_parse import LlamaParse
14
- from llama_index.core.retrievers import VectorIndexRetriever
15
- from llama_index.core.retrievers import RouterRetriever
16
- from llama_index.retrievers.bm25 import BM25Retriever
17
- from llama_index.core.tools import RetrieverTool
18
- from llama_index.core.query_engine import RetrieverQueryEngine
19
- from langchain_groq import ChatGroq
20
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
21
- from llama_index.llms.azure_openai import AzureOpenAI
22
- from llama_index.embeddings.openai import OpenAIEmbedding
23
- from sklearn.metrics.pairwise import cosine_similarity
24
- from sklearn.feature_extraction.text import TfidfVectorizer
25
- from llama_index.core.query_engine import FLAREInstructQueryEngine, MultiStepQueryEngine
26
- from llama_index.core.indices.query.query_transform import HyDEQueryTransform
27
- from llama_index.core.query_engine import TransformQueryEngine
28
- from llama_index.core.indices.query.query_transform.base import (
29
- StepDecomposeQueryTransform,
30
- )
31
-
32
- # Configure logging
33
- logging.basicConfig(level=logging.INFO)
34
-
35
- # Load environment variables from .env file
36
- load_dotenv()
37
-
38
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
- # Fetch API keys from environment variables
40
- pinecone_api_key = os.getenv('PINECONE_API_KEY')
41
- parse_api_key = os.getenv('PARSE_API_KEY')
42
- azure_api_key = os.getenv('AZURE_API_KEY')
43
- azure_endpoint = os.getenv('AZURE_ENDPOINT')
44
- azure_deployment_name = os.getenv('AZURE_DEPLOYMENT_NAME')
45
- azure_api_version = os.getenv('AZURE_API_VERSION')
46
-
47
- # Global variables for lazy loading
48
- llm = None
49
- pinecone_index = None
50
- query_engine = None
51
-
52
- def log_and_exit(message):
53
- logging.error(message)
54
- raise SystemExit(message)
55
-
56
- def initialize_apis(api, model):
57
- global llm, pinecone_index
58
- try:
59
- if llm is None:
60
- if api == 'groq':
61
- if model == 'mixtral-8x7b':
62
- llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
63
- elif model == 'llama3-8b':
64
- llm = ChatGroq(model="llama3-8b-8192", temperature=0)
65
- elif model == "llama3-70b":
66
- llm = ChatGroq(model="llama3-70b-8192", temperature=0)
67
- elif model == "gemma-7b":
68
- llm = ChatGroq(model="gemma-7b-it", temperature=0)
69
-
70
- elif api == 'azure':
71
- if model == 'gpt35':
72
- llm = AzureOpenAI(
73
- deployment_name="gpt35",
74
- temperature=0,
75
- api_key=azure_api_key,
76
- azure_endpoint=azure_endpoint,
77
- api_version=azure_api_version
78
- )
79
-
80
- if pinecone_index is None:
81
- index_name = "demo"
82
- pinecone_client = Pinecone(pinecone_api_key)
83
- pinecone_index = pinecone_client.Index(index_name)
84
- logging.info("Initialized LLM and Pinecone.")
85
- except Exception as e:
86
- log_and_exit(f"Error initializing APIs: {e}")
87
-
88
- def load_pdf_data():
89
- PDF_FILE_PATH = "policy.pdf"
90
- try:
91
- parser = LlamaParse(api_key=parse_api_key, result_type="markdown")
92
- file_extractor = {".pdf": parser}
93
- documents = SimpleDirectoryReader(input_files=[PDF_FILE_PATH], file_extractor=file_extractor).load_data()
94
- logging.info(f"Loaded {len(documents)} documents from PDF.")
95
- return documents
96
- except Exception as e:
97
- log_and_exit(f"Error loading PDF file: {e}")
98
-
99
- def create_index(documents, embedding_model_type="HF", embedding_model="BAAI/bge-large-en-v1.5", retriever_method="BM25"):
100
- global llm, pinecone_index
101
- try:
102
- if embedding_model_type == "HF":
103
- embed_model = HuggingFaceEmbedding(model_name=embedding_model)
104
- elif embedding_model_type == "OAI":
105
- # embed_model = OpenAIEmbedding() implement oai EMBEDDING
106
- pass
107
-
108
- Settings.llm = llm
109
- Settings.embed_model = embed_model
110
- Settings.chunk_size = 512
111
-
112
- if retriever_method == "BM25" or retriever_method == "BM25+Vector":
113
- splitter = SentenceSplitter(chunk_size=512)
114
- nodes = splitter.get_nodes_from_documents(documents)
115
- storage_context = StorageContext.from_defaults()
116
- return None, nodes # Return None for index when using BM25
117
- else:
118
- vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
119
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
120
- index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
121
- logging.info("Created index from documents.")
122
- return index, None # Return None for nodes when not using BM25
123
- except Exception as e:
124
- log_and_exit(f"Error creating index: {e}")
125
-
126
- def setup_query_engine(index, response_mode, nodes=None, query_engine_method=None, retriever_method=None):
127
- global llm
128
- try:
129
- logging.info(f"Setting up query engine with retriever_method: {retriever_method} and query_engine_method: {query_engine_method}")
130
-
131
- if retriever_method == 'BM25':
132
- retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)
133
- elif retriever_method == "BM25+Vector":
134
- vector_retriever = VectorIndexRetriever(index=index)
135
- bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)
136
-
137
- retriever_tools = [
138
- RetrieverTool.from_defaults(
139
- retriever=vector_retriever,
140
- description='Useful in most cases',
141
- ),
142
- RetrieverTool.from_defaults(
143
- retriever=bm25_retriever,
144
- description='Useful if searching about specific information',
145
- ),
146
- ]
147
- retriever = RouterRetriever.from_defaults(
148
- retriever_tools=retriever_tools,
149
- llm=llm,
150
- select_multi=True
151
- )
152
- else:
153
- retriever = VectorIndexRetriever(index=index, similarity_top_k=2)
154
-
155
- response_synthesizer = get_response_synthesizer(response_mode=response_mode)
156
- index_query_engine = index.as_query_engine(similarity_top_k=2) if index else None
157
-
158
- if query_engine_method == "FLARE":
159
- query_engine = FLAREInstructQueryEngine(
160
- query_engine=index_query_engine,
161
- max_iterations=7,
162
- verbose=True
163
- )
164
- elif query_engine_method == "MS":
165
- step_decompose_transform = StepDecomposeQueryTransform(llm=llm, verbose=True)
166
- index_summary = "Used to answer questions about the regulation"
167
- query_engine = MultiStepQueryEngine(
168
- query_engine=index_query_engine,
169
- query_transform=step_decompose_transform,
170
- index_summary=index_summary
171
- )
172
- else:
173
- query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer)
174
- return query_engine
175
- except Exception as e:
176
- log_and_exit(f"Error setting up query engine: {e}")
177
-
178
- def retrieve_and_update_contexts(api, model, file_path):
179
- global query_engine
180
-
181
- if query_engine is None:
182
- initialize_apis(api, model)
183
- documents = load_pdf_data()
184
- _, nodes = create_index(documents, retriever_method='BM25')
185
- query_engine = setup_query_engine(None, response_mode="compact_accumulate", nodes=nodes, retriever_method='BM25')
186
-
187
- df = pd.read_csv(file_path)
188
-
189
- for idx, row in df.iterrows():
190
- question = row['question']
191
- response = query_engine.query(question)
192
- retrieved_nodes = response.source_nodes
193
- chunks = [node.text for node in retrieved_nodes]
194
- logging.info(f"Context response for question {idx}: {response}")
195
- df.at[idx, 'contexts'] = " ".join(chunks)
196
-
197
- df.to_csv(file_path, index=False)
198
- logging.info(f"Processed questions and updated the CSV file: {file_path}")
199
-
200
- def retrieve_answers_for_modes(api, model, file_path):
201
- global query_engine
202
-
203
- df = pd.read_csv(file_path)
204
- initialize_apis(api, model)
205
- documents = load_pdf_data()
206
- index, _ = create_index(documents)
207
-
208
- response_modes = ["refine", "compact", "tree_summarize", "simple_summarize"]
209
-
210
- for idx, row in df.iterrows():
211
- question = row['question']
212
- for mode in response_modes:
213
- query_engine = setup_query_engine(index, response_mode=mode, retriever_method='Default')
214
- response = query_engine.query(question)
215
- answer_column = f"{mode}_answer"
216
- df.at[idx, answer_column] = response.response
217
-
218
- df.to_csv(file_path, index=False)
219
- logging.info(f"Processed questions and updated the CSV file with answers: {file_path}")
220
-
221
- def run_streamlit_app(api, model):
222
- global query_engine
223
-
224
- if query_engine is None:
225
- initialize_apis(api, model)
226
- documents = load_pdf_data()
227
- index, nodes = create_index(documents)
228
- query_engine = setup_query_engine(index, response_mode="tree_summarize", nodes=nodes, retriever_method='BM25')
229
-
230
- if 'chat_history' not in st.session_state:
231
- st.session_state.chat_history = []
232
-
233
- st.title("RAG Chat Application")
234
-
235
- query_method = st.selectbox("Select Query Engine Method", ["Default", "FLARE", "MS"])
236
- retriever_method = st.selectbox("Select Retriever Method", ["Default", "BM25", "BM25+Vector"])
237
- selected_api = st.selectbox("Select API", ["azure", "ollama", "groq"])
238
- selected_model = st.selectbox("Select Model", ["llama3-8b", "llama3-70b", "mixtral-8x7b", "gemma-7b", "gpt35"])
239
- embedding_model_type = st.selectbox("Select Embedding Model Type", ["HF", "OAI"])
240
- embedding_model = st.selectbox("Select Embedding Model", ["BAAI/bge-large-en-v1.5", "other_model"]) # Add your embedding models here
241
-
242
- if query_method == "FLARE":
243
- if retriever_method in ["BM25", "BM25+Vector"]:
244
- query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, query_engine_method="FLARE", retriever_method=retriever_method)
245
- else:
246
- query_engine = setup_query_engine(index, response_mode="tree_summarize", query_engine_method="FLARE", retriever_method=retriever_method)
247
- elif query_method == "MS":
248
- if retriever_method in ["BM25", "BM25+Vector"]:
249
- query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, query_engine_method="MS", retriever_method=retriever_method)
250
- else:
251
- query_engine = setup_query_engine(index, response_mode="tree_summarize", query_engine_method="MS", retriever_method=retriever_method)
252
- else:
253
- if retriever_method in ["BM25", "BM25+Vector"]:
254
- query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, retriever_method=retriever_method)
255
- else:
256
- query_engine = setup_query_engine(index, response_mode="tree_summarize", retriever_method=retriever_method)
257
-
258
- for chat in st.session_state.chat_history:
259
- with st.chat_message("user"):
260
- st.markdown(chat['user'])
261
- with st.chat_message("bot"):
262
- st.markdown(chat['response'])
263
-
264
- if question := st.chat_input("Enter your question"):
265
- response = query_engine.query(question)
266
- st.session_state.chat_history.append({'user': question, 'response': response.response})
267
- st.rerun()
268
-
269
- def run_terminal_app(api, model, query_method, retriever_method):
270
- global query_engine
271
-
272
- if query_engine is None:
273
- initialize_apis(api, model)
274
- documents = load_pdf_data()
275
- if retriever_method == "BM25" or retriever_method == "BM25+Vector":
276
- _, nodes = create_index(documents, retriever_method=retriever_method)
277
- query_engine = setup_query_engine(None, response_mode="compact_accumulate", nodes=nodes, query_engine_method=query_method, retriever_method=retriever_method)
278
- else:
279
- index, _ = create_index(documents, retriever_method=retriever_method)
280
- query_engine = setup_query_engine(index, response_mode="compact_accumulate", query_engine_method=query_method, retriever_method=retriever_method)
281
-
282
- while True:
283
- question = input("Enter your question (or type 'exit' to quit): ")
284
- if question.lower() == 'exit':
285
- break
286
- response = query_engine.query(question)
287
- retrieved_nodes = response.source_nodes
288
- chunks = [node.text for node in retrieved_nodes]
289
- print("Contexts:")
290
- for chunk in chunks:
291
- print(chunk)
292
- if retriever_method == "BM25" or retriever_method == "BM25+Vector":
293
- query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, query_engine_method=query_method, retriever_method=retriever_method)
294
- else:
295
- query_engine = setup_query_engine(index, response_mode="tree_summarize", query_engine_method=query_method, retriever_method=retriever_method)
296
- final_response = query_engine.query(question)
297
- print("Final Answer:", final_response.response)
298
-
299
- def main():
300
- parser = argparse.ArgumentParser(description="Run the RAG app.")
301
- parser.add_argument('--mode', type=str, choices=['terminal', 'benchmark', 'retrieve_contexts', 'retrieve_answers'], required=False, default='terminal', help="Mode to run the application in: 'terminal', 'benchmark', 'retrieve_contexts', 'retrieve_answers'")
302
- parser.add_argument('--api', type=str, choices=['azure', 'ollama', 'groq'], required=False, default='azure', help='Which api to use to call LLMs: ollama, groq or azure (openai)')
303
- parser.add_argument('--model', type=str, choices=['llama3-8b', 'llama3-70b', 'mixtral-8x7b', 'gemma-7b', 'gpt35'], default='gpt35')
304
- parser.add_argument('--embedding_model_type', type=str, choices=['HF'], required=False, default="HF")
305
- parser.add_argument('--embedding_model', type=str, default="BAAI/bge-large-en-v1.5")
306
- parser.add_argument('--csv_file', type=str, required=False, help='Path to the CSV file containing questions')
307
- parser.add_argument('--query_method', type=str, choices=['Default', 'FLARE', 'MS'], required=False, default='Default', help='Query Engine Method to use')
308
- parser.add_argument('--retriever_method', type=str, choices=['Default', 'BM25', 'BM25+Vector'], required=False, default='Default', help='Retriever Method to use')
309
- args = parser.parse_args()
310
-
311
- if args.mode == 'terminal':
312
- run_terminal_app(args.api, args.model, args.query_method, args.retriever_method)
313
- elif args.mode == 'retrieve_contexts':
314
- if args.csv_file:
315
- retrieve_and_update_contexts(args.api, args.model, args.csv_file)
316
- else:
317
- log_and_exit("CSV file path is required for retrieve_contexts mode.")
318
- elif args.mode == 'retrieve_answers':
319
- if args.csv_file:
320
- retrieve_answers_for_modes(args.api, args.model, args.csv_file)
321
- else:
322
- log_and_exit("CSV file path is required for retrieve_answers mode.")
323
- elif args.mode == 'benchmark':
324
- pass
325
-
326
- if __name__ == "__main__":
327
- import sys
328
- if 'streamlit' in sys.argv[0]:
329
- run_streamlit_app('azure', 'gpt35')
330
- else:
331
- main()