mariagrandury commited on
Commit
bed03be
·
1 Parent(s): 18c8a5b

order imports, remove special case models and remove comments

Browse files
Files changed (1) hide show
  1. app.py +20 -86
app.py CHANGED
@@ -1,29 +1,23 @@
1
- import gradio as gr
2
  import os
3
-
4
- from langchain_community.document_loaders import PyPDFLoader
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.vectorstores import Chroma
7
- from langchain.chains import ConversationalRetrievalChain
8
- from langchain_community.embeddings import HuggingFaceEmbeddings
9
- from langchain_community.llms import HuggingFacePipeline
10
- from langchain.chains import ConversationChain
11
- from langchain.memory import ConversationBufferMemory
12
- from langchain_community.llms import HuggingFaceEndpoint
13
-
14
  from pathlib import Path
15
- import chromadb
16
- from unidecode import unidecode
17
 
18
- from transformers import AutoTokenizer
19
- import transformers
 
20
  import torch
21
  import tqdm
22
- import accelerate
23
- import re
24
-
 
 
 
 
 
 
 
25
 
26
- # default_persist_directory = './chroma_HF/'
27
  list_llm = [
28
  "mistralai/Mistral-7B-Instruct-v0.2",
29
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
@@ -31,8 +25,6 @@ list_llm = [
31
  "google/gemma-7b-it",
32
  "google/gemma-2b-it",
33
  "HuggingFaceH4/zephyr-7b-beta",
34
- "microsoft/phi-2",
35
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
36
  "tiiuae/falcon-7b-instruct",
37
  "google/flan-t5-xxl",
38
  ]
@@ -65,7 +57,6 @@ def create_db(splits, collection_name):
65
  embedding=embedding,
66
  client=new_client,
67
  collection_name=collection_name,
68
- # persist_directory=default_persist_directory
69
  )
70
  return vectordb
71
 
@@ -73,10 +64,7 @@ def create_db(splits, collection_name):
73
  # Load vector database
74
  def load_db():
75
  embedding = HuggingFaceEmbeddings()
76
- vectordb = Chroma(
77
- # persist_directory=default_persist_directory,
78
- embedding_function=embedding
79
- )
80
  return vectordb
81
 
82
 
@@ -85,64 +73,20 @@ def initialize_llmchain(
85
  llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
86
  ):
87
  progress(0.1, desc="Initializing HF tokenizer...")
88
- # HuggingFacePipeline uses local model
89
- # Note: it will download model locally...
90
- # tokenizer=AutoTokenizer.from_pretrained(llm_model)
91
- # progress(0.5, desc="Initializing HF pipeline...")
92
- # pipeline=transformers.pipeline(
93
- # "text-generation",
94
- # model=llm_model,
95
- # tokenizer=tokenizer,
96
- # torch_dtype=torch.bfloat16,
97
- # trust_remote_code=True,
98
- # device_map="auto",
99
- # # max_length=1024,
100
- # max_new_tokens=max_tokens,
101
- # do_sample=True,
102
- # top_k=top_k,
103
- # num_return_sequences=1,
104
- # eos_token_id=tokenizer.eos_token_id
105
- # )
106
- # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
107
 
108
  # HuggingFaceHub uses HF inference endpoints
109
  progress(0.5, desc="Initializing HF Hub...")
110
- # Use of trust_remote_code as model_kwargs
111
- # Warning: langchain issue
112
- # URL: https://github.com/langchain-ai/langchain/issues/6080
113
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
114
  llm = HuggingFaceEndpoint(
115
  repo_id=llm_model,
116
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
117
  temperature=temperature,
118
  max_new_tokens=max_tokens,
119
  top_k=top_k,
120
  load_in_8bit=True,
121
  )
122
- elif llm_model == "microsoft/phi-2":
123
- # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
124
- llm = HuggingFaceEndpoint(
125
- repo_id=llm_model,
126
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
127
- temperature=temperature,
128
- max_new_tokens=max_tokens,
129
- top_k=top_k,
130
- trust_remote_code=True,
131
- torch_dtype="auto",
132
- )
133
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
134
- llm = HuggingFaceEndpoint(
135
- repo_id=llm_model,
136
- # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
137
- temperature=temperature,
138
- max_new_tokens=250,
139
- top_k=top_k,
140
- )
141
  else:
