Ubai commited on
Commit
d4b9831
·
verified ·
1 Parent(s): 0ae6c24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -72
app.py CHANGED
@@ -5,24 +5,28 @@ from langchain.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
- from langchain.embeddings import HuggingFaceEmbeddings
9
- from langchain.llms import HuggingFacePipeline
10
- from langchain.chains import ConversationChain
11
- from langchain.memory import ConversationBufferMemory
12
  from langchain.llms import HuggingFaceHub
13
 
14
  from pathlib import Path
15
  import chromadb
16
 
 
 
 
 
 
 
 
 
 
17
  # Load PDF document and create doc splits
18
  def load_doc(list_file_path, chunk_size, chunk_overlap):
19
  loaders = [PyPDFLoader(x) for x in list_file_path]
20
  pages = []
21
  for loader in loaders:
22
  pages.extend(loader.load())
23
- text_splitter = RecursiveCharacterTextSplitter(
24
- chunk_size=chunk_size,
25
- chunk_overlap=chunk_overlap)
26
  doc_splits = text_splitter.split_documents(pages)
27
  return doc_splits
28
 
@@ -34,105 +38,98 @@ def create_db(splits, collection_name):
34
  documents=splits,
35
  embedding=embedding,
36
  client=new_client,
37
- collection_name=collection_name,
38
  )
39
  return vectordb
40
 
41
  # Initialize langchain LLM chain
42
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
43
  llm = HuggingFaceHub(
44
  repo_id=llm_model,
45
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
46
  )
 
47
  memory = ConversationBufferMemory(
48
  memory_key="chat_history",
49
  output_key='answer',
50
  return_messages=True
51
  )
 
52
  retriever = vector_db.as_retriever()
 
53
  qa_chain = ConversationalRetrievalChain.from_llm(
54
  llm,
55
  retriever=retriever,
56
  chain_type="stuff",
57
  memory=memory,
58
  return_source_documents=True,
59
- verbose=False,
60
  )
 
61
  progress(0.9, desc="Done!")
62
  return qa_chain
63
 
64
- # Initialize database and LLM chain
65
- def initialize_demo(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
66
- list_file_path = [x.name for x in list_file_obj if x is not None]
67
  collection_name = Path(list_file_path[0]).stem.replace(" ", "-")[:50]
68
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
69
  vector_db = create_db(doc_splits, collection_name)
70
  qa_chain = initialize_llmchain(
71
- "mistralai/Mistral-7B-Instruct-v0.2",
72
- 0.7,
73
- 1024,
74
- 3,
75
  vector_db,
76
- progress
77
  )
78
  return vector_db, collection_name, qa_chain, "Complete!"
79
 
80
- def format_chat_history(message, chat_history):
81
- formatted_chat_history = []
82
- for user_message, bot_message in chat_history:
83
- formatted_chat_history.append(f"User: {user_message}")
84
- formatted_chat_history.append(f"Assistant: {bot_message}")
85
- return formatted_chat_history
86
-
87
- def conversation(qa_chain, message, history):
88
- formatted_chat_history = format_chat_history(message, history)
89
- response = qa_chain({"question": message, "chat_history": formatted_chat_history})
90
- response_answer = response["answer"]
91
- if response_answer.find("Helpful Answer:") != -1:
92
- response_answer = response_answer.split("Helpful Answer:")[-1]
93
- response_sources = response["source_documents"]
94
- response_source1 = response_sources[0].page_content.strip()
95
- response_source2 = response_sources[1].page_content.strip()
96
- response_source3 = response_sources[2].page_content.strip()
97
- response_source1_page = response_sources[0].metadata["page"] + 1
98
- response_source2_page = response_sources[1].metadata["page"] + 1
99
- response_source3_page = response_sources[2].metadata["page"] + 1
100
- new_history = history + [(message, response_answer)]
101
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
102
 
103
  def demo():
104
  with gr.Blocks(theme="base") as demo:
105
  vector_db = gr.State()
106
- qa_chain = gr.State()
107
  collection_name = gr.State()
108
-
109
- gr.Markdown(
110
- """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
111
- <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
112
- <b>Note:</b> This AI assistant performs retrieval-augmented generation from your PDF documents. \
113
- When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
114
- <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
115
- """)
116
-
117
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
118
- slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
119
- slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
120
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
121
-
122
- # Initialize vector database and LLM chain in the background
123
- vector_db, collection_name, qa_chain, status = initialize_demo([document], slider_chunk_size, slider_chunk_overlap, db_progress)
124
-
125
- chatbot = gr.Chatbot(height=300)
126
- msg = gr.Textbox(placeholder="Type message", container=True)
127
- submit_btn = gr.Button("Submit")
128
- clear_btn = gr.ClearButton([msg, chatbot])
129
-
130
- msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot], queue=False)
131
- submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot], queue=False)
132
- clear_btn.click(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot], queue=False)
133
-
134
- demo.queue().launch(debug=True)
135
-
136
- if __name__ == "__main__":
137
- demo()
138
-
 
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
+ from langchain.embeddings import HuggingFaceEmbeddings
 
 
 
