HemaMeena commited on
Commit
6c1dd09
·
verified ·
1 Parent(s): d114adc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -89
app.py CHANGED
@@ -1,26 +1,42 @@
1
- import gradio as gr
2
- import time
3
  import os
4
  import glob
5
  import textwrap
 
6
 
7
- from transformers import (
8
- AutoTokenizer, AutoModelForCausalLM,
9
- BitsAndBytesConfig,
10
- pipeline
11
- )
12
  from langchain.document_loaders import PyPDFLoader, DirectoryLoader
 
 
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
 
 
 
14
  from langchain.vectorstores import FAISS
 
 
15
  from langchain.llms import HuggingFacePipeline
16
  from langchain.embeddings import HuggingFaceInstructEmbeddings
 
 
17
  from langchain.chains import RetrievalQA
18
- from langchain.prompts import PromptTemplate
19
 
20
- # Configuration class
 
 
 
 
 
 
 
 
 
21
  class CFG:
22
  # LLMs
23
- model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
24
  temperature = 0
25
  top_p = 0.95
26
  repetition_penalty = 1.15
@@ -36,103 +52,221 @@ class CFG:
36
  k = 6
37
 
38
  # paths
39
- PDFs_path = './' # Set to your PDF path
40
- Embeddings_path = './faiss-hp-sentence-transformers'
41
  Output_folder = './rag-vectordb'
42
 
43
- # Set preferred encoding to UTF-8 (for non-ASCII characters)
44
- import locale
45
- locale.getpreferredencoding = lambda: "UTF-8"
46
 
47
- # Function to get model
48
- def get_model(model = CFG.model_name):
49
  print('\nDownloading model: ', model, '\n\n')
50
-
51
  if model == 'wizardlm':
52
  model_repo = 'TheBloke/wizardLM-7B-HF'
53
 
54
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
 
55
  bnb_config = BitsAndBytesConfig(
56
- load_in_4bit=True,
57
- bnb_4bit_quant_type="nf4",
58
- bnb_4bit_compute_dtype=torch.float16,
59
- bnb_4bit_use_double_quant=True,
60
  )
61
 
62
  model = AutoModelForCausalLM.from_pretrained(
63
  model_repo,
64
- quantization_config=bnb_config,
65
- device_map='auto',
66
- low_cpu_mem_usage=True
67
  )
68
 
69
  max_len = 1024
70
 
71
  elif model == 'llama2-7b-chat':
72
  model_repo = 'daryl149/llama-2-7b-chat-hf'
 
73
  tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
74
 
75
  bnb_config = BitsAndBytesConfig(
76
- load_in_4bit=True,
77
- bnb_4bit_quant_type="nf4",
78
- bnb_4bit_compute_dtype=torch.float16,
79
- bnb_4bit_use_double_quant=True,
80
  )
81
 
82
  model = AutoModelForCausalLM.from_pretrained(
83
  model_repo,
84
- quantization_config=bnb_config,
85
- device_map='auto',
86
- low_cpu_mem_usage=True,
87
- trust_remote_code=True
88
  )
89
 
90
  max_len = 2048
91
 
92
  elif model == 'llama2-13b-chat':
93
  model_repo = 'daryl149/llama-2-13b-chat-hf'
 
94
  tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
95
 
96
  bnb_config = BitsAndBytesConfig(
97
- load_in_4bit=True,
98
- bnb_4bit_quant_type="nf4",
99
- bnb_4bit_compute_dtype=torch.float16,
100
- bnb_4bit_use_double_quant=True,
101
  )
102
 