142
  llm = HuggingFaceEndpoint(
143
  repo_id=llm_model,
144
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
145
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
146
  temperature=temperature,
147
  max_new_tokens=max_tokens,
148
  top_k=top_k,
@@ -154,15 +98,14 @@ def initialize_llmchain(
154
  )
155
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
156
  retriever = vector_db.as_retriever()
 
157
  progress(0.8, desc="Defining retrieval chain...")
158
  qa_chain = ConversationalRetrievalChain.from_llm(
159
  llm,
160
  retriever=retriever,
161
  chain_type="stuff",
162
  memory=memory,
163
- # combine_docs_chain_kwargs={"prompt": your_prompt})
164
  return_source_documents=True,
165
- # return_generated_question=False,
166
  verbose=False,
167
  )
168
  progress(0.9, desc="Done!")
@@ -197,22 +140,19 @@ def create_collection_name(filepath):
197
  return collection_name
198
 
199
 
200
- # Initialize database
201
  def initialize_database(
202
  list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()
203
  ):
204
- # Create list of documents (when valid)
205
  list_file_path = [x.name for x in list_file_obj if x is not None]
206
- # Create collection_name for vector database
207
  progress(0.1, desc="Creating collection name...")
208
  collection_name = create_collection_name(list_file_path[0])
 
209
  progress(0.25, desc="Loading document...")
210
- # Load document and create splits
211
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
212
- # Create or load vector database
213
  progress(0.5, desc="Generating vector database...")
214
- # global vector_db
215
  vector_db = create_db(doc_splits, collection_name)
 
216
  progress(0.9, desc="Done!")
217
  return vector_db, collection_name, "Complete!"
218
 
@@ -220,7 +160,6 @@ def initialize_database(
220
  def initialize_LLM(
221
  llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
222
  ):
223
- # print("llm_option",llm_option)
224
  llm_name = list_llm[llm_option]
225
  print("llm_name: ", llm_name)
226
  qa_chain = initialize_llmchain(
@@ -239,7 +178,6 @@ def format_chat_history(message, chat_history):
239
 
240
  def conversation(qa_chain, message, history):
241
  formatted_chat_history = format_chat_history(message, history)
242
- # print("formatted_chat_history",formatted_chat_history)
243
 
244
  # Generate response using QA chain
245
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
@@ -247,15 +185,13 @@ def conversation(qa_chain, message, history):
247
  if response_answer.find("Helpful Answer:") != -1:
248
  response_answer = response_answer.split("Helpful Answer:")[-1]
249
  response_sources = response["source_documents"]
 
250
  response_source1 = response_sources[0].page_content.strip()
251
  response_source2 = response_sources[1].page_content.strip()
252
  response_source3 = response_sources[2].page_content.strip()
253
- # Langchain sources are zero-based
254
  response_source1_page = response_sources[0].metadata["page"] + 1
255
  response_source2_page = response_sources[1].metadata["page"] + 1
256
  response_source3_page = response_sources[2].metadata["page"] + 1
257
- # print ('chat response: ', response_answer)
258
- # print('DB source', response_sources)
259
 
260
  # Append user message and response to chat history
261
  new_history = history + [(message, response_answer)]
@@ -278,8 +214,6 @@ def upload_file(file_obj):
278
  for idx, file in enumerate(file_obj):
279
  file_path = file_obj.name
280
  list_file_path.append(file_path)
281
- # print(file_path)
282
- # initialize_database(file_path, progress)
283
  return list_file_path
284
 
285
 
 
 
1
  import os
2
+ import re
 
 
 
 
 
 
 
 
 
 
3
  from pathlib import Path
 
 
4
 
5
+ import accelerate
6
+ import chromadb
7
+ import gradio as gr
8
  import torch
9
  import tqdm
10
+ import transformers
11
+ from langchain.chains import ConversationalRetrievalChain, ConversationChain
12
+ from langchain.memory import ConversationBufferMemory
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain_community.document_loaders import PyPDFLoader
15
+ from langchain_community.embeddings import HuggingFaceEmbeddings
16
+ from langchain_community.llms import HuggingFaceEndpoint, HuggingFacePipeline
17
+ from langchain_community.vectorstores import Chroma
18
+ from transformers import AutoTokenizer
19
+ from unidecode import unidecode
20
 
 
21
  list_llm = [
22
  "mistralai/Mistral-7B-Instruct-v0.2",
23
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
 
25
  "google/gemma-7b-it",
26
  "google/gemma-2b-it",
27
  "HuggingFaceH4/zephyr-7b-beta",
 
 
28
  "tiiuae/falcon-7b-instruct",
29
  "google/flan-t5-xxl",
30
  ]
 
57
  embedding=embedding,
58
  client=new_client,
59
  collection_name=collection_name,
 
60
  )
61
  return vectordb
62
 
 
64
  # Load vector database
65
  def load_db():
66
  embedding = HuggingFaceEmbeddings()
67
+ vectordb = Chroma(embedding_function=embedding)
 
 
 
68
  return vectordb
69
 
70
 
 
73
  llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
74
  ):
75
  progress(0.1, desc="Initializing HF tokenizer...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # HuggingFaceHub uses HF inference endpoints
78
  progress(0.5, desc="Initializing HF Hub...")
 
 
 
79
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
80
  llm = HuggingFaceEndpoint(
81
  repo_id=llm_model,
 
82
  temperature=temperature,
83
  max_new_tokens=max_tokens,
84
  top_k=top_k,
85
  load_in_8bit=True,
86
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  else:
88
  llm = HuggingFaceEndpoint(
89
  repo_id=llm_model,
 
 
90
  temperature=temperature,
91
  max_new_tokens=max_tokens,
92
  top_k=top_k,
 
98
  )
99
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
100
  retriever = vector_db.as_retriever()
101
+
102
  progress(0.8, desc="Defining retrieval chain...")
103
  qa_chain = ConversationalRetrievalChain.from_llm(
104
  llm,
105
  retriever=retriever,
106
  chain_type="stuff",
107
  memory=memory,
 
108
  return_source_documents=True,
 
109
  verbose=False,
110
  )
111
  progress(0.9, desc="Done!")
 
140
  return collection_name
141
 
142
 
 
143
  def initialize_database(
144
  list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()
145
  ):
 
146
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
147
  progress(0.1, desc="Creating collection name...")
148
  collection_name = create_collection_name(list_file_path[0])
149
+
150
  progress(0.25, desc="Loading document...")
 
151
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
152
+
153
  progress(0.5, desc="Generating vector database...")
 
154
  vector_db = create_db(doc_splits, collection_name)
155
+
156
  progress(0.9, desc="Done!")
157
  return vector_db, collection_name, "Complete!"
158
 
 
160
  def initialize_LLM(
161
  llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
162
  ):
 
163
  llm_name = list_llm[llm_option]
164
  print("llm_name: ", llm_name)
165
  qa_chain = initialize_llmchain(
 
178
 
179
  def conversation(qa_chain, message, history):
180
  formatted_chat_history = format_chat_history(message, history)
 
181
 
182
  # Generate response using QA chain
183
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
 
185
  if response_answer.find("Helpful Answer:") != -1:
186
  response_answer = response_answer.split("Helpful Answer:")[-1]
187
  response_sources = response["source_documents"]
188
+ # Langchain sources are zero-based
189
  response_source1 = response_sources[0].page_content.strip()
190
  response_source2 = response_sources[1].page_content.strip()
191
  response_source3 = response_sources[2].page_content.strip()
 
192
  response_source1_page = response_sources[0].metadata["page"] + 1
193
  response_source2_page = response_sources[1].metadata["page"] + 1
194
  response_source3_page = response_sources[2].metadata["page"] + 1
 
 
195
 
196
  # Append user message and response to chat history
197
  new_history = history + [(message, response_answer)]
 
214
  for idx, file in enumerate(file_obj):
215
  file_path = file_obj.name
216
  list_file_path.append(file_path)
 
 
217
  return list_file_path
218
 
219