9
  from langchain.llms import HuggingFaceHub
10
 
11
  from pathlib import Path
12
  import chromadb
13
 
14
+ # List of available LLM models
15
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
16
+ "google/gemma-7b-it", "google/gemma-2b-it",
17
+ "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
18
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct",
19
+ "google/flan-t5-xxl"
20
+ ]
21
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
+
23
  # Load PDF document and create doc splits
24
  def load_doc(list_file_path, chunk_size, chunk_overlap):
25
  loaders = [PyPDFLoader(x) for x in list_file_path]
26
  pages = []
27
  for loader in loaders:
28
  pages.extend(loader.load())
29
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
 
 
30
  doc_splits = text_splitter.split_documents(pages)
31
  return doc_splits
32
 
 
38
  documents=splits,
39
  embedding=embedding,
40
  client=new_client,
41
+ collection_name=collection_name
42
  )
43
  return vectordb
44
 
45
  # Initialize langchain LLM chain
46
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
47
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
48
+ model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
49
+ elif llm_model == "microsoft/phi-2":
50
+ raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
51
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
52
+ model_kwargs = {"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
53
+ else:
54
+ model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
55
+
56
  llm = HuggingFaceHub(
57
  repo_id=llm_model,
58
+ model_kwargs=model_kwargs
59
  )
60
+
61
  memory = ConversationBufferMemory(
62
  memory_key="chat_history",
63
  output_key='answer',
64
  return_messages=True
65
  )
66
+
67
  retriever = vector_db.as_retriever()
68
+
69
  qa_chain = ConversationalRetrievalChain.from_llm(
70
  llm,
71
  retriever=retriever,
72
  chain_type="stuff",
73
  memory=memory,
74
  return_source_documents=True,
75
+ verbose=False
76
  )
77
+
78
  progress(0.9, desc="Done!")
79
  return qa_chain
80
 
81
+ def initialize_demo(list_file_obj, chunk_size, chunk_overlap, db_progress):
82
+ list_file_path = [file.name for file in list_file_obj if file is not None]
 
83
  collection_name = Path(list_file_path[0]).stem.replace(" ", "-")[:50]
84
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
85
  vector_db = create_db(doc_splits, collection_name)
86
  qa_chain = initialize_llmchain(
87
+ list_llm[0], # Using Mistral-7B-Instruct-v0.2 as the LLM model
88
+ 0.7, # Temperature
89
+ 1024, # Max Tokens
90
+ 3, # Top K
91
  vector_db,
92
+ db_progress
93
  )
94
  return vector_db, collection_name, qa_chain, "Complete!"
95
 
96
+ def upload_file(file_obj):
97
+ list_file_path = []
98
+ for file in file_obj:
99
+ if file is not None:
100
+ file_path = file.name
101
+ list_file_path.append(file_path)
102
+ return list_file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def demo():
105
  with gr.Blocks(theme="base") as demo:
106
  vector_db = gr.State()
 
107
  collection_name = gr.State()
108
+ qa_chain = gr.State()
109
+
110
+ with gr.Tab("Step 1 - Document pre-processing"):
111
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
112
+ slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
113
+ slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
114
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
115
+ db_btn = gr.Button("Generate vector database...")
116
+
117
+ with gr.Tab("Step 2 - QA chain initialization"):
118
+ llm_progress = gr.Textbox(value="None", label="QA chain initialization")
119
+ qachain_btn = gr.Button("Initialize question-answering chain...")
120
+
121
+ with gr.Tab("Step 3 - Conversation with chatbot"):
122
+ chatbot = gr.Chatbot(height=300)
123
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
124
+ source1_page = gr.Number(label="Page", scale=1)
125
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
126
+ source2_page = gr.Number(label="Page", scale=1)
127
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
128
+ source3_page = gr.Number(label="Page", scale=1)
129
+ msg = gr.Textbox(placeholder="Type message", container=True)
130
+ submit_btn = gr.Button("Submit")
131
+ clear_btn = gr.ClearButton([msg, chatbot])
132
+
133
+ document.upload(initialize_demo, inputs=[document, slider_chunk_size, slider_chunk_overlap, db_progress], outputs=[vector_db, collection_name, qa_chain, db_progress])
134
+ qachain_btn.click(initialize_llmchain, inputs=[qa_chain, llm_progress], outputs=[qa_chain, llm_progress])
135
+ submit_btn.click(lambda: None, inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2