vishwask commited on
Commit
bdc42f3
·
verified ·
1 Parent(s): 626e763

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -111
app.py CHANGED
@@ -20,9 +20,16 @@ import torch
20
  import tqdm
21
  import accelerate
22
 
 
 
 
 
 
 
 
23
  def load_doc(list_file_path, chunk_size, chunk_overlap):
24
  # Processing for one document only
25
- # loader = Py PDFLoader(file_path)
26
  # pages = loader.load()
27
  loaders = [PyPDFLoader(x) for x in list_file_path]
28
  pages = []
@@ -35,7 +42,6 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
35
  doc_splits = text_splitter.split_documents(pages)
36
  return doc_splits
37
 
38
-
39
  # Create vector database
40
  def create_db(splits, collection_name):
41
  embedding = HuggingFaceEmbeddings()
@@ -49,6 +55,7 @@ def create_db(splits, collection_name):
49
  )
50
  return vectordb
51
 
 
52
  # Load vector database
53
  def load_db():
54
  embedding = HuggingFaceEmbeddings()
@@ -57,38 +64,95 @@ def load_db():
57
  embedding_function=embedding)
58
  return vectordb
59
 
 
 
60
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
61
  progress(0.1, desc="Initializing HF tokenizer...")
62
 
63
  # HuggingFaceHub uses HF inference endpoints
64
  progress(0.5, desc="Initializing HF Hub...")
 
65
  # Use of trust_remote_code as model_kwargs
66
  # Warning: langchain issue
67
  # URL: https://github.com/langchain-ai/langchain/issues/6080
68
-
69
- llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1"
70
- llm = HuggingFaceHub(repo_id=llm_model, model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True})
71
- progress(0.75, desc="Defining buffer memory...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer',return_messages=True )
74
-
 
 
 
 
75
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
76
  retriever=vector_db.as_retriever()
77
-
78
  progress(0.8, desc="Defining retrieval chain...")
79
-
80
- qa_chain = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,chain_type="stuff", memory=memory,return_source_documents=True,verbose=False,)
 
 
 
 
 
 
 
 
81
  progress(0.9, desc="Done!")
82
  return qa_chain
83
 
84
- # Initialize database
85
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # Create list of documents (when valid)
88
  list_file_path = [x.name for x in list_file_obj if x is not None]
89
 
90
  # Create collection_name for vector database
91
- progress(0.1, desc="Creating collection name...")
92
  collection_name = Path(list_file_path[0]).stem
93
 
94
  # Fix potential issues from naming convention
@@ -96,112 +160,21 @@ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Pr
96
  collection_name = collection_name.replace(" ","-")
97
  ## Limit lenght to 50 characters
98
  collection_name = collection_name[:50]
99
-
100
  ## Enforce start and end as alphanumeric character
101
  if not collection_name[0].isalnum():
102
  collection_name[0] = 'A'
103
  if not collection_name[-1].isalnum():
104
  collection_name[-1] = 'Z'
105
-
106
  # print('list_file_path: ', list_file_path)
107
  print('Collection name: ', collection_name)
108
- progress(0.25, desc="Loading document...")
109
-
110
  # Load document and create splits
111
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
112
 
113
  # Create or load vector database
114
- progress(0.5, desc="Generating vector database...")
115
-
116
- # global vector_db
117
  vector_db = create_db(doc_splits, collection_name)
118
- progress(0.9, desc="Done!")
119
- return vector_db, collection_name, "Complete!"
120
-
121
-
122
-
123
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
124
- llm_name = list_llm[llm_option]
125
- print("llm_name: ",llm_name)
126
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
127
- return qa_chain, "Complete!"
128
-
129
- def format_chat_history(message, chat_history):
130
- formatted_chat_history = []
131
- for user_message, bot_message in chat_history:
132
- formatted_chat_history.append(f"User: {user_message}")
133
- formatted_chat_history.append(f"Assistant: {bot_message}")
134
- return formatted_chat_history
135
-
136
-
137
- def conversation(qa_chain, message, history):
138
- formatted_chat_history = format_chat_history(message, history)
139
- #print("formatted_chat_history",formatted_chat_history)
140
-
141
- # Generate response using QA chain
142
- response = qa_chain({"question": message, "chat_history": formatted_chat_history})
143
- response_answer = response["answer"]
144
- if response_answer.find("Helpful Answer:") != -1:
145
- response_answer = response_answer.split("Helpful Answer:")[-1]
146
- response_sources = response["source_documents"]
147
- response_source1 = response_sources[0].page_content.strip()
148
- response_source2 = response_sources[1].page_content.strip()
149
- response_source3 = response_sources[2].page_content.strip()
150
- # Langchain sources are zero-based
151
- response_source1_page = response_sources[0].metadata["page"] + 1
152
- response_source2_page = response_sources[1].metadata["page"] + 1
153
- response_source3_page = response_sources[2].metadata["page"] + 1
154
- # print ('chat response: ', response_answer)
155
- # print('DB source', response_sources)
156
 
