captain-awesome's picture
Update app.py
7f721d2 verified
raw
history blame
5.15 kB
from langchain.llms import CTransformers
from langchain.agents import Tool
from langchain.agents import AgentType, initialize_agent
from langchain.chains import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain.embeddings import HuggingFaceBgeEmbeddings
import streamlit as st
def main():
FILE_LOADER_MAPPING = {
"pdf": (PyPDFLoader, {})
# Add more mappings for other file extensions and loaders as needed
}
st.title("Document Comparison with Q&A using Agents")
# Upload files
uploaded_files = st.file_uploader("Upload your documents", type=["pdf"], accept_multiple_files=True)
loaded_documents = []
if uploaded_files:
# Create a temporary directory
with tempfile.TemporaryDirectory() as td:
# Move the uploaded files to the temporary directory and process them
for uploaded_file in uploaded_files:
st.write(f"Uploaded: {uploaded_file.name}")
ext = os.path.splitext(uploaded_file.name)[-1][1:].lower()
st.write(f"Uploaded: {ext}")
# Check if the extension is in FILE_LOADER_MAPPING
if ext in FILE_LOADER_MAPPING:
loader_class, loader_args = FILE_LOADER_MAPPING[ext]
# st.write(f"loader_class: {loader_class}")
# Save the uploaded file to the temporary directory
file_path = os.path.join(td, uploaded_file.name)
with open(file_path, 'wb') as temp_file:
temp_file.write(uploaded_file.read())
# Use Langchain loader to process the file
loader = loader_class(file_path, **loader_args)
loaded_documents.extend(loader.load())
else:
st.warning(f"Unsupported file extension: {ext}, the app currently only supports 'pdf'")
st.write("Ask question to get comparison from the documents:")
query = st.text_input("Ask a question:")
if st.button("Get Answer"):
if query:
# Load model, set prompts, create vector database, and retrieve answer
try:
start = timeit.default_timer()
config = {
'max_new_tokens': 1024,
'repetition_penalty': 1.1,
'temperature': 0.1,
'top_k': 50,
'top_p': 0.9,
'stream': True,
'threads': int(os.cpu_count() / 2)
}
llm = CTransformers(
model="TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
model_file="mistral-7b-instruct-v0.2.Q4_0.gguf",
model_type="mistral",
lib="avx2", #for CPU use
**config
)
print("LLM Initialized...")
model_name = "BAAI/bge-large-en"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
chunked_documents = text_splitter.split_documents(loaded_documents)
retriever = FAISS.from_documents(docs, embeddings).as_retriever()
# Wrap retrievers in a Tool
tools.append(
Tool(
name="Comparison tool",
description=f"useful when you want to answer questions about the uploaded documents}",
func=RetrievalQA.from_chain_type(llm=llm, retriever=retriever),
)
agent = initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION
verbose=True
)
response = agent.run(query")
end = timeit.default_timer()
st.write("Elapsed time:")
st.write(end - start)
st.write("Bot Response:")
st.write(response)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
else:
st.warning("Please enter a question.")
)
)
if __name__ == "__main__":
main()