Document_RAG_QA / app.py
VishnuRamDebyez's picture
Update app.py
20ca471 verified
import streamlit as st
import os
from langchain_groq import ChatGroq
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from dotenv import load_dotenv
import time
# Load environment variables
load_dotenv()
# Set page configuration
st.set_page_config(page_title="Legal Assistant", layout="wide")
# Create a unique key for the session state to help with resetting
if 'reset_key' not in st.session_state:
st.session_state['reset_key'] = 0
# Function to reset the entire session state
def reset_session_state():
# Increment reset key to force a complete reset
st.session_state['reset_key'] += 1
# Reset specific session state variables
st.session_state['last_response'] = None
st.session_state['current_question'] = ''
# Title
st.title("Legal Assistant")
# Sidebar setup
st.sidebar.title("Chat History")
# API Key Configuration
groq_api_key = os.getenv('groqapi')
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")
# Initialize chat history if not exists
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
# LLM and Prompt Setup
llm = ChatGroq(groq_api_key=groq_api_key, model_name="Llama3-8b-8192")
prompt = ChatPromptTemplate.from_template(
"""
Answer the questions based on the provided context only.
Please provide the most accurate response based on the question
<context>
{context}
<context>
Questions:{input}
"""
)
def vector_embedding():
"""Perform vector embedding of documents"""
if "vectors" not in st.session_state:
st.session_state.embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
st.session_state.loader = PyPDFDirectoryLoader("./new") # Data Ingestion
st.session_state.docs = st.session_state.loader.load() # Document Loading
st.session_state.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # Chunk Creation
st.session_state.final_documents = st.session_state.text_splitter.split_documents(st.session_state.docs[:20]) # splitting
st.session_state.vectors = FAISS.from_documents(st.session_state.final_documents, st.session_state.embeddings)
# Perform vector embedding
vector_embedding()
# Function to add to chat history
def add_to_chat_history(question, answer):
st.session_state.chat_history.append({
'question': question,
'answer': answer
})
# Main content area
def main():
# Clear chat button
clear_button = st.button("Clear Chat")
# Handle clear chat functionality
if clear_button:
reset_session_state()
# Create a unique key for the text input to force reset
text_input_key = f'question_input_{st.session_state["reset_key"]}'
# Text input with reset mechanism
prompt1 = st.text_input(
"Enter Your Question",
key=text_input_key,
value=st.session_state.get('current_question', '')
)
# Process question if exists
if prompt1:
try:
# Store current question
st.session_state['current_question'] = prompt1
# Create document and retrieval chains
document_chain = create_stuff_documents_chain(llm, prompt)
retriever = st.session_state.vectors.as_retriever()
retrieval_chain = create_retrieval_chain(retriever, document_chain)
# Generate response
start = time.process_time()
response = retrieval_chain.invoke({'input': prompt1})
response_time = time.process_time() - start
# Store and display response
st.session_state['last_response'] = response['answer']
# Add to chat history
add_to_chat_history(prompt1, response['answer'])
except Exception as e:
st.error(f"An error occurred: {e}")
# Display the last response if exists
if st.session_state.get('last_response'):
st.write(st.session_state['last_response'])
# Sidebar content
# Clear chat history button
if st.sidebar.button("Clear Chat History"):
st.session_state.chat_history = []
# Display chat history
st.sidebar.write("### Previous Questions")
for idx, chat in enumerate(reversed(st.session_state.chat_history), 1):
# Expander for each chat history item
with st.sidebar.expander(f"Question {len(st.session_state.chat_history) - idx + 1}"):
st.write(f"**Question:** {chat['question']}")
st.write(f"**Answer:** {chat['answer']}")
# Run the main function
main()