103
  model = AutoModelForCausalLM.from_pretrained(
104
  model_repo,
105
- quantization_config=bnb_config,
106
- low_cpu_mem_usage=True,
107
- trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
109
 
110
  max_len = 2048
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  else:
113
- print("Model not implemented!")
114
 
115
  return tokenizer, model, max_len
116
 
117
- # Get the model
118
- tokenizer, model, max_len = get_model(CFG.model_name)
119
 
120
- # Set up Hugging Face pipeline
121
  pipe = pipeline(
122
- task="text-generation",
123
- model=model,
124
- tokenizer=tokenizer,
125
- pad_token_id=tokenizer.eos_token_id,
126
- max_length=max_len,
127
- temperature=CFG.temperature,
128
- top_p=CFG.top_p,
129
- repetition_penalty=CFG.repetition_penalty
 
130
  )
131
 
132
- # Langchain pipeline
133
- llm = HuggingFacePipeline(pipeline=pipe)
134
 
135
- # Load the documents
136
  loader = DirectoryLoader(
137
  CFG.PDFs_path,
138
  glob="./*.pdf",
@@ -140,25 +274,27 @@ loader = DirectoryLoader(
140
  show_progress=True,
141
  use_multithreading=True
142
  )
 
143
  documents = loader.load()
144
 
145
- # Split the documents
146
  text_splitter = RecursiveCharacterTextSplitter(
147
- chunk_size=CFG.split_chunk_size,
148
- chunk_overlap=CFG.split_overlap
149
  )
 
150
  texts = text_splitter.split_documents(documents)
151
 
152
- # Set up vector store
 
153
  vectordb = FAISS.from_documents(
154
  texts,
155
- HuggingFaceInstructEmbeddings(model_name=CFG.embeddings_model_repo)
156
  )
157
 
158
- # Save the vector store
159
- vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag")
 
160
 
161
- # Define the prompt template
162
  prompt_template = """
163
  Don't try to make up an answer, if you don't know just say that you don't know.
164
  Answer in the same language the question was asked.
@@ -169,34 +305,39 @@ Use only the following pieces of context to answer the question at the end.
169
  Question: {question}
170
  Answer:"""
171
 
 
172
  PROMPT = PromptTemplate(
173
- template=prompt_template,
174
- input_variables=["context", "question"]
175
  )
176
 
177
- # Set up retriever
178
- retriever = vectordb.as_retriever(search_kwargs={"k": CFG.k, "search_type": "similarity"})
179
 
180
- # Create the retrieval-based QA chain
181
  qa_chain = RetrievalQA.from_chain_type(
182
- llm=llm,
183
- chain_type="stuff", # other options: "map_reduce", "map_rerank", "refine"
184
- retriever=retriever,
185
- chain_type_kwargs={"prompt": PROMPT},
186
- return_source_documents=True,
187
- verbose=False
188
  )
189
 
190
- # Function to wrap text for proper display
191
  def wrap_text_preserve_newlines(text, width=700):
 
192
  lines = text.split('\n')
 
 
193
  wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
 
 
194
  wrapped_text = '\n'.join(wrapped_lines)
 
195
  return wrapped_text
196
 
197
- # Function to process model response
198
  def process_llm_response(llm_response):
199
  ans = wrap_text_preserve_newlines(llm_response['result'])
 
200
  sources_used = ' \n'.join(
201
  [
202
  source.metadata['source'].split('/')[-1][:-4]
@@ -205,31 +346,37 @@ def process_llm_response(llm_response):
205
  for source in llm_response['source_documents']
206
  ]
207
  )
 
208
  ans = ans + '\n\nSources: \n' + sources_used
209
  return ans
210
 
211
- # Function to get the answer from the model
212
  def llm_ans(query):
213
  start = time.time()
 
214
  llm_response = qa_chain.invoke(query)
215
  ans = process_llm_response(llm_response)
 
216
  end = time.time()
217
 
218
  time_elapsed = int(round(end - start, 0))
219
  time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
220
  return ans + time_elapsed_str
221
 
222
- # Function for Gradio chat interface
223
- def predict(message, history):
224
- output = str(llm_ans(message)).replace("\n", "<br/>")
225
- return output
226
 
227
- # Set up Gradio interface
228
- demo = gr.ChatInterface(
229
- fn=predict,
230
- title=f'Open-Source LLM ({CFG.model_name}) Question Answering'
231
- )
 
 
 
 
 
 
 
232
 
233
- # Start the Gradio interface
234
- demo.queue()
235
- demo.launch()
 
 
 
1
  import os
2
  import glob
3
  import textwrap
4
+ import time
5
 
6
+ import langchain
7
+
8
+ ### loaders
 
 
9
  from langchain.document_loaders import PyPDFLoader, DirectoryLoader
10
+
11
+ ### splits
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+
14
+ ### prompts
15
+ from langchain import PromptTemplate, LLMChain
16
+
17
+ ### vector stores
18
  from langchain.vectorstores import FAISS
19
+
20
+ ### models
21
  from langchain.llms import HuggingFacePipeline
22
  from langchain.embeddings import HuggingFaceInstructEmbeddings
23
+
24
+ ### retrievers
25
  from langchain.chains import RetrievalQA
 
26
 
27
+ import torch
28
+ import transformers
29
+ from transformers import (
30
+ AutoTokenizer, AutoModelForCausalLM,
31
+ BitsAndBytesConfig,
32
+ pipeline
33
+ )
34
+
35
+ sorted(glob.glob('/content/anatomy_vol_*'))
36
+
37
  class CFG:
38
  # LLMs
39
+ model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
40
  temperature = 0
41
  top_p = 0.95
42
  repetition_penalty = 1.15
 
52
  k = 6
53
 
54
  # paths
55
+ PDFs_path = '/content/'
56
+ Embeddings_path = '/content/faiss-hp-sentence-transformers'
57
  Output_folder = './rag-vectordb'
58
 
59
+ def get_model(model = CFG.model_name):
 
 
60
 
 
 
61
  print('\nDownloading model: ', model, '\n\n')
62
+
63
  if model == 'wizardlm':
64
  model_repo = 'TheBloke/wizardLM-7B-HF'
65
 
66
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
67
+
68
  bnb_config = BitsAndBytesConfig(
69
+ load_in_4bit = True,
70
+ bnb_4bit_quant_type = "nf4",
71
+ bnb_4bit_compute_dtype = torch.float16,
72
+ bnb_4bit_use_double_quant = True,
73
  )
74
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  model_repo,
77
+ quantization_config = bnb_config,
78
+ device_map = 'auto',
79
+ low_cpu_mem_usage = True
80
  )
81
 
82
  max_len = 1024
83
 
84
  elif model == 'llama2-7b-chat':
85
  model_repo = 'daryl149/llama-2-7b-chat-hf'
86
+
87
  tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
88
 
89
  bnb_config = BitsAndBytesConfig(
90
+ load_in_4bit = True,
91
+ bnb_4bit_quant_type = "nf4",
92
+ bnb_4bit_compute_dtype = torch.float16,
93
+ bnb_4bit_use_double_quant = True,
94
  )
95
 
96
  model = AutoModelForCausalLM.from_pretrained(
97
  model_repo,
98
+ quantization_config = bnb_config,
99
+ device_map = 'auto',
100
+ low_cpu_mem_usage = True,
101
+ trust_remote_code = True
102
  )
103
 
104
  max_len = 2048
105
 
106
  elif model == 'llama2-13b-chat':
107
  model_repo = 'daryl149/llama-2-13b-chat-hf'
108
+
109
  tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
110
 
111
  bnb_config = BitsAndBytesConfig(
112
+ load_in_4bit = True,
113
+ bnb_4bit_quant_type = "nf4",
114
+ bnb_4bit_compute_dtype = torch.float16,
115
+ bnb_4bit_use_double_quant = True,
116
  )
117
 
118
  model = AutoModelForCausalLM.from_pretrained(
119
  model_repo,
120
+ quantization_config = bnb_config,
121
+
122
+ low_cpu_mem_usage = True,
123
+ trust_remote_code = True
124
+ )
125
+
126
+ max_len = 2048 #8192
127
+ truncation=True, # Explicitly enable truncation
128
+ padding="max_len" # Optional: pad to max_length
129
+
130
+ elif model == 'mistral-7B':
131
+ model_repo = 'mistralai/Mistral-7B-v0.1'
132
+
133
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
134
+
135
+ bnb_config = BitsAndBytesConfig(
136
+ load_in_4bit = True,
137
+ bnb_4bit_quant_type = "nf4",
138
+ bnb_4bit_compute_dtype = torch.float16,
139
+ bnb_4bit_use_double_quant = True,
140
+ )
141
+
142
+ model = AutoModelForCausalLM.from_pretrained(
143
+ model_repo,
144
+ quantization_config = bnb_config,
145
+ device_map = 'auto',
146
+ low_cpu_mem_usage = True,
147
+ )
148
+
149
+ max_len = 1024
150
+
151
+ else:
152
+ print("Not implemented model (tokenizer and backbone)")
153
+
154
+ return tokenizer, model, max_len
155
+
156
+ def get_model(model = CFG.model_name):
157
+
158
+ print('\nDownloading model: ', model, '\n\n')
159
+
160
+ if model == 'wizardlm':
161
+ model_repo = 'TheBloke/wizardLM-7B-HF'
162
+
163
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
164
+
165
+ bnb_config = BitsAndBytesConfig(
166
+ load_in_4bit = True,
167
+ bnb_4bit_quant_type = "nf4",
168
+ bnb_4bit_compute_dtype = torch.float16,
169
+ bnb_4bit_use_double_quant = True,
170
+ )
171
+
172
+ model = AutoModelForCausalLM.from_pretrained(
173
+ model_repo,
174
+ quantization_config = bnb_config,
175
+ device_map = 'auto',
176
+ low_cpu_mem_usage = True
177
+ )
178
+
179
+ max_len = 1024
180
+
181
+ elif model == 'llama2-7b-chat':
182
+ model_repo = 'daryl149/llama-2-7b-chat-hf'
183
+
184
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
185
+
186
+ bnb_config = BitsAndBytesConfig(
187
+ load_in_4bit = True,
188
+ bnb_4bit_quant_type = "nf4",
189
+ bnb_4bit_compute_dtype = torch.float16,
190
+ bnb_4bit_use_double_quant = True,
191
+ )
192
+
193
+ model = AutoModelForCausalLM.from_pretrained(
194
+ model_repo,
195
+ quantization_config = bnb_config,
196
+ device_map = 'auto',
197
+ low_cpu_mem_usage = True,
198
+ trust_remote_code = True
199
  )
200
 
201
  max_len = 2048
202
 
203
+ elif model == 'llama2-13b-chat':
204
+ model_repo = 'daryl149/llama-2-13b-chat-hf'
205
+
206
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
207
+
208
+ bnb_config = BitsAndBytesConfig(
209
+ load_in_4bit = True,
210
+ bnb_4bit_quant_type = "nf4",
211
+ bnb_4bit_compute_dtype = torch.float16,
212
+ bnb_4bit_use_double_quant = True,
213
+ )
214
+
215
+ model = AutoModelForCausalLM.from_pretrained(
216
+ model_repo,
217
+ quantization_config = bnb_config,
218
+
219
+ low_cpu_mem_usage = True,
220
+ trust_remote_code = True
221
+ )
222
+
223
+ max_len = 2048 #8192
224
+ truncation=True, # Explicitly enable truncation
225
+ padding="max_len" # Optional: pad to max_length
226
+
227
+ elif model == 'mistral-7B':
228
+ model_repo = 'mistralai/Mistral-7B-v0.1'
229
+
230
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
231
+
232
+ bnb_config = BitsAndBytesConfig(
233
+ load_in_4bit = True,
234
+ bnb_4bit_quant_type = "nf4",
235
+ bnb_4bit_compute_dtype = torch.float16,
236
+ bnb_4bit_use_double_quant = True,
237
+ )
238
+
239
+ model = AutoModelForCausalLM.from_pretrained(
240
+ model_repo,
241
+ quantization_config = bnb_config,
242
+ device_map = 'auto',
243
+ low_cpu_mem_usage = True,
244
+ )
245
+
246
+ max_len = 1024
247
+
248
  else:
249
+ print("Not implemented model (tokenizer and backbone)")
250
 
251
  return tokenizer, model, max_len
252
 
253
+ tokenizer, model, max_len = get_model(model = CFG.model_name)
 
254
 
 
255
  pipe = pipeline(
256
+ task = "text-generation",
257
+ model = model,
258
+ tokenizer = tokenizer,
259
+ pad_token_id = tokenizer.eos_token_id,
260
+ # do_sample = True,
261
+ max_length = max_len,
262
+ temperature = CFG.temperature,
263
+ top_p = CFG.top_p,
264
+ repetition_penalty = CFG.repetition_penalty
265
  )
266
 
267
+ ### langchain pipeline
268
+ llm = HuggingFacePipeline(pipeline = pipe)
269
 
 
270
  loader = DirectoryLoader(
271
  CFG.PDFs_path,
272
  glob="./*.pdf",
 
274
  show_progress=True,
275
  use_multithreading=True
276
  )
277
+
278
  documents = loader.load()
279
 
 
280
  text_splitter = RecursiveCharacterTextSplitter(
281
+ chunk_size = CFG.split_chunk_size,
282
+ chunk_overlap = CFG.split_overlap
283
  )
284
+
285
  texts = text_splitter.split_documents(documents)
286
 
287
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
288
+
289
  vectordb = FAISS.from_documents(
290
  texts,
291
+ HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
292
  )
293
 
294
+ ### persist vector database
295
+ vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag") # save in output folder
296
+ # vectordb.save_local(f"{CFG.Embeddings_path}/faiss_index_hp") # save in input folder
297
 
 
298
  prompt_template = """
299
  Don't try to make up an answer, if you don't know just say that you don't know.
300
  Answer in the same language the question was asked.
 
305
  Question: {question}
306
  Answer:"""
307
 
308
+
309
  PROMPT = PromptTemplate(
310
+ template = prompt_template,
311
+ input_variables = ["context", "question"]
312
  )
313
 
314
+ retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"})
 
315
 
 
316
  qa_chain = RetrievalQA.from_chain_type(
317
+ llm = llm,
318
+ chain_type = "stuff", # map_reduce, map_rerank, stuff, refine
319
+ retriever = retriever,
320
+ chain_type_kwargs = {"prompt": PROMPT},
321
+ return_source_documents = True,
322
+ verbose = False
323
  )
324
 
 
325
  def wrap_text_preserve_newlines(text, width=700):
326
+ # Split the input text into lines based on newline characters
327
  lines = text.split('\n')
328
+
329
+ # Wrap each line individually
330
  wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
331
+
332
+ # Join the wrapped lines back together using newline characters
333
  wrapped_text = '\n'.join(wrapped_lines)
334
+
335
  return wrapped_text
336
 
337
+
338
  def process_llm_response(llm_response):
339
  ans = wrap_text_preserve_newlines(llm_response['result'])
340
+
341
  sources_used = ' \n'.join(
342
  [
343
  source.metadata['source'].split('/')[-1][:-4]
 
346
  for source in llm_response['source_documents']
347
  ]
348
  )
349
+
350
  ans = ans + '\n\nSources: \n' + sources_used
351
  return ans
352
 
 
353
  def llm_ans(query):
354
  start = time.time()
355
+
356
  llm_response = qa_chain.invoke(query)
357
  ans = process_llm_response(llm_response)
358
+
359
  end = time.time()
360
 
361
  time_elapsed = int(round(end - start, 0))
362
  time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
363
  return ans + time_elapsed_str
364
 
365
+ import locale
366
+ locale.getpreferredencoding = lambda: "UTF-8"
 
 
367
 
368
+ import gradio as gr
369
+
370
+ def predict(message, history):
371
+ # output = message # debug mode
372
+
373
+ output = str(llm_ans(message)).replace("\n", "<br/>")
374
+ return output
375
+
376
+ demo = gr.ChatInterface(
377
+ predict,
378
+ title = f' Open-Source LLM ({CFG.model_name}) Question Answering'
379
+ )
380
 
381
+ demo.queue()
382
+ demo.launch()