File size: 5,991 Bytes
3ec9224 5be8df6 3ec9224 5be8df6 d4b9831 5be8df6 3ec9224 1ef8d7c d4b9831 0abb90d 5be8df6 d4b9831 5be8df6 0abb90d 1ef8d7c 5be8df6 1ef8d7c 5be8df6 1ef8d7c d4b9831 5be8df6 0abb90d 5be8df6 d4b9831 fc1e558 d4b9831 fc1e558 d4b9831 5be8df6 9733941 5be8df6 d4b9831 fc1e558 d4b9831 5be8df6 0abb90d 5be8df6 9733941 d4b9831 5be8df6 d4b9831 5be8df6 00bd139 5be8df6 d4b9831 fc1e558 5be8df6 1ef8d7c fc1e558 d4b9831 0abb90d d4b9831 0abb90d fc1e558 5be8df6 d4b9831 5be8df6 3ca2785 1ef8d7c d4b9831 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import os
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceHub
from pathlib import Path
import chromadb
# List of available LLM models
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
"google/gemma-7b-it", "google/gemma-2b-it",
"HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct",
"google/flan-t5-xxl"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
# Load PDF document and create doc splits
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
# Create vector database
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name
)
return vectordb
# Initialize langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
elif llm_model == "microsoft/phi-2":
raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
model_kwargs = {"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
else:
model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
llm = HuggingFaceHub(
repo_id=llm_model,
model_kwargs=model_kwargs
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False
)
progress(0.9, desc="Done!")
return qa_chain
def initialize_demo(list_file_obj, chunk_size, chunk_overlap, db_progress):
list_file_path = [file.name for file in list_file_obj if file is not None]
collection_name = Path(list_file_path[0]).stem.replace(" ", "-")[:50]
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
vector_db = create_db(doc_splits, collection_name)
qa_chain = initialize_llmchain(
list_llm[0], # Using Mistral-7B-Instruct-v0.2 as the LLM model
0.7, # Temperature
1024, # Max Tokens
3, # Top K
vector_db,
db_progress
)
return vector_db, collection_name, qa_chain, "Complete!"
def upload_file(file_obj):
list_file_path = []
for file in file_obj:
if file is not None:
file_path = file.name
list_file_path.append(file_path)
return list_file_path
def demo():
with gr.Blocks(theme="base") as demo:
vector_db = gr.State()
collection_name = gr.State()
qa_chain = gr.State()
with gr.Tab("Step 1 - Document pre-processing"):
document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
db_progress = gr.Textbox(label="Vector database initialization", value="None")
db_btn = gr.Button("Generate vector database...")
with gr.Tab("Step 2 - QA chain initialization"):
llm_progress = gr.Textbox(value="None", label="QA chain initialization")
qachain_btn = gr.Button("Initialize question-answering chain...")
with gr.Tab("Step 3 - Conversation with chatbot"):
chatbot = gr.Chatbot(height=300)
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
source1_page = gr.Number(label="Page", scale=1)
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
source2_page = gr.Number(label="Page", scale=1)
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
source3_page = gr.Number(label="Page", scale=1)
msg = gr.Textbox(placeholder="Type message", container=True)
submit_btn = gr.Button("Submit")
clear_btn = gr.ClearButton([msg, chatbot])
document.upload(initialize_demo, inputs=[document, slider_chunk_size, slider_chunk_overlap, db_progress], outputs=[vector_db, collection_name, qa_chain, db_progress])
qachain_btn.click(initialize_llmchain, inputs=[qa_chain, llm_progress], outputs=[qa_chain, llm_progress])
submit_btn.click(lambda: None, inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2
|