|
|
|
from langchain.llms import CTransformers |
|
from langchain.agents import Tool |
|
from langchain.agents import AgentType, initialize_agent |
|
from langchain.chains import RetrievalQA |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
|
|
import streamlit as st |
|
|
|
|
|
def main(): |
|
|
|
FILE_LOADER_MAPPING = { |
|
"pdf": (PyPDFLoader, {}) |
|
|
|
} |
|
|
|
st.title("Document Comparison with Q&A using Agents") |
|
|
|
|
|
|
|
|
|
uploaded_files = st.file_uploader("Upload your documents", type=["pdf"], accept_multiple_files=True) |
|
loaded_documents = [] |
|
|
|
if uploaded_files: |
|
|
|
with tempfile.TemporaryDirectory() as td: |
|
|
|
for uploaded_file in uploaded_files: |
|
st.write(f"Uploaded: {uploaded_file.name}") |
|
ext = os.path.splitext(uploaded_file.name)[-1][1:].lower() |
|
st.write(f"Uploaded: {ext}") |
|
|
|
|
|
if ext in FILE_LOADER_MAPPING: |
|
loader_class, loader_args = FILE_LOADER_MAPPING[ext] |
|
|
|
|
|
|
|
file_path = os.path.join(td, uploaded_file.name) |
|
with open(file_path, 'wb') as temp_file: |
|
temp_file.write(uploaded_file.read()) |
|
|
|
|
|
loader = loader_class(file_path, **loader_args) |
|
loaded_documents.extend(loader.load()) |
|
else: |
|
st.warning(f"Unsupported file extension: {ext}, the app currently only supports 'pdf'") |
|
|
|
st.write("Ask question to get comparison from the documents:") |
|
query = st.text_input("Ask a question:") |
|
|
|
|
|
|
|
|
|
if st.button("Get Answer"): |
|
if query: |
|
|
|
try: |
|
start = timeit.default_timer() |
|
config = { |
|
'max_new_tokens': 1024, |
|
'repetition_penalty': 1.1, |
|
'temperature': 0.1, |
|
'top_k': 50, |
|
'top_p': 0.9, |
|
'stream': True, |
|
'threads': int(os.cpu_count() / 2) |
|
} |
|
|
|
llm = CTransformers( |
|
model="TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF", |
|
model_file="mistral-7b-instruct-v0.2.Q4_0.gguf", |
|
model_type="mistral", |
|
lib="avx2", |
|
**config |
|
) |
|
|
|
print("LLM Initialized...") |
|
|
|
|
|
|
|
model_name = "BAAI/bge-large-en" |
|
model_kwargs = {'device': 'cpu'} |
|
encode_kwargs = {'normalize_embeddings': False} |
|
embeddings = HuggingFaceBgeEmbeddings( |
|
model_name=model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs |
|
) |
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
chunked_documents = text_splitter.split_documents(loaded_documents) |
|
retriever = FAISS.from_documents(docs, embeddings).as_retriever() |
|
|
|
|
|
tools.append( |
|
Tool( |
|
name="Comparison tool", |
|
description=f"useful when you want to answer questions about the uploaded documents}", |
|
func=RetrievalQA.from_chain_type(llm=llm, retriever=retriever), |
|
) |
|
|
|
agent = initialize_agent( |
|
tools=tools, |
|
llm=llm, |
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION |
|
verbose=True |
|
) |
|
|
|
response = agent.run(query") |
|
|
|
end = timeit.default_timer() |
|
st.write("Elapsed time:") |
|
st.write(end - start) |
|
|
|
st.write("Bot Response:") |
|
st.write(response) |
|
|
|
|
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
else: |
|
st.warning("Please enter a question.") |
|
|
|
|
|
|
|
) |
|
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|