# from langchain.chains import ConversationalRetrievalChain | |
# from langchain.chains.question_answering import load_qa_chain | |
# from langchain.chains import RetrievalQA | |
# from langchain.memory import ConversationBufferMemory | |
# from langchain.memory import ConversationTokenBufferMemory | |
# from langchain.llms import HuggingFacePipeline | |
# # from langchain import PromptTemplate | |
# from langchain.prompts import PromptTemplate | |
# from langchain.embeddings import HuggingFaceEmbeddings | |
# from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
# from langchain.vectorstores import Chroma | |
# from chromadb.utils import embedding_functions | |
# from langchain.embeddings import SentenceTransformerEmbeddings | |
# from langchain.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.llms import Llamacpp | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.document_loaders import ( | |
CSVLoader, | |
DirectoryLoader, | |
GitLoader, | |
NotebookLoader, | |
OnlinePDFLoader, | |
PythonLoader, | |
TextLoader, | |
UnstructuredFileLoader, | |
UnstructuredHTMLLoader, | |
UnstructuredPDFLoader, | |
UnstructuredWordDocumentLoader, | |
WebBaseLoader, | |
PyPDFLoader, | |
UnstructuredMarkdownLoader, | |
UnstructuredEPubLoader, | |
UnstructuredHTMLLoader, | |
UnstructuredPowerPointLoader, | |
UnstructuredODTLoader, | |
NotebookLoader, | |
UnstructuredFileLoader | |
) | |
# from transformers import ( | |
# AutoModelForCausalLM, | |
# AutoTokenizer, | |
# StoppingCriteria, | |
# StoppingCriteriaList, | |
# pipeline, | |
# GenerationConfig, | |
# TextStreamer, | |
# pipeline | |
# ) | |
# from langchain.llms import HuggingFaceHub | |
import torch | |
# from transformers import BitsAndBytesConfig | |
import os | |
# from langchain.llms import CTransformers | |
import streamlit as st | |
# from langchain.document_loaders.base import BaseLoader | |
# from langchain.schema import Document | |
# import gradio as gr | |
import tempfile | |
import timeit | |
import textwrap | |
# from chromadb.utils import embedding_functions | |
# from tqdm import tqdm | |
# tqdm(disable=True, total=0) # initialise internal lock | |
# tqdm.write("test") | |
from langchain import PromptTemplate, LLMChain | |
from langchain.llms import CTransformers | |
import os | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceBgeEmbeddings | |
from io import BytesIO | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.vectorstores import FAISS | |
# def load_model(): | |
# config = {'max_new_tokens': 1024, | |
# 'repetition_penalty': 1.1, | |
# 'temperature': 0.1, | |
# 'top_k': 50, | |
# 'top_p': 0.9, | |
# 'stream': True, | |
# 'threads': int(os.cpu_count() / 2) | |
# } | |
# llm = CTransformers( | |
# model = "TheBloke/zephyr-7B-beta-GGUF", | |
# model_file = "zephyr-7b-beta.Q4_0.gguf", | |
# callbacks=[StreamingStdOutCallbackHandler()], | |
# lib="avx2", #for CPU use | |
# **config | |
# # model_type=model_type, | |
# # max_new_tokens=max_new_tokens, # type: ignore | |
# # temperature=temperature, # type: ignore | |
# ) | |
# return llm | |
# def create_vector_database(loaded_documents): | |
# # DB_DIR: str = os.path.join(ABS_PATH, "db") | |
# """ | |
# Creates a vector database using document loaders and embeddings. | |
# This function loads data from PDF, markdown and text files in the 'data/' directory, | |
# splits the loaded documents into chunks, transforms them into embeddings using HuggingFace, | |
# and finally persists the embeddings into a Chroma vector database. | |
# """ | |
# # Split loaded documents into chunks | |
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=30, length_function = len) | |
# chunked_documents = text_splitter.split_documents(loaded_documents) | |
# # embeddings = HuggingFaceEmbeddings( | |
# # model_name="sentence-transformers/all-MiniLM-L6-v2" | |
# # # model_name = "sentence-transformers/all-mpnet-base-v2" | |
# # ) | |
# embeddings = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2") | |
# # embeddings = HuggingFaceBgeEmbeddings( | |
# # model_name = "BAAI/bge-large-en" | |
# # ) | |
# # model_name = "BAAI/bge-large-en" | |
# # model_kwargs = {'device': 'cpu'} | |
# # encode_kwargs = {'normalize_embeddings': False} | |
# # embeddings = HuggingFaceBgeEmbeddings( | |
# # model_name=model_name, | |
# # model_kwargs=model_kwargs, | |
# # encode_kwargs=encode_kwargs | |
# # ) | |
# persist_directory = 'db' | |
# # Create and persist a Chroma vector database from the chunked documents | |
# db = Chroma.from_documents( | |
# documents=chunked_documents, | |
# embedding=embeddings, | |
# persist_directory=persist_directory | |
# # persist_directory=DB_DIR, | |
# ) | |
# db.persist() | |
# # db = Chroma(persist_directory=persist_directory, | |
# # embedding_function=embedding) | |
# return db | |
# def set_custom_prompt(): | |
# """ | |
# Prompt template for retrieval for each vectorstore | |
# """ | |
# prompt_template = """Use the following pieces of information to answer the user's question. | |
# If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
# Context: {context} | |
# Question: {question} | |
# Only return the helpful answer below and nothing else. | |
# Helpful answer: | |
# """ | |
# prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
# return prompt | |
# def create_chain(llm, prompt, db): | |
# """ | |
# Creates a Retrieval Question-Answering (QA) chain using a given language model, prompt, and database. | |
# This function initializes a ConversationalRetrievalChain object with a specific chain type and configurations, | |
# and returns this chain. The retriever is set up to return the top 3 results (k=3). | |
# Args: | |
# llm (any): The language model to be used in the RetrievalQA. | |
# prompt (str): The prompt to be used in the chain type. | |
# db (any): The database to be used as the | |
# retriever. | |
# Returns: | |
# ConversationalRetrievalChain: The initialized conversational chain. | |
# """ | |
# memory = ConversationTokenBufferMemory(llm=llm, memory_key="chat_history", return_messages=True, input_key='question', output_key='answer') | |
# # chain = ConversationalRetrievalChain.from_llm( | |
# # llm=llm, | |
# # chain_type="stuff", | |
# # retriever=db.as_retriever(search_kwargs={"k": 3}), | |
# # return_source_documents=True, | |
# # max_tokens_limit=256, | |
# # combine_docs_chain_kwargs={"prompt": prompt}, | |
# # condense_question_prompt=CONDENSE_QUESTION_PROMPT, | |
# # memory=memory, | |
# # ) | |
# # chain = RetrievalQA.from_chain_type(llm=llm, | |
# # chain_type='stuff', | |
# # retriever=db.as_retriever(search_kwargs={'k': 3}), | |
# # return_source_documents=True, | |
# # chain_type_kwargs={'prompt': prompt} | |
# # ) | |
# chain = RetrievalQA.from_chain_type(llm=llm, | |
# chain_type='stuff', | |
# retriever=db.as_retriever(search_kwargs={'k': 3}), | |
# return_source_documents=True | |
# ) | |
# return chain | |
# def create_retrieval_qa_bot(loaded_documents): | |
# # if not os.path.exists(persist_dir): | |
# # raise FileNotFoundError(f"No directory found at {persist_dir}") | |
# try: | |
# llm = load_model() # Assuming this function exists and works as expected | |
# except Exception as e: | |
# raise Exception(f"Failed to load model: {str(e)}") | |
# try: | |
# prompt = set_custom_prompt() # Assuming this function exists and works as expected | |
# except Exception as e: | |
# raise Exception(f"Failed to get prompt: {str(e)}") | |
# # try: | |
# # CONDENSE_QUESTION_PROMPT = set_custom_prompt_condense() # Assuming this function exists and works as expected | |
# # except Exception as e: | |
# # raise Exception(f"Failed to get condense prompt: {str(e)}") | |
# try: | |
# db = create_vector_database(loaded_documents) # Assuming this function exists and works as expected | |
# except Exception as e: | |
# raise Exception(f"Failed to get database: {str(e)}") | |
# try: | |
# # qa = create_chain( | |
# # llm=llm, prompt=prompt,CONDENSE_QUESTION_PROMPT=CONDENSE_QUESTION_PROMPT, db=db | |
# # ) # Assuming this function exists and works as expected | |
# qa = create_chain( | |
# llm=llm, prompt=prompt, db=db | |
# ) # Assuming this function exists and works as expected | |
# except Exception as e: | |
# raise Exception(f"Failed to create retrieval QA chain: {str(e)}") | |
# return qa | |
# def wrap_text_preserve_newlines(text, width=110): | |
# # Split the input text into lines based on newline characters | |
# lines = text.split('\n') | |
# # Wrap each line individually | |
# wrapped_lines = [textwrap.fill(line, width=width) for line in lines] | |
# # Join the wrapped lines back together using newline characters | |
# wrapped_text = '\n'.join(wrapped_lines) | |
# return wrapped_text | |
# def retrieve_bot_answer(query, loaded_documents): | |
# """ | |
# Retrieves the answer to a given query using a QA bot. | |
# This function creates an instance of a QA bot, passes the query to it, | |
# and returns the bot's response. | |
# Args: | |
# query (str): The question to be answered by the QA bot. | |
# Returns: | |
# dict: The QA bot's response, typically a dictionary with response details. | |
# """ | |
# qa_bot_instance = create_retrieval_qa_bot(loaded_documents) | |
# # bot_response = qa_bot_instance({"question": query}) | |
# bot_response = qa_bot_instance({"query": query}) | |
# # Check if the 'answer' key exists in the bot_response dictionary | |
# # if 'answer' in bot_response: | |
# # # answer = bot_response['answer'] | |
# # return bot_response | |
# # else: | |
# # raise KeyError("Expected 'answer' key in bot_response, but it was not found.") | |
# # result = bot_response['answer'] | |
# # result = bot_response['result'] | |
# # sources = [] | |
# # for source in bot_response["source_documents"]: | |
# # sources.append(source.metadata['source']) | |
# # return result, sources | |
# result = wrap_text_preserve_newlines(bot_response['result']) | |
# for source in bot_response["source_documents"]: | |
# sources.append(source.metadata['source']) | |
# return result, sources | |
def main(): | |
FILE_LOADER_MAPPING = { | |
"csv": (CSVLoader, {"encoding": "utf-8"}), | |
"doc": (UnstructuredWordDocumentLoader, {}), | |
"docx": (UnstructuredWordDocumentLoader, {}), | |
"epub": (UnstructuredEPubLoader, {}), | |
"html": (UnstructuredHTMLLoader, {}), | |
"md": (UnstructuredMarkdownLoader, {}), | |
"odt": (UnstructuredODTLoader, {}), | |
"pdf": (PyPDFLoader, {}), | |
"ppt": (UnstructuredPowerPointLoader, {}), | |
"pptx": (UnstructuredPowerPointLoader, {}), | |
"txt": (TextLoader, {"encoding": "utf8"}), | |
"ipynb": (NotebookLoader, {}), | |
"py": (PythonLoader, {}), | |
# Add more mappings for other file extensions and loaders as needed | |
} | |
st.title("Docuverse") | |
# Upload files | |
uploaded_files = st.file_uploader("Upload your documents", type=["pdf", "md", "txt", "csv", "py", "epub", "html", "ppt", "pptx", "doc", "docx", "odt", "ipynb"], accept_multiple_files=True) | |
loaded_documents = [] | |
if uploaded_files: | |
# Create a temporary directory | |
with tempfile.TemporaryDirectory() as td: | |
# Move the uploaded files to the temporary directory and process them | |
for uploaded_file in uploaded_files: | |
st.write(f"Uploaded: {uploaded_file.name}") | |
ext = os.path.splitext(uploaded_file.name)[-1][1:].lower() | |
st.write(f"Uploaded: {ext}") | |
# Check if the extension is in FILE_LOADER_MAPPING | |
if ext in FILE_LOADER_MAPPING: | |
loader_class, loader_args = FILE_LOADER_MAPPING[ext] | |
# st.write(f"loader_class: {loader_class}") | |
# Save the uploaded file to the temporary directory | |
file_path = os.path.join(td, uploaded_file.name) | |
with open(file_path, 'wb') as temp_file: | |
temp_file.write(uploaded_file.read()) | |
# Use Langchain loader to process the file | |
loader = loader_class(file_path, **loader_args) | |
loaded_documents.extend(loader.load()) | |
else: | |
st.warning(f"Unsupported file extension: {ext}") | |
# st.write(f"loaded_documents: {loaded_documents}") | |
st.write("Chat with the Document:") | |
query = st.text_input("Ask a question:") | |
if st.button("Get Answer"): | |
if query: | |
# Load model, set prompts, create vector database, and retrieve answer | |
try: | |
start = timeit.default_timer() | |
config = { | |
'max_new_tokens': 1024, | |
'repetition_penalty': 1.1, | |
'temperature': 0.1, | |
'top_k': 50, | |
'top_p': 0.9, | |
'stream': True, | |
'threads': int(os.cpu_count() / 2) | |
} | |
# llm = CTransformers( | |
# model = "TheBloke/zephyr-7B-beta-GGUF", | |
# model_file = "zephyr-7b-beta.Q4_0.gguf", | |
# model_type="mistral", | |
# lib="avx2", #for CPU use | |
# **config | |
# ) | |
llm = Llamacpp(model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF",temperature=0.75,max_tokens=2000,top_p=1) | |
st.write("LLM Initialized:") | |
model_name = "BAAI/bge-large-en" | |
model_kwargs = {'device': 'cpu'} | |
encode_kwargs = {'normalize_embeddings': False} | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs | |
) | |
# embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl", | |
# model_kwargs={"device": "cpu"}) | |
# llm = load_model() | |
# prompt = set_custom_prompt() | |
# CONDENSE_QUESTION_PROMPT = set_custom_prompt_condense() | |
# db = create_vector_database(loaded_documents) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=30, length_function = len) | |
chunked_documents = text_splitter.split_documents(loaded_documents) | |
persist_directory = 'db' | |
# Create and persist a Chroma vector database from the chunked documents | |
db = FAISS.from_documents(chunked_documents, embeddings) | |
# db = Chroma.from_documents(documents=chunked_documents,embedding=embeddings,persist_directory=persist_directory) | |
# db.persist() | |
retriever = db.as_retriever(search_kwargs={"k":1}) | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, verbose=True) | |
bot_response = qa(query) | |
lines = bot_response['result'].split('\n') | |
wrapped_lines = [textwrap.fill(line, width=50) for line in lines] | |
wrapped_text = '\n'.join(wrapped_lines) | |
for source in bot_response["source_documents"]: | |
sources = source.metadata['source'] | |
end = timeit.default_timer() | |
st.write("Elapsed time:") | |
st.write(end - start) | |
# st.write(f"response: {response}") | |
# Display bot response | |
st.write("Bot Response:") | |
st.write(wrapped_text) | |
st.write(sources) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
else: | |
st.warning("Please enter a question.") | |
if __name__ == "__main__": | |
main() | |