Document_RAG_QA / app.py
VishnuRamDebyez's picture
Update app.py
02ab853 verified
raw
history blame
4.86 kB
import streamlit as st
import os
from langchain_groq import ChatGroq
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from dotenv import load_dotenv
import time
# Load environment variables
load_dotenv()
# Set page configuration
st.set_page_config(page_title="Legal Assistant", layout="wide")
# Initialize session state for chat history and response
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'last_response' not in st.session_state:
st.session_state.last_response = None
if 'clear_chat' not in st.session_state:
st.session_state.clear_chat = False
# Title
st.title("Legal Assistant")
# Sidebar setup
st.sidebar.title("Chat History")
# API Key Configuration
groq_api_key = os.getenv('groqapi')
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")
# LLM and Prompt Setup
llm = ChatGroq(groq_api_key=groq_api_key, model_name="Llama3-8b-8192")
prompt = ChatPromptTemplate.from_template(
"""
Answer the questions based on the provided context only.
Please provide the most accurate response based on the question
<context>
{context}
<context>
Questions:{input}
"""
)
def vector_embedding():
"""Perform vector embedding of documents"""
if "vectors" not in st.session_state:
st.session_state.embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
st.session_state.loader = PyPDFDirectoryLoader("./new") # Data Ingestion
st.session_state.docs = st.session_state.loader.load() # Document Loading
st.session_state.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # Chunk Creation
st.session_state.final_documents = st.session_state.text_splitter.split_documents(st.session_state.docs[:20]) # splitting
st.session_state.vectors = FAISS.from_documents(st.session_state.final_documents, st.session_state.embeddings)
# Perform vector embedding
vector_embedding()
# Function to add to chat history
def add_to_chat_history(question, answer):
st.session_state.chat_history.append({
'question': question,
'answer': answer
})
# Main input area with clear button
col1, col2 = st.columns([3, 1])
with col1:
# Use a key to manage the text input state
# If clear_chat is True, use an empty default value
default_text = "" if st.session_state.clear_chat else st.session_state.get('last_question', '')
prompt1 = st.text_input("Enter Your Question From Documents",
value=default_text,
key="user_question")
with col2:
# Add clear chat button next to the input
st.text("") # Add some vertical space to align with input
clear_button = st.button("Clear Chat")
# Clear chat functionality
if clear_button:
# Reset flags and last response
st.session_state.clear_chat = True
st.session_state.last_response = None
st.experimental_rerun()
# Reset clear_chat flag after processing
st.session_state.clear_chat = False
# Process question and generate response
if prompt1:
try:
# Store the current question
st.session_state.last_question = prompt1
# Create document and retrieval chains
document_chain = create_stuff_documents_chain(llm, prompt)
retriever = st.session_state.vectors.as_retriever()
retrieval_chain = create_retrieval_chain(retriever, document_chain)
# Generate response
start = time.process_time()
response = retrieval_chain.invoke({'input': prompt1})
response_time = time.process_time() - start
# Store and display response
st.session_state.last_response = response['answer']
# Add to chat history
add_to_chat_history(prompt1, response['answer'])
except Exception as e:
st.error(f"An error occurred: {e}")
# Display the last response if exists
if st.session_state.last_response:
st.write(st.session_state.last_response)
# Sidebar content
# Clear chat history button
if st.sidebar.button("Clear Chat History"):
st.session_state.chat_history = []
# Display chat history
st.sidebar.write("### Previous Questions")
for idx, chat in enumerate(reversed(st.session_state.chat_history), 1):
# Expander for each chat history item
with st.sidebar.expander(f"Question {len(st.session_state.chat_history) - idx + 1}"):
st.write(f"**Question:** {chat['question']}")
st.write(f"**Answer:** {chat['answer']}")