|
import logging |
|
import json |
|
import pandas as pd |
|
import streamlit as st |
|
from pinecone import Pinecone |
|
from llama_index.vector_stores.pinecone import PineconeVectorStore |
|
from llama_index.core import ( |
|
StorageContext, VectorStoreIndex, SimpleDirectoryReader, |
|
get_response_synthesizer, Settings |
|
) |
|
from llama_index.core.node_parser import SentenceSplitter |
|
from llama_index.core.retrievers import ( |
|
VectorIndexRetriever, RouterRetriever |
|
) |
|
from llama_index.retrievers.bm25 import BM25Retriever |
|
from llama_index.core.tools import RetrieverTool |
|
from llama_index.core.query_engine import ( |
|
RetrieverQueryEngine, FLAREInstructQueryEngine, MultiStepQueryEngine |
|
) |
|
from llama_index.core.indices.query.query_transform import ( |
|
StepDecomposeQueryTransform |
|
) |
|
from llama_index.llms.groq import Groq |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.llms.azure_openai import AzureOpenAI |
|
from llama_index.embeddings.openai import OpenAIEmbedding |
|
from llama_index.readers.file import PyMuPDFReader |
|
import traceback |
|
from oauth2client.service_account import ServiceAccountCredentials |
|
import gspread |
|
import uuid |
|
from dotenv import load_dotenv |
|
import os |
|
from datetime import datetime |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] |
|
creds_dict = { |
|
"type": os.getenv("type"), |
|
"project_id": os.getenv("project_id"), |
|
"private_key_id": os.getenv("private_key_id"), |
|
"private_key": os.getenv("private_key"), |
|
"client_email": os.getenv("client_email"), |
|
"client_id": os.getenv("client_id"), |
|
"auth_uri": os.getenv("auth_uri"), |
|
"token_uri": os.getenv("token_uri"), |
|
"auth_provider_x509_cert_url": os.getenv("auth_provider_x509_cert_url"), |
|
"client_x509_cert_url": os.getenv("client_x509_cert_url") |
|
} |
|
creds_dict['private_key'] = creds_dict['private_key'].replace('\\n', '\n') |
|
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) |
|
client = gspread.authorize(creds) |
|
sheet = client.open("RAG").sheet1 |
|
|
|
|
|
AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME") |
|
AZURE_API_VERSION = os.getenv("AZURE_API_VERSION") |
|
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") |
|
|
|
|
|
llm = None |
|
pinecone_index = None |
|
|
|
def log_and_exit(message): |
|
logging.error(message) |
|
raise SystemExit(message) |
|
|
|
def initialize_apis(api, model, pinecone_api_key, groq_api_key, azure_api_key): |
|
global llm, pinecone_index |
|
try: |
|
if llm is None: |
|
llm = initialize_llm(api, model, groq_api_key, azure_api_key) |
|
if pinecone_index is None: |
|
pinecone_client = Pinecone(pinecone_api_key) |
|
pinecone_index = pinecone_client.Index("ll144") |
|
logging.info("Initialized LLM and Pinecone.") |
|
except Exception as e: |
|
log_and_exit(f"Error initializing APIs: {e}") |
|
|
|
def initialize_llm(api, model, groq_api_key, azure_api_key): |
|
if api == 'groq': |
|
model_mappings = { |
|
'mixtral-8x7b': "mixtral-8x7b-32768", |
|
'llama3-8b': "llama3-8b-8192", |
|
'llama3-70b': "llama3-70b-8192", |
|
'gemma-7b': "gemma-7b-it" |
|
} |
|
return Groq(model=model_mappings[model], api_key=groq_api_key) |
|
elif api == 'azure': |
|
if model == 'gpt35': |
|
return AzureOpenAI( |
|
deployment_name=AZURE_DEPLOYMENT_NAME, |
|
temperature=0, |
|
api_key=azure_api_key, |
|
azure_endpoint=AZURE_OPENAI_ENDPOINT, |
|
api_version=AZURE_API_VERSION |
|
) |
|
|
|
def load_pdf_data(chunk_size): |
|
reader = PyMuPDFReader() |
|
file_extractor = {".pdf": reader} |
|
documents = SimpleDirectoryReader(input_files=['LL144.pdf', 'LL144_Definitions.pdf'], file_extractor=file_extractor).load_data() |
|
return documents |
|
|
|
def create_index(documents, embedding_model_type="HF", embedding_model="BAAI/bge-large-en-v1.5", retriever_method="BM25", chunk_size=512): |
|
global llm, pinecone_index |
|
try: |
|
embed_model = select_embedding_model(embedding_model_type, embedding_model) |
|
|
|
Settings.llm = llm |
|
Settings.embed_model = embed_model |
|
Settings.chunk_size = chunk_size |
|
|
|
if retriever_method in ["BM25", "BM25+Vector"]: |
|
nodes = create_bm25_nodes(documents, chunk_size) |
|
logging.info("Created BM25 nodes from documents.") |
|
if retriever_method == "BM25+Vector": |
|
vector_store = PineconeVectorStore(pinecone_index=pinecone_index) |
|
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context) |
|
logging.info("Created index for BM25+Vector from documents.") |
|
return index, nodes |
|
return None, nodes |
|
else: |
|
vector_store = PineconeVectorStore(pinecone_index=pinecone_index) |
|
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context) |
|
logging.info("Created index from documents.") |
|
return index, None |
|
except Exception as e: |
|
log_and_exit(f"Error creating index: {e}") |
|
|
|
def select_embedding_model(embedding_model_type, embedding_model): |
|
if embedding_model_type == "HF": |
|
return HuggingFaceEmbedding(model_name=embedding_model) |
|
elif embedding_model_type == "OAI": |
|
return OpenAIEmbedding() |
|
|
|
def create_bm25_nodes(documents, chunk_size): |
|
splitter = SentenceSplitter(chunk_size=chunk_size) |
|
nodes = splitter.get_nodes_from_documents(documents) |
|
return nodes |
|
|
|
def select_retriever(index, nodes, retriever_method, top_k): |
|
logging.info(f"Selecting retriever with method: {retriever_method}") |
|
if nodes is not None: |
|
logging.info(f"Available document IDs: {list(range(len(nodes)))}") |
|
else: |
|
logging.warning("Nodes are None") |
|
|
|
if retriever_method == 'BM25': |
|
return BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k) |
|
elif retriever_method == "BM25+Vector": |
|
if index is None: |
|
log_and_exit("Index must be initialized when using BM25+Vector retriever method.") |
|
|
|
bm25_retriever = RetrieverTool.from_defaults( |
|
retriever=BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k), |
|
description="BM25 Retriever" |
|
) |
|
|
|
vector_retriever = RetrieverTool.from_defaults( |
|
retriever=VectorIndexRetriever(index=index), |
|
description="Vector Retriever" |
|
) |
|
|
|
router_retriever = RouterRetriever.from_defaults( |
|
retriever_tools=[bm25_retriever, vector_retriever], |
|
llm=llm, |
|
select_multi=True |
|
) |
|
return router_retriever |
|
elif retriever_method == "Vector Search": |
|
if index is None: |
|
log_and_exit("Index must be initialized when using Vector Search retriever method.") |
|
return VectorIndexRetriever(index=index, similarity_top_k=top_k) |
|
else: |
|
log_and_exit(f"Unsupported retriever method: {retriever_method}") |
|
|
|
def setup_query_engine(index, response_mode, nodes=None, query_engine_method=None, retriever_method=None, top_k=2): |
|
global llm |
|
try: |
|
logging.info(f"Setting up query engine with retriever_method: {retriever_method} and query_engine_method: {query_engine_method}") |
|
retriever = select_retriever(index, nodes, retriever_method, top_k) |
|
|
|
if retriever is None: |
|
log_and_exit("Failed to create retriever. Index or nodes might be None.") |
|
|
|
response_synthesizer = get_response_synthesizer(response_mode=response_mode) |
|
index_query_engine = index.as_query_engine(similarity_top_k=top_k) if index else None |
|
|
|
if query_engine_method == "FLARE": |
|
query_engine = FLAREInstructQueryEngine( |
|
query_engine=index_query_engine, |
|
max_iterations=4, |
|
verbose=False |
|
) |
|
elif query_engine_method == "MS": |
|
query_engine = MultiStepQueryEngine( |
|
query_engine=index_query_engine, |
|
query_transform=StepDecomposeQueryTransform(llm=llm, verbose=False), |
|
index_summary="Used to answer questions about the regulation" |
|
) |
|
else: |
|
query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer) |
|
|
|
if query_engine is None: |
|
log_and_exit("Failed to create query engine.") |
|
|
|
return query_engine |
|
except Exception as e: |
|
logging.error(f"Error setting up query engine: {e}") |
|
traceback.print_exc() |
|
log_and_exit(f"Error setting up query engine: {e}") |
|
|
|
def log_to_google_sheets(data): |
|
try: |
|
sheet.append_row(data) |
|
logging.info("Logged data to Google Sheets.") |
|
except Exception as e: |
|
logging.error(f"Error logging data to Google Sheets: {e}") |
|
|
|
def update_google_sheets(question_id, feedback=None, detailed_feedback=None, annotated_answer=None): |
|
try: |
|
existing_data = sheet.get_all_values() |
|
headers = existing_data[0] |
|
for i, row in enumerate(existing_data): |
|
if row[0] == question_id: |
|
if feedback is not None: |
|
sheet.update_cell(i+1, headers.index("Feedback") + 1, feedback) |
|
if detailed_feedback is not None: |
|
sheet.update_cell(i+1, headers.index("Detailed Feedback") + 1, detailed_feedback) |
|
if annotated_answer is not None: |
|
sheet.update_cell(i+1, headers.index("annotated_answer") + 1, annotated_answer) |
|
logging.info("Updated data in Google Sheets.") |
|
return |
|
except Exception as e: |
|
logging.error(f"Error updating data in Google Sheets: {e}") |
|
|
|
def run_streamlit_app(): |
|
if 'query_engine' not in st.session_state: |
|
st.session_state.query_engine = None |
|
|
|
st.title("RAG Chat Application") |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
pinecone_api_key = st.text_input("Pinecone API Key") |
|
azure_api_key = st.text_input("Azure API Key") |
|
groq_api_key = st.text_input("Groq API Key") |
|
|
|
def update_api_based_on_model(): |
|
selected_model = st.session_state['selected_model'] |
|
if selected_model == 'gpt35': |
|
st.session_state['selected_api'] = 'azure' |
|
else: |
|
st.session_state['selected_api'] = 'groq' |
|
|
|
with col2: |
|
selected_model = st.selectbox("Select Model", ["llama3-8b", "llama3-70b", "mixtral-8x7b", "gemma-7b", "gpt35"], index=4, key='selected_model', on_change=update_api_based_on_model) |
|
selected_api = st.selectbox("Select API", ["azure", "groq"], index=0, key='selected_api', disabled=True) |
|
embedding_model_type = "HF" |
|
embedding_model = st.selectbox("Select Embedding Model", ["BAAI/bge-large-en-v1.5", "other_model"]) |
|
retriever_method = st.selectbox("Select Retriever Method", ["Vector Search", "BM25", "BM25+Vector"]) |
|
|
|
col3, col4 = st.columns(2) |
|
with col3: |
|
chunk_size = st.selectbox("Select Chunk Size", [128, 256, 512, 1024], index=2) |
|
with col4: |
|
top_k = st.selectbox("Select Top K", [1, 2, 3, 5, 6], index=1) |
|
|
|
if st.button("Initialize"): |
|
initialize_apis(st.session_state['selected_api'], selected_model, pinecone_api_key, groq_api_key, azure_api_key) |
|
documents = load_pdf_data(chunk_size) |
|
index, nodes = create_index(documents, embedding_model_type=embedding_model_type, embedding_model=embedding_model, retriever_method=retriever_method, chunk_size=chunk_size) |
|
st.session_state.query_engine = setup_query_engine(index, response_mode="compact", nodes=nodes, query_engine_method=None, retriever_method=retriever_method, top_k=top_k) |
|
st.success("Initialization complete.") |
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
for chat_index, chat in enumerate(st.session_state.chat_history): |
|
with st.chat_message("user"): |
|
st.markdown(chat['user']) |
|
with st.chat_message("bot"): |
|
st.markdown("### Retrieved Contexts") |
|
for node in chat.get('contexts', []): |
|
st.markdown( |
|
f"<div style='border:1px solid #ccc; padding:10px; margin:10px 0; font-size:small;'>{node.text}</div>", |
|
unsafe_allow_html=True |
|
) |
|
st.markdown("### Answer") |
|
st.markdown(chat['response']) |
|
|
|
col1, col2 = st.columns([1, 1]) |
|
with col1: |
|
if st.button("Annotate π", key=f"annotate_{chat_index}"): |
|
chat['annotate'] = True |
|
chat['feedback'] = -1 |
|
st.session_state.chat_history[chat_index] = chat |
|
update_google_sheets(chat['id'], feedback=-1) |
|
st.rerun() |
|
with col2: |
|
if st.button("Approve π", key=f"approve_{chat_index}"): |
|
chat['approved'] = True |
|
chat['feedback'] = 1 |
|
st.session_state.chat_history[chat_index] = chat |
|
update_google_sheets(chat['id'], feedback=1, annotated_answer=chat['response']) |
|
|
|
if chat.get('annotate', False): |
|
annotated_answer = st.text_area("Annotate Answer", value=chat['response'], key=f"annotate_text_{chat_index}") |
|
if st.button("Submit Annotated Answer", key=f"submit_annotate_{chat_index}"): |
|
chat['annotated_answer'] = annotated_answer |
|
chat['annotate'] = False |
|
st.session_state.chat_history[chat_index] = chat |
|
update_google_sheets(chat['id'], annotated_answer=annotated_answer) |
|
|
|
feedback = st.text_area("How was the response? Does it match the context? Does it answer the question fully?", key=f"textarea_{chat_index}") |
|
if st.button("Submit Feedback", key=f"submit_{chat_index}"): |
|
chat['detailed_feedback'] = feedback |
|
st.session_state.chat_history[chat_index] = chat |
|
update_google_sheets(chat['id'], detailed_feedback=feedback) |
|
|
|
if question := st.chat_input("Enter your question"): |
|
if st.session_state.query_engine: |
|
with st.spinner('Generating response...'): |
|
|
|
history = "\n".join([f"Q: {chat['user']}\nA: {chat['response']}" for chat in st.session_state.chat_history]) |
|
full_query = f"{history}\nQ: {question}" |
|
response = st.session_state.query_engine.query(full_query) |
|
logging.info(f"Generated response: {response.response}") |
|
logging.info(f"Retrieved contexts: {[node.text for node in response.source_nodes]}") |
|
question_id = str(uuid.uuid4()) |
|
timestamp = datetime.now().isoformat() |
|
st.session_state.chat_history.append({'id': question_id, 'user': question, 'response': response.response, 'contexts': response.source_nodes, 'feedback': 0, 'detailed_feedback': '', 'annotated_answer': '', 'timestamp': timestamp}) |
|
|
|
|
|
log_to_google_sheets([question_id, question, response.response, st.session_state['selected_api'], selected_model, embedding_model, retriever_method, chunk_size, top_k, 0, "", "", timestamp]) |
|
|
|
st.rerun() |
|
else: |
|
st.error("Query engine is not initialized. Please initialize it first.") |
|
|
|
if __name__ == "__main__": |
|
run_streamlit_app() |
|
|