Delete app.py
Browse files
app.py
DELETED
@@ -1,331 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import logging
|
3 |
-
import pandas as pd
|
4 |
-
import argparse
|
5 |
-
import streamlit as st
|
6 |
-
from pinecone import Pinecone
|
7 |
-
from llama_index.llms.gemini import Gemini
|
8 |
-
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
9 |
-
from llama_index.core import StorageContext, VectorStoreIndex, SimpleDirectoryReader, get_response_synthesizer
|
10 |
-
from llama_index.core import Settings
|
11 |
-
from dotenv import load_dotenv
|
12 |
-
from llama_index.core.node_parser import SentenceSplitter
|
13 |
-
from llama_parse import LlamaParse
|
14 |
-
from llama_index.core.retrievers import VectorIndexRetriever
|
15 |
-
from llama_index.core.retrievers import RouterRetriever
|
16 |
-
from llama_index.retrievers.bm25 import BM25Retriever
|
17 |
-
from llama_index.core.tools import RetrieverTool
|
18 |
-
from llama_index.core.query_engine import RetrieverQueryEngine
|
19 |
-
from langchain_groq import ChatGroq
|
20 |
-
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
21 |
-
from llama_index.llms.azure_openai import AzureOpenAI
|
22 |
-
from llama_index.embeddings.openai import OpenAIEmbedding
|
23 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
24 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
25 |
-
from llama_index.core.query_engine import FLAREInstructQueryEngine, MultiStepQueryEngine
|
26 |
-
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
|
27 |
-
from llama_index.core.query_engine import TransformQueryEngine
|
28 |
-
from llama_index.core.indices.query.query_transform.base import (
|
29 |
-
StepDecomposeQueryTransform,
|
30 |
-
)
|
31 |
-
|
32 |
-
# Configure logging
|
33 |
-
logging.basicConfig(level=logging.INFO)
|
34 |
-
|
35 |
-
# Load environment variables from .env file
|
36 |
-
load_dotenv()
|
37 |
-
|
38 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
39 |
-
# Fetch API keys from environment variables
|
40 |
-
pinecone_api_key = os.getenv('PINECONE_API_KEY')
|
41 |
-
parse_api_key = os.getenv('PARSE_API_KEY')
|
42 |
-
azure_api_key = os.getenv('AZURE_API_KEY')
|
43 |
-
azure_endpoint = os.getenv('AZURE_ENDPOINT')
|
44 |
-
azure_deployment_name = os.getenv('AZURE_DEPLOYMENT_NAME')
|
45 |
-
azure_api_version = os.getenv('AZURE_API_VERSION')
|
46 |
-
|
47 |
-
# Global variables for lazy loading
|
48 |
-
llm = None
|
49 |
-
pinecone_index = None
|
50 |
-
query_engine = None
|
51 |
-
|
52 |
-
def log_and_exit(message):
|
53 |
-
logging.error(message)
|
54 |
-
raise SystemExit(message)
|
55 |
-
|
56 |
-
def initialize_apis(api, model):
|
57 |
-
global llm, pinecone_index
|
58 |
-
try:
|
59 |
-
if llm is None:
|
60 |
-
if api == 'groq':
|
61 |
-
if model == 'mixtral-8x7b':
|
62 |
-
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
|
63 |
-
elif model == 'llama3-8b':
|
64 |
-
llm = ChatGroq(model="llama3-8b-8192", temperature=0)
|
65 |
-
elif model == "llama3-70b":
|
66 |
-
llm = ChatGroq(model="llama3-70b-8192", temperature=0)
|
67 |
-
elif model == "gemma-7b":
|
68 |
-
llm = ChatGroq(model="gemma-7b-it", temperature=0)
|
69 |
-
|
70 |
-
elif api == 'azure':
|
71 |
-
if model == 'gpt35':
|
72 |
-
llm = AzureOpenAI(
|
73 |
-
deployment_name="gpt35",
|
74 |
-
temperature=0,
|
75 |
-
api_key=azure_api_key,
|
76 |
-
azure_endpoint=azure_endpoint,
|
77 |
-
api_version=azure_api_version
|
78 |
-
)
|
79 |
-
|
80 |
-
if pinecone_index is None:
|
81 |
-
index_name = "demo"
|
82 |
-
pinecone_client = Pinecone(pinecone_api_key)
|
83 |
-
pinecone_index = pinecone_client.Index(index_name)
|
84 |
-
logging.info("Initialized LLM and Pinecone.")
|
85 |
-
except Exception as e:
|
86 |
-
log_and_exit(f"Error initializing APIs: {e}")
|
87 |
-
|
88 |
-
def load_pdf_data():
|
89 |
-
PDF_FILE_PATH = "policy.pdf"
|
90 |
-
try:
|
91 |
-
parser = LlamaParse(api_key=parse_api_key, result_type="markdown")
|
92 |
-
file_extractor = {".pdf": parser}
|
93 |
-
documents = SimpleDirectoryReader(input_files=[PDF_FILE_PATH], file_extractor=file_extractor).load_data()
|
94 |
-
logging.info(f"Loaded {len(documents)} documents from PDF.")
|
95 |
-
return documents
|
96 |
-
except Exception as e:
|
97 |
-
log_and_exit(f"Error loading PDF file: {e}")
|
98 |
-
|
99 |
-
def create_index(documents, embedding_model_type="HF", embedding_model="BAAI/bge-large-en-v1.5", retriever_method="BM25"):
|
100 |
-
global llm, pinecone_index
|
101 |
-
try:
|
102 |
-
if embedding_model_type == "HF":
|
103 |
-
embed_model = HuggingFaceEmbedding(model_name=embedding_model)
|
104 |
-
elif embedding_model_type == "OAI":
|
105 |
-
# embed_model = OpenAIEmbedding() implement oai EMBEDDING
|
106 |
-
pass
|
107 |
-
|
108 |
-
Settings.llm = llm
|
109 |
-
Settings.embed_model = embed_model
|
110 |
-
Settings.chunk_size = 512
|
111 |
-
|
112 |
-
if retriever_method == "BM25" or retriever_method == "BM25+Vector":
|
113 |
-
splitter = SentenceSplitter(chunk_size=512)
|
114 |
-
nodes = splitter.get_nodes_from_documents(documents)
|
115 |
-
storage_context = StorageContext.from_defaults()
|
116 |
-
return None, nodes # Return None for index when using BM25
|
117 |
-
else:
|
118 |
-
vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
|
119 |
-
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
120 |
-
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
|
121 |
-
logging.info("Created index from documents.")
|
122 |
-
return index, None # Return None for nodes when not using BM25
|
123 |
-
except Exception as e:
|
124 |
-
log_and_exit(f"Error creating index: {e}")
|
125 |
-
|
126 |
-
def setup_query_engine(index, response_mode, nodes=None, query_engine_method=None, retriever_method=None):
|
127 |
-
global llm
|
128 |
-
try:
|
129 |
-
logging.info(f"Setting up query engine with retriever_method: {retriever_method} and query_engine_method: {query_engine_method}")
|
130 |
-
|
131 |
-
if retriever_method == 'BM25':
|
132 |
-
retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)
|
133 |
-
elif retriever_method == "BM25+Vector":
|
134 |
-
vector_retriever = VectorIndexRetriever(index=index)
|
135 |
-
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)
|
136 |
-
|
137 |
-
retriever_tools = [
|
138 |
-
RetrieverTool.from_defaults(
|
139 |
-
retriever=vector_retriever,
|
140 |
-
description='Useful in most cases',
|
141 |
-
),
|
142 |
-
RetrieverTool.from_defaults(
|
143 |
-
retriever=bm25_retriever,
|
144 |
-
description='Useful if searching about specific information',
|
145 |
-
),
|
146 |
-
]
|
147 |
-
retriever = RouterRetriever.from_defaults(
|
148 |
-
retriever_tools=retriever_tools,
|
149 |
-
llm=llm,
|
150 |
-
select_multi=True
|
151 |
-
)
|
152 |
-
else:
|
153 |
-
retriever = VectorIndexRetriever(index=index, similarity_top_k=2)
|
154 |
-
|
155 |
-
response_synthesizer = get_response_synthesizer(response_mode=response_mode)
|
156 |
-
index_query_engine = index.as_query_engine(similarity_top_k=2) if index else None
|
157 |
-
|
158 |
-
if query_engine_method == "FLARE":
|
159 |
-
query_engine = FLAREInstructQueryEngine(
|
160 |
-
query_engine=index_query_engine,
|
161 |
-
max_iterations=7,
|
162 |
-
verbose=True
|
163 |
-
)
|
164 |
-
elif query_engine_method == "MS":
|
165 |
-
step_decompose_transform = StepDecomposeQueryTransform(llm=llm, verbose=True)
|
166 |
-
index_summary = "Used to answer questions about the regulation"
|
167 |
-
query_engine = MultiStepQueryEngine(
|
168 |
-
query_engine=index_query_engine,
|
169 |
-
query_transform=step_decompose_transform,
|
170 |
-
index_summary=index_summary
|
171 |
-
)
|
172 |
-
else:
|
173 |
-
query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer)
|
174 |
-
return query_engine
|
175 |
-
except Exception as e:
|
176 |
-
log_and_exit(f"Error setting up query engine: {e}")
|
177 |
-
|
178 |
-
def retrieve_and_update_contexts(api, model, file_path):
|
179 |
-
global query_engine
|
180 |
-
|
181 |
-
if query_engine is None:
|
182 |
-
initialize_apis(api, model)
|
183 |
-
documents = load_pdf_data()
|
184 |
-
_, nodes = create_index(documents, retriever_method='BM25')
|
185 |
-
query_engine = setup_query_engine(None, response_mode="compact_accumulate", nodes=nodes, retriever_method='BM25')
|
186 |
-
|
187 |
-
df = pd.read_csv(file_path)
|
188 |
-
|
189 |
-
for idx, row in df.iterrows():
|
190 |
-
question = row['question']
|
191 |
-
response = query_engine.query(question)
|
192 |
-
retrieved_nodes = response.source_nodes
|
193 |
-
chunks = [node.text for node in retrieved_nodes]
|
194 |
-
logging.info(f"Context response for question {idx}: {response}")
|
195 |
-
df.at[idx, 'contexts'] = " ".join(chunks)
|
196 |
-
|
197 |
-
df.to_csv(file_path, index=False)
|
198 |
-
logging.info(f"Processed questions and updated the CSV file: {file_path}")
|
199 |
-
|
200 |
-
def retrieve_answers_for_modes(api, model, file_path):
|
201 |
-
global query_engine
|
202 |
-
|
203 |
-
df = pd.read_csv(file_path)
|
204 |
-
initialize_apis(api, model)
|
205 |
-
documents = load_pdf_data()
|
206 |
-
index, _ = create_index(documents)
|
207 |
-
|
208 |
-
response_modes = ["refine", "compact", "tree_summarize", "simple_summarize"]
|
209 |
-
|
210 |
-
for idx, row in df.iterrows():
|
211 |
-
question = row['question']
|
212 |
-
for mode in response_modes:
|
213 |
-
query_engine = setup_query_engine(index, response_mode=mode, retriever_method='Default')
|
214 |
-
response = query_engine.query(question)
|
215 |
-
answer_column = f"{mode}_answer"
|
216 |
-
df.at[idx, answer_column] = response.response
|
217 |
-
|
218 |
-
df.to_csv(file_path, index=False)
|
219 |
-
logging.info(f"Processed questions and updated the CSV file with answers: {file_path}")
|
220 |
-
|
221 |
-
def run_streamlit_app(api, model):
|
222 |
-
global query_engine
|
223 |
-
|
224 |
-
if query_engine is None:
|
225 |
-
initialize_apis(api, model)
|
226 |
-
documents = load_pdf_data()
|
227 |
-
index, nodes = create_index(documents)
|
228 |
-
query_engine = setup_query_engine(index, response_mode="tree_summarize", nodes=nodes, retriever_method='BM25')
|
229 |
-
|
230 |
-
if 'chat_history' not in st.session_state:
|
231 |
-
st.session_state.chat_history = []
|
232 |
-
|
233 |
-
st.title("RAG Chat Application")
|
234 |
-
|
235 |
-
query_method = st.selectbox("Select Query Engine Method", ["Default", "FLARE", "MS"])
|
236 |
-
retriever_method = st.selectbox("Select Retriever Method", ["Default", "BM25", "BM25+Vector"])
|
237 |
-
selected_api = st.selectbox("Select API", ["azure", "ollama", "groq"])
|
238 |
-
selected_model = st.selectbox("Select Model", ["llama3-8b", "llama3-70b", "mixtral-8x7b", "gemma-7b", "gpt35"])
|
239 |
-
embedding_model_type = st.selectbox("Select Embedding Model Type", ["HF", "OAI"])
|
240 |
-
embedding_model = st.selectbox("Select Embedding Model", ["BAAI/bge-large-en-v1.5", "other_model"]) # Add your embedding models here
|
241 |
-
|
242 |
-
if query_method == "FLARE":
|
243 |
-
if retriever_method in ["BM25", "BM25+Vector"]:
|
244 |
-
query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, query_engine_method="FLARE", retriever_method=retriever_method)
|
245 |
-
else:
|
246 |
-
query_engine = setup_query_engine(index, response_mode="tree_summarize", query_engine_method="FLARE", retriever_method=retriever_method)
|
247 |
-
elif query_method == "MS":
|
248 |
-
if retriever_method in ["BM25", "BM25+Vector"]:
|
249 |
-
query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, query_engine_method="MS", retriever_method=retriever_method)
|
250 |
-
else:
|
251 |
-
query_engine = setup_query_engine(index, response_mode="tree_summarize", query_engine_method="MS", retriever_method=retriever_method)
|
252 |
-
else:
|
253 |
-
if retriever_method in ["BM25", "BM25+Vector"]:
|
254 |
-
query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, retriever_method=retriever_method)
|
255 |
-
else:
|
256 |
-
query_engine = setup_query_engine(index, response_mode="tree_summarize", retriever_method=retriever_method)
|
257 |
-
|
258 |
-
for chat in st.session_state.chat_history:
|
259 |
-
with st.chat_message("user"):
|
260 |
-
st.markdown(chat['user'])
|
261 |
-
with st.chat_message("bot"):
|
262 |
-
st.markdown(chat['response'])
|
263 |
-
|
264 |
-
if question := st.chat_input("Enter your question"):
|
265 |
-
response = query_engine.query(question)
|
266 |
-
st.session_state.chat_history.append({'user': question, 'response': response.response})
|
267 |
-
st.rerun()
|
268 |
-
|
269 |
-
def run_terminal_app(api, model, query_method, retriever_method):
|
270 |
-
global query_engine
|
271 |
-
|
272 |
-
if query_engine is None:
|
273 |
-
initialize_apis(api, model)
|
274 |
-
documents = load_pdf_data()
|
275 |
-
if retriever_method == "BM25" or retriever_method == "BM25+Vector":
|
276 |
-
_, nodes = create_index(documents, retriever_method=retriever_method)
|
277 |
-
query_engine = setup_query_engine(None, response_mode="compact_accumulate", nodes=nodes, query_engine_method=query_method, retriever_method=retriever_method)
|
278 |
-
else:
|
279 |
-
index, _ = create_index(documents, retriever_method=retriever_method)
|
280 |
-
query_engine = setup_query_engine(index, response_mode="compact_accumulate", query_engine_method=query_method, retriever_method=retriever_method)
|
281 |
-
|
282 |
-
while True:
|
283 |
-
question = input("Enter your question (or type 'exit' to quit): ")
|
284 |
-
if question.lower() == 'exit':
|
285 |
-
break
|
286 |
-
response = query_engine.query(question)
|
287 |
-
retrieved_nodes = response.source_nodes
|
288 |
-
chunks = [node.text for node in retrieved_nodes]
|
289 |
-
print("Contexts:")
|
290 |
-
for chunk in chunks:
|
291 |
-
print(chunk)
|
292 |
-
if retriever_method == "BM25" or retriever_method == "BM25+Vector":
|
293 |
-
query_engine = setup_query_engine(None, response_mode="tree_summarize", nodes=nodes, query_engine_method=query_method, retriever_method=retriever_method)
|
294 |
-
else:
|
295 |
-
query_engine = setup_query_engine(index, response_mode="tree_summarize", query_engine_method=query_method, retriever_method=retriever_method)
|
296 |
-
final_response = query_engine.query(question)
|
297 |
-
print("Final Answer:", final_response.response)
|
298 |
-
|
299 |
-
def main():
|
300 |
-
parser = argparse.ArgumentParser(description="Run the RAG app.")
|
301 |
-
parser.add_argument('--mode', type=str, choices=['terminal', 'benchmark', 'retrieve_contexts', 'retrieve_answers'], required=False, default='terminal', help="Mode to run the application in: 'terminal', 'benchmark', 'retrieve_contexts', 'retrieve_answers'")
|
302 |
-
parser.add_argument('--api', type=str, choices=['azure', 'ollama', 'groq'], required=False, default='azure', help='Which api to use to call LLMs: ollama, groq or azure (openai)')
|
303 |
-
parser.add_argument('--model', type=str, choices=['llama3-8b', 'llama3-70b', 'mixtral-8x7b', 'gemma-7b', 'gpt35'], default='gpt35')
|
304 |
-
parser.add_argument('--embedding_model_type', type=str, choices=['HF'], required=False, default="HF")
|
305 |
-
parser.add_argument('--embedding_model', type=str, default="BAAI/bge-large-en-v1.5")
|
306 |
-
parser.add_argument('--csv_file', type=str, required=False, help='Path to the CSV file containing questions')
|
307 |
-
parser.add_argument('--query_method', type=str, choices=['Default', 'FLARE', 'MS'], required=False, default='Default', help='Query Engine Method to use')
|
308 |
-
parser.add_argument('--retriever_method', type=str, choices=['Default', 'BM25', 'BM25+Vector'], required=False, default='Default', help='Retriever Method to use')
|
309 |
-
args = parser.parse_args()
|
310 |
-
|
311 |
-
if args.mode == 'terminal':
|
312 |
-
run_terminal_app(args.api, args.model, args.query_method, args.retriever_method)
|
313 |
-
elif args.mode == 'retrieve_contexts':
|
314 |
-
if args.csv_file:
|
315 |
-
retrieve_and_update_contexts(args.api, args.model, args.csv_file)
|
316 |
-
else:
|
317 |
-
log_and_exit("CSV file path is required for retrieve_contexts mode.")
|
318 |
-
elif args.mode == 'retrieve_answers':
|
319 |
-
if args.csv_file:
|
320 |
-
retrieve_answers_for_modes(args.api, args.model, args.csv_file)
|
321 |
-
else:
|
322 |
-
log_and_exit("CSV file path is required for retrieve_answers mode.")
|
323 |
-
elif args.mode == 'benchmark':
|
324 |
-
pass
|
325 |
-
|
326 |
-
if __name__ == "__main__":
|
327 |
-
import sys
|
328 |
-
if 'streamlit' in sys.argv[0]:
|
329 |
-
run_streamlit_app('azure', 'gpt35')
|
330 |
-
else:
|
331 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|