vishwask commited on
Commit
2ed5c8a
·
verified ·
1 Parent(s): 036be0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -102
app.py CHANGED
@@ -21,12 +21,16 @@ import tqdm
21
  import accelerate
22
 
23
 
24
- # default_persist_directory = './chroma_HF/'
25
 
26
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.2"]
27
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
 
 
 
 
 
28
 
29
- # Load PDF document and create doc splits
30
  def load_doc(list_file_path, chunk_size, chunk_overlap):
31
  # Processing for one document only
32
  # loader = PyPDFLoader(file_path)
@@ -55,7 +59,6 @@ def create_db(splits, collection_name):
55
  )
56
  return vectordb
57
 
58
-
59
  # Load vector database
60
  def load_db():
61
  embedding = HuggingFaceEmbeddings()
@@ -64,99 +67,38 @@ def load_db():
64
  embedding_function=embedding)
65
  return vectordb
66
 
67
-
68
  # Initialize langchain LLM chain
69
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
70
- progress(0.1, desc="Initializing HF tokenizer...")
71
-
72
- # HuggingFaceHub uses HF inference endpoints
73
- progress(0.5, desc="Initializing HF Hub...")
 
74
 
75
- # Use of trust_remote_code as model_kwargs
76
- # Warning: langchain issue
77
- # URL: https://github.com/langchain-ai/langchain/issues/6080
78
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
79
- llm = HuggingFaceHub(
80
- repo_id=llm_model,
81
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
82
- )
83
- elif llm_model == "microsoft/phi-2":
84
- raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
85
- llm = HuggingFaceHub(
86
- repo_id=llm_model,
87
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
88
- )
89
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
90
- llm = HuggingFaceHub(
91
- repo_id=llm_model,
92
- model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
93
- )
94
- elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
95
- raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
96
- llm = HuggingFaceHub(
97
- repo_id=llm_model,
98
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
99
- )
100
- else:
101
- llm = HuggingFaceHub(
102
- repo_id=llm_model,
103
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
104
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
105
- )
106
-
107
- progress(0.75, desc="Defining buffer memory...")
108
- memory = ConversationBufferMemory(
109
- memory_key="chat_history",
110
- output_key='answer',
111
- return_messages=True
112
- )
113
- # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
114
  retriever=vector_db.as_retriever()
115
- progress(0.8, desc="Defining retrieval chain...")
116
  qa_chain = ConversationalRetrievalChain.from_llm(
117
  llm,
118
  retriever=retriever,
119
  chain_type="stuff",
120
  memory=memory,
121
- # combine_docs_chain_kwargs={"prompt": your_prompt})
122
  return_source_documents=True,
123
- #return_generated_question=False,
124
  verbose=False,
125
  )
126
- progress(0.9, desc="Done!")
127
  return qa_chain
128
 
129
- def start(llm_model, temperature, max_tokens, top_k,
130
- vector_db, list_file_obj, chunk_size, chunk_overlap,
131
- qa_chain, message, history):
132
- # HuggingFaceHub uses HF inference endpoints
133
- # Use of trust_remote_code as model_kwargs
134
- # Warning: langchain issue
135
- # URL: https://github.com/langchain-ai/langchain/issues/6080
136
- llm = HuggingFaceHub(repo_id=llm_model, model_kwargs={"temperature": temperature,
137
- "max_new_tokens": max_tokens,
138
- "top_k": top_k,
139
- "load_in_8bit": True})
140
- memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer',return_messages=True)
141
 
142
- retriever=vector_db.as_retriever()
143
- qa_chain = ConversationalRetrievalChain.from_llm(
144
- llm,
145
- retriever=retriever,
146
- chain_type="stuff",
147
- memory=memory,
148
- # combine_docs_chain_kwargs={"prompt": your_prompt})
149
- return_source_documents=True,
150
- #return_generated_question=False,
151
- verbose=False,
152
- )
153
 
 
 
 
 
154
  # Create list of documents (when valid)
155
  list_file_path = [x.name for x in list_file_obj if x is not None]
156
-
157
  # Create collection_name for vector database
 
158
  collection_name = Path(list_file_path[0]).stem
159
-
160
  # Fix potential issues from naming convention
161
  ## Remove space
162
  collection_name = collection_name.replace(" ","-")