157
- # Append user message and response to chat history
158
- new_history = history + [(message, response_answer)]
159
- # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
160
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
161
-
162
-
163
-
164
- def demo():
165
- with gr.Blocks(theme="base") as demo:
166
- vector_db = gr.State()
167
- qa_chain = gr.State()
168
- collection_name = gr.State()
169
-
170
- gr.Markdown("""RAG USING MIXTRAL""")
171
-
172
- with gr.Row():
173
- chatbot = gr.Chatbot(height=300)
174
- with gr.Accordion("Advanced - Document references", open=False):
175
- with gr.Row():
176
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
177
- source1_page = gr.Number(label="Page", scale=1)
178
- with gr.Row():
179
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
180
- source2_page = gr.Number(label="Page", scale=1)
181
- with gr.Row():
182
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
183
- source3_page = gr.Number(label="Page", scale=1)
184
- with gr.Row():
185
- msg = gr.Textbox(placeholder="Type message", container=True)
186
- with gr.Row():
187
- submit_btn = gr.Button("Submit")
188
- clear_btn = gr.ClearButton([msg, chatbot])
189
-
190
- # Chatbot events
191
- submit_btn.click(conversation, \
192
- inputs=[qa_chain, msg, chatbot], \
193
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
194
- queue=False)
195
- clear_btn.click(lambda:[None,"",0,"",0,"",0], \
196
- inputs=None, \
197
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
198
- queue=False)
199
- demo.queue().launch(debug=True)
200
-
201
- # Replace the placeholders with your actual functions
202
- def conversation(*args, **kwargs):
203
- pass
204
-
205
-
206
-
207
-
 
20
  import tqdm
21
  import accelerate
22
 
23
+
24
+ # default_persist_directory = './chroma_HF/'
25
+
26
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2"]
27
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
28
+
29
+ # Load PDF document and create doc splits
30
  def load_doc(list_file_path, chunk_size, chunk_overlap):
31
  # Processing for one document only
32
+ # loader = PyPDFLoader(file_path)
33
  # pages = loader.load()
34
  loaders = [PyPDFLoader(x) for x in list_file_path]
35
  pages = []
 
42
  doc_splits = text_splitter.split_documents(pages)
43
  return doc_splits
44
 
 
45
  # Create vector database
46
  def create_db(splits, collection_name):
47
  embedding = HuggingFaceEmbeddings()
 
55
  )
56
  return vectordb
57
 
58
+
59
  # Load vector database
60
  def load_db():
61
  embedding = HuggingFaceEmbeddings()
 
64
  embedding_function=embedding)
65
  return vectordb
66
 
67
+
68
+ # Initialize langchain LLM chain
69
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
70
  progress(0.1, desc="Initializing HF tokenizer...")
71
 
72
  # HuggingFaceHub uses HF inference endpoints
73
  progress(0.5, desc="Initializing HF Hub...")
74
+
75
  # Use of trust_remote_code as model_kwargs
76
  # Warning: langchain issue
77
  # URL: https://github.com/langchain-ai/langchain/issues/6080
78
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
79
+ llm = HuggingFaceHub(
80
+ repo_id=llm_model,
81
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
82
+ )
83
+ elif llm_model == "microsoft/phi-2":
84
+ raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
85
+ llm = HuggingFaceHub(
86
+ repo_id=llm_model,
87
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
88
+ )
89
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
90
+ llm = HuggingFaceHub(
91
+ repo_id=llm_model,
92
+ model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
93
+ )
94
+ elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
95
+ raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
96
+ llm = HuggingFaceHub(
97
+ repo_id=llm_model,
98
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
99
+ )
100
+ else:
101
+ llm = HuggingFaceHub(
102
+ repo_id=llm_model,
103
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
104
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
105
+ )
106
 
107
+ progress(0.75, desc="Defining buffer memory...")
108
+ memory = ConversationBufferMemory(
109
+ memory_key="chat_history",
110
+ output_key='answer',
111
+ return_messages=True
112
+ )
113
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
114
  retriever=vector_db.as_retriever()
 
115
  progress(0.8, desc="Defining retrieval chain...")
116
+ qa_chain = ConversationalRetrievalChain.from_llm(
117
+ llm,
118
+ retriever=retriever,
119
+ chain_type="stuff",
120
+ memory=memory,
121
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
122
+ return_source_documents=True,
123
+ #return_generated_question=False,
124
+ verbose=False,
125
+ )
126
  progress(0.9, desc="Done!")
127
  return qa_chain
128
 
129
+ def start(llm_model, temperature, max_tokens, top_k, vector_db, list_file_obj, chunk_size, chunk_overlap):
130
+ # HuggingFaceHub uses HF inference endpoints
131
+ # Use of trust_remote_code as model_kwargs
132
+ # Warning: langchain issue
133
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
134
+ llm = HuggingFaceHub(repo_id=llm_model, model_kwargs={"temperature": temperature,
135
+ "max_new_tokens": max_tokens,
136
+ "top_k": top_k,
137
+ "load_in_8bit": True})
138
+ memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer',return_messages=True)
139
+
140
+ retriever=vector_db.as_retriever()
141
+ qa_chain = ConversationalRetrievalChain.from_llm(
142
+ llm,
143
+ retriever=retriever,
144
+ chain_type="stuff",
145
+ memory=memory,
146
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
147
+ return_source_documents=True,
148
+ #return_generated_question=False,
149
+ verbose=False,
150
+ )
151
+
152
  # Create list of documents (when valid)
153
  list_file_path = [x.name for x in list_file_obj if x is not None]
154
 
155
  # Create collection_name for vector database
 
156
  collection_name = Path(list_file_path[0]).stem
157
 
158
  # Fix potential issues from naming convention
 
160
  collection_name = collection_name.replace(" ","-")
161
  ## Limit lenght to 50 characters
162
  collection_name = collection_name[:50]
 
163
  ## Enforce start and end as alphanumeric character
164
  if not collection_name[0].isalnum():
165
  collection_name[0] = 'A'
166
  if not collection_name[-1].isalnum():
167
  collection_name[-1] = 'Z'
 
168
  # print('list_file_path: ', list_file_path)
169
  print('Collection name: ', collection_name)
170
+
 
171
  # Load document and create splits
172
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
173
 
174
  # Create or load vector database
 
 
 
175
  vector_db = create_db(doc_splits, collection_name)
176
+
177
+
178
+ return qa_chain, vector_db, collection_name
179
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180