@@ -169,13 +111,33 @@ def start(llm_model, temperature, max_tokens, top_k,
169
  collection_name[-1] = 'Z'
170
  # print('list_file_path: ', list_file_path)
171
  print('Collection name: ', collection_name)
172
-
173
  # Load document and create splits
174
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
175
-
176
  # Create or load vector database
 
 
177
  vector_db = create_db(doc_splits, collection_name)
 
 
 
 
 
 
 
 
 
 
178
 
 
 
 
 
 
 
 
 
 
179
  formatted_chat_history = format_chat_history(message, history)
180
  #print("formatted_chat_history",formatted_chat_history)
181
 
@@ -197,17 +159,22 @@ def start(llm_model, temperature, max_tokens, top_k,
197
 
198
  # Append user message and response to chat history
199
  new_history = history + [(message, response_answer)]
 
 
200
 
201
- return qa_chain, vector_db, collection_name, new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
202
-
 
 
 
203
  def demo():
204
- with gr.Blocks(theme="base") as demo:
205
  vector_db = gr.State()
206
  qa_chain = gr.State()
207
  collection_name = gr.State()
208
-
209
  chatbot = gr.Chatbot(height=300)
210
- with gr.Accordion("Advanced - Document references", open=False):
211
  with gr.Row():
212
  doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
213
  source1_page = gr.Number(label="Page", scale=1)
@@ -218,19 +185,18 @@ def demo():
218
  doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
219
  source3_page = gr.Number(label="Page", scale=1)
220
  with gr.Row():
221
- msg = gr.Textbox(placeholder="Type message", container=True)
222
  with gr.Row():
223
- submit_btn = gr.Button("Submit")
224
- clear_btn = gr.ClearButton([msg, chatbot])
225
-
226
- msg.submit(start,
227
- inputs=[llm_model, temperature, max_tokens, top_k,
228
- vector_db, list_file_obj, chunk_size, chunk_overlap,
229
- qa_chain, message, history],
230
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page,
231
- doc_source2, source2_page,
232
- doc_source3, source3_page],
233
- queue=False)
234
  submit_btn.click(conversation, \
235
  inputs=[qa_chain, msg, chatbot], \
236
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
@@ -239,8 +205,5 @@ def demo():
239
  inputs=None, \
240
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
241
  queue=False)
242
-
243
  demo.queue().launch(debug=True)
244
-
245
- if __name__ == "__main__":
246
- demo()
 
21
  import accelerate
22
 
23
 
24
+ #Set parameters
25
 
26
+ llm_model = 'mistralai/Mixtral-8x7B-Instruct-v0.1'
27
+ list_file_path = '/home/niti/something'
28
+ chunk_size = 1024
29
+ chunk_overlap = 128
30
+ temperature = 0.1
31
+ max_tokens = 6000
32
+ top_k = 3
33
 
 
34
  def load_doc(list_file_path, chunk_size, chunk_overlap):
35
  # Processing for one document only
36
  # loader = PyPDFLoader(file_path)
 
59
  )
60
  return vectordb
61
 
 
62
  # Load vector database
63
  def load_db():
64
  embedding = HuggingFaceEmbeddings()
 
67
  embedding_function=embedding)
68
  return vectordb
69
 
 
70
  # Initialize langchain LLM chain
71
+ def initialize_llmchain(vector_db):
72
+ llm = HuggingFaceHub(repo_id = llm_model,
73
+ model_kwargs={"temperature": temperature,
74
+ "max_new_tokens": max_tokens,
75
+ "top_k": top_k,
76
+ "load_in_8bit": True})
77
 
78
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  retriever=vector_db.as_retriever()
 
80
  qa_chain = ConversationalRetrievalChain.from_llm(
81
  llm,
82
  retriever=retriever,
83
  chain_type="stuff",
84
  memory=memory,
 
85
  return_source_documents=True,
 
86
  verbose=False,
87
  )
 
88
  return qa_chain
89
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ vector_db, collection_name = initialize_database()
 
 
 
 
 
 
 
 
 
 
92
 
93
+ #list_file_obj = document
94
+
95
+ # Initialize database
96
+ def initialize_database(list_file_obj):
97
  # Create list of documents (when valid)
98
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
99
  # Create collection_name for vector database
100
+ progress(0.1, desc="Creating collection name...")
101
  collection_name = Path(list_file_path[0]).stem
 
102
  # Fix potential issues from naming convention
103
  ## Remove space
104
  collection_name = collection_name.replace(" ","-")
 
111
  collection_name[-1] = 'Z'
112
  # print('list_file_path: ', list_file_path)
113
  print('Collection name: ', collection_name)
114
+ progress(0.25, desc="Loading document...")
115
  # Load document and create splits
116
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
 
117
  # Create or load vector database
118
+ progress(0.5, desc="Generating vector database...")
119
+ # global vector_db
120
  vector_db = create_db(doc_splits, collection_name)
121
+ progress(0.9, desc="Done!")
122
+ return vector_db, collection_name
123
+
124
+
125
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
126
+ # print("llm_option",llm_option)
127
+ llm_name = llm_model
128
+ qa_chain = initialize_llmchain(llm_name, temperature, max_tokens, top_k, vector_db)
129
+ return qa_chain
130
+
131
 
132
+ def format_chat_history(message, chat_history):
133
+ formatted_chat_history = []
134
+ for user_message, bot_message in chat_history:
135
+ formatted_chat_history.append(f"User: {user_message}")
136
+ formatted_chat_history.append(f"Assistant: {bot_message}")
137
+ return formatted_chat_history
138
+
139
+
140
+ def conversation(qa_chain, message, history):
141
  formatted_chat_history = format_chat_history(message, history)
142
  #print("formatted_chat_history",formatted_chat_history)
143
 
 
159
 
160
  # Append user message and response to chat history
161
  new_history = history + [(message, response_answer)]
162
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
163
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
164
 
165
+ document = os.listdir(list_file_path)
166
+ vector_db, collection_name = initialize_database(document)
167
+ qa_chain = initialize_LLM(vector_db)
168
+
169
+
170
  def demo():
171
+ with gr.Blocks(theme='base') as demo:
172
  vector_db = gr.State()
173
  qa_chain = gr.State()
174
  collection_name = gr.State()
175
+
176
  chatbot = gr.Chatbot(height=300)
177
+ with gr.Accordion('References', open=True):
178
  with gr.Row():
179
  doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
180
  source1_page = gr.Number(label="Page", scale=1)
 
185
  doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
186
  source3_page = gr.Number(label="Page", scale=1)
187
  with gr.Row():
188
+ msg = gr.Textbox(placeholder = 'Ask your question', container = True)
189
  with gr.Row():
190
+ submit_btn = gr.Button('Submit')
191
+ clear_button = gr.ClearButton([msg, chatbot])
192
+
193
+
194
+
195
+
196
+ msg.submit(conversation, \
197
+ inputs=[qa_chain, msg, chatbot], \
198
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
199
+ queue=False)
 
200
  submit_btn.click(conversation, \
201
  inputs=[qa_chain, msg, chatbot], \
202
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
 
205
  inputs=None, \
206
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
207
  queue=False)
 
208
  demo.queue().launch(debug=True)
209
+