HemaMeena commited on
Commit
a9a5bcb
·
verified ·
1 Parent(s): 6c1dd09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -179
app.py CHANGED
@@ -36,7 +36,7 @@ sorted(glob.glob('/content/anatomy_vol_*'))
36
 
37
  class CFG:
38
  # LLMs
39
- model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
40
  temperature = 0
41
  top_p = 0.95
42
  repetition_penalty = 1.15
@@ -53,11 +53,12 @@ class CFG:
53
 
54
  # paths
55
  PDFs_path = '/content/'
56
- Embeddings_path = '/content/faiss-hp-sentence-transformers'
57
  Output_folder = './rag-vectordb'
58
 
59
- def get_model(model = CFG.model_name):
60
 
 
 
61
  print('\nDownloading model: ', model, '\n\n')
62
 
63
  if model == 'wizardlm':
@@ -66,17 +67,17 @@ class CFG:
66
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
67
 
68
  bnb_config = BitsAndBytesConfig(
69
- load_in_4bit = True,
70
- bnb_4bit_quant_type = "nf4",
71
- bnb_4bit_compute_dtype = torch.float16,
72
- bnb_4bit_use_double_quant = True,
73
  )
74
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  model_repo,
77
- quantization_config = bnb_config,
78
- device_map = 'auto',
79
- low_cpu_mem_usage = True
80
  )
81
 
82
  max_len = 1024
@@ -87,18 +88,18 @@ class CFG:
87
  tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
88
 
89
  bnb_config = BitsAndBytesConfig(
90
- load_in_4bit = True,
91
- bnb_4bit_quant_type = "nf4",
92
- bnb_4bit_compute_dtype = torch.float16,
93
- bnb_4bit_use_double_quant = True,
94
  )
95
 
96
  model = AutoModelForCausalLM.from_pretrained(
97
  model_repo,
98
- quantization_config = bnb_config,
99
- device_map = 'auto',
100
- low_cpu_mem_usage = True,
101
- trust_remote_code = True
102
  )
103
 
104
  max_len = 2048
@@ -109,23 +110,20 @@ class CFG:
109
  tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
110
 
111
  bnb_config = BitsAndBytesConfig(
112
- load_in_4bit = True,
113
- bnb_4bit_quant_type = "nf4",
114
- bnb_4bit_compute_dtype = torch.float16,
115
- bnb_4bit_use_double_quant = True,
116
  )
117
 
118
  model = AutoModelForCausalLM.from_pretrained(
119
  model_repo,
120
- quantization_config = bnb_config,
121
-
122
- low_cpu_mem_usage = True,
123
- trust_remote_code = True
124
  )
125
 
126
- max_len = 2048 #8192
127
- truncation=True, # Explicitly enable truncation
128
- padding="max_len" # Optional: pad to max_length
129
 
130
  elif model == 'mistral-7B':
131
  model_repo = 'mistralai/Mistral-7B-v0.1'
@@ -133,17 +131,17 @@ class CFG:
133
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
134
 
135
  bnb_config = BitsAndBytesConfig(
136
- load_in_4bit = True,
137
- bnb_4bit_quant_type = "nf4",
138
- bnb_4bit_compute_dtype = torch.float16,
139
- bnb_4bit_use_double_quant = True,
140
  )
141
 
142
  model = AutoModelForCausalLM.from_pretrained(
143
  model_repo,
144
- quantization_config = bnb_config,
145
- device_map = 'auto',
146
- low_cpu_mem_usage = True,
147
  )
148
 
149
  max_len = 1024
@@ -153,120 +151,26 @@ class CFG:
153
 
154
  return tokenizer, model, max_len
155
 
156
- def get_model(model = CFG.model_name):
157
-
158
- print('\nDownloading model: ', model, '\n\n')
159
-
160
- if model == 'wizardlm':
161
- model_repo = 'TheBloke/wizardLM-7B-HF'
162
-
163
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
164
-
165
- bnb_config = BitsAndBytesConfig(
166
- load_in_4bit = True,
167
- bnb_4bit_quant_type = "nf4",
168
- bnb_4bit_compute_dtype = torch.float16,
169
- bnb_4bit_use_double_quant = True,
170
- )
171
-
172
- model = AutoModelForCausalLM.from_pretrained(
173
- model_repo,
174
- quantization_config = bnb_config,
175
- device_map = 'auto',
176
- low_cpu_mem_usage = True
177
- )
178
-
179
- max_len = 1024
180
-
181
- elif model == 'llama2-7b-chat':
182
- model_repo = 'daryl149/llama-2-7b-chat-hf'
183
-
184
- tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
185
-
186
- bnb_config = BitsAndBytesConfig(
187
- load_in_4bit = True,
188
- bnb_4bit_quant_type = "nf4",
189
- bnb_4bit_compute_dtype = torch.float16,
190
- bnb_4bit_use_double_quant = True,
191
- )
192
-
193
- model = AutoModelForCausalLM.from_pretrained(
194
- model_repo,
195
- quantization_config = bnb_config,
196
- device_map = 'auto',
197
- low_cpu_mem_usage = True,
198
- trust_remote_code = True
199
- )
200
 
201
- max_len = 2048
202
-
203
- elif model == 'llama2-13b-chat':
204
- model_repo = 'daryl149/llama-2-13b-chat-hf'
205
-
206
- tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
207
-
208
- bnb_config = BitsAndBytesConfig(
209
- load_in_4bit = True,
210
- bnb_4bit_quant_type = "nf4",
211
- bnb_4bit_compute_dtype = torch.float16,
212
- bnb_4bit_use_double_quant = True,
213
- )
214
-
215
- model = AutoModelForCausalLM.from_pretrained(
216
- model_repo,
217
- quantization_config = bnb_config,
218
-
219
- low_cpu_mem_usage = True,
220
- trust_remote_code = True
221
- )
222
-
223
- max_len = 2048 #8192
224
- truncation=True, # Explicitly enable truncation
225
- padding="max_len" # Optional: pad to max_length
226
-
227
- elif model == 'mistral-7B':
228
- model_repo = 'mistralai/Mistral-7B-v0.1'
229
-
230
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
231
-
232
- bnb_config = BitsAndBytesConfig(
233
- load_in_4bit = True,
234
- bnb_4bit_quant_type = "nf4",
235
- bnb_4bit_compute_dtype = torch.float16,
236
- bnb_4bit_use_double_quant = True,
237
- )
238
-
239
- model = AutoModelForCausalLM.from_pretrained(
240
- model_repo,
241
- quantization_config = bnb_config,
242
- device_map = 'auto',
243
- low_cpu_mem_usage = True,
244
- )
245
-
246
- max_len = 1024
247
-
248
- else:
249
- print("Not implemented model (tokenizer and backbone)")
250
-
251
- return tokenizer, model, max_len
252
-
253
- tokenizer, model, max_len = get_model(model = CFG.model_name)
254
 
 
255
  pipe = pipeline(
256
- task = "text-generation",
257
- model = model,
258
- tokenizer = tokenizer,
259
- pad_token_id = tokenizer.eos_token_id,
260
- # do_sample = True,
261
- max_length = max_len,
262
- temperature = CFG.temperature,
263
- top_p = CFG.top_p,
264
- repetition_penalty = CFG.repetition_penalty
265
  )
266
 
267
- ### langchain pipeline
268
- llm = HuggingFacePipeline(pipeline = pipe)
269
 
 
270
  loader = DirectoryLoader(
271
  CFG.PDFs_path,
272
  glob="./*.pdf",
@@ -277,13 +181,15 @@ loader = DirectoryLoader(
277
 
278
  documents = loader.load()
279
 
 
280
  text_splitter = RecursiveCharacterTextSplitter(
281
- chunk_size = CFG.split_chunk_size,
282
- chunk_overlap = CFG.split_overlap
283
  )
284
 
285
  texts = text_splitter.split_documents(documents)
286
 
 
287
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
288
 
289
  vectordb = FAISS.from_documents(
@@ -291,10 +197,10 @@ vectordb = FAISS.from_documents(
291
  HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
292
  )
293
 
294
- ### persist vector database
295
- vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag") # save in output folder
296
- # vectordb.save_local(f"{CFG.Embeddings_path}/faiss_index_hp") # save in input folder
297
 
 
298
  prompt_template = """
299
  Don't try to make up an answer, if you don't know just say that you don't know.
300
  Answer in the same language the question was asked.
@@ -305,39 +211,35 @@ Use only the following pieces of context to answer the question at the end.
305
  Question: {question}
306
  Answer:"""
307
 
308
-
309
  PROMPT = PromptTemplate(
310
- template = prompt_template,
311
- input_variables = ["context", "question"]
312
  )
313
 
314
- retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"})
 
315
 
 
316
  qa_chain = RetrievalQA.from_chain_type(
317
- llm = llm,
318
- chain_type = "stuff", # map_reduce, map_rerank, stuff, refine
319
- retriever = retriever,
320
- chain_type_kwargs = {"prompt": PROMPT},
321
- return_source_documents = True,
322
- verbose = False
323
  )
324
 
 
325
  def wrap_text_preserve_newlines(text, width=700):
326
- # Split the input text into lines based on newline characters
327
  lines = text.split('\n')
328
-
329
- # Wrap each line individually
330
  wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
331
-
332
- # Join the wrapped lines back together using newline characters
333
  wrapped_text = '\n'.join(wrapped_lines)
334
-
335
  return wrapped_text
336
 
337
 
 
338
  def process_llm_response(llm_response):
339
  ans = wrap_text_preserve_newlines(llm_response['result'])
340
-
341
  sources_used = ' \n'.join(
342
  [
343
  source.metadata['source'].split('/')[-1][:-4]
@@ -346,10 +248,11 @@ def process_llm_response(llm_response):
346
  for source in llm_response['source_documents']
347
  ]
348
  )
349
-
350
  ans = ans + '\n\nSources: \n' + sources_used
351
  return ans
352
 
 
 
353
  def llm_ans(query):
354
  start = time.time()
355
 
@@ -357,26 +260,28 @@ def llm_ans(query):
357
  ans = process_llm_response(llm_response)
358
 
359
  end = time.time()
360
-
361
  time_elapsed = int(round(end - start, 0))
362
  time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
363
  return ans + time_elapsed_str
364
 
 
 
365
  import locale
366
  locale.getpreferredencoding = lambda: "UTF-8"
367
 
 
368
  import gradio as gr
369
 
370
- def predict(message, history):
371
- # output = message # debug mode
 
372
 
373
- output = str(llm_ans(message)).replace("\n", "<br/>")
374
- return output
375
-
376
- demo = gr.ChatInterface(
377
- predict,
378
- title = f' Open-Source LLM ({CFG.model_name}) Question Answering'
379
- )
380
 
381
- demo.queue()
382
- demo.launch()
 
36
 
37
  class CFG:
38
  # LLMs
39
+ model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
40
  temperature = 0
41
  top_p = 0.95
42
  repetition_penalty = 1.15
 
53
 
54
  # paths
55
  PDFs_path = '/content/'
56
+ Embeddings_path = '/content/faiss-hp-sentence-transformers'
57
  Output_folder = './rag-vectordb'
58
 
 
59
 
60
+ # Define model loading function
61
+ def get_model(model=CFG.model_name):
62
  print('\nDownloading model: ', model, '\n\n')
63
 
64
  if model == 'wizardlm':
 
67
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
68
 
69
  bnb_config = BitsAndBytesConfig(
70
+ load_in_4bit=True,
71
+ bnb_4bit_quant_type="nf4",
72
+ bnb_4bit_compute_dtype=torch.float16,
73
+ bnb_4bit_use_double_quant=True,
74
  )
75
 
76
  model = AutoModelForCausalLM.from_pretrained(
77
  model_repo,
78
+ quantization_config=bnb_config,
79
+ device_map='auto',
80
+ low_cpu_mem_usage=True
81
  )
82
 
83
  max_len = 1024
 
88
  tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
89
 
90
  bnb_config = BitsAndBytesConfig(
91
+ load_in_4bit=True,
92
+ bnb_4bit_quant_type="nf4",
93
+ bnb_4bit_compute_dtype=torch.float16,
94
+ bnb_4bit_use_double_quant=True,
95
  )
96
 
97
  model = AutoModelForCausalLM.from_pretrained(
98
  model_repo,
99
+ quantization_config=bnb_config,
100
+ device_map='auto',
101
+ low_cpu_mem_usage=True,
102
+ trust_remote_code=True
103
  )
104
 
105
  max_len = 2048
 
110
  tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
111
 
112
  bnb_config = BitsAndBytesConfig(
113
+ load_in_4bit=True,
114
+ bnb_4bit_quant_type="nf4",
115
+ bnb_4bit_compute_dtype=torch.float16,
116
+ bnb_4bit_use_double_quant=True,
117
  )
118
 
119
  model = AutoModelForCausalLM.from_pretrained(
120
  model_repo,
121
+ quantization_config=bnb_config,
122
+ low_cpu_mem_usage=True,
123
+ trust_remote_code=True
 
124
  )
125
 
126
+ max_len = 2048 # 8192
 
 
127
 
128
  elif model == 'mistral-7B':
129
  model_repo = 'mistralai/Mistral-7B-v0.1'
 
131
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
132
 
133
  bnb_config = BitsAndBytesConfig(
134
+ load_in_4bit=True,
135
+ bnb_4bit_quant_type="nf4",
136
+ bnb_4bit_compute_dtype=torch.float16,
137
+ bnb_4bit_use_double_quant=True,
138
  )
139
 
140
  model = AutoModelForCausalLM.from_pretrained(
141
  model_repo,
142
+ quantization_config=bnb_config,
143
+ device_map='auto',
144
+ low_cpu_mem_usage=True,
145
  )
146
 
147
  max_len = 1024
 
151
 
152
  return tokenizer, model, max_len
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # Initialize model and tokenizer
156
+ tokenizer, model, max_len = get_model(model=CFG.model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ # Set up pipeline for LLM
159
  pipe = pipeline(
160
+ task="text-generation",
161
+ model=model,
162
+ tokenizer=tokenizer,
163
+ pad_token_id=tokenizer.eos_token_id,
164
+ max_length=max_len,
165
+ temperature=CFG.temperature,
166
+ top_p=CFG.top_p,
167
+ repetition_penalty=CFG.repetition_penalty
 
168
  )
169
 
170
+ # Langchain pipeline
171
+ llm = HuggingFacePipeline(pipeline=pipe)
172
 
173
+ # Load PDFs from content
174
  loader = DirectoryLoader(
175
  CFG.PDFs_path,
176
  glob="./*.pdf",
 
181
 
182
  documents = loader.load()
183
 
184
+ # Split documents into chunks
185
  text_splitter = RecursiveCharacterTextSplitter(
186
+ chunk_size=CFG.split_chunk_size,
187
+ chunk_overlap=CFG.split_overlap
188
  )
189
 
190
  texts = text_splitter.split_documents(documents)
191
 
192
+ # Set up vector store with embeddings
193
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
194
 
195
  vectordb = FAISS.from_documents(
 
197
  HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
198
  )
199
 
200
+ # Persist vector database
201
+ vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag")
 
202
 
203
+ # Define prompt template for question answering
204
  prompt_template = """
205
  Don't try to make up an answer, if you don't know just say that you don't know.
206
  Answer in the same language the question was asked.
 
211
  Question: {question}
212
  Answer:"""
213
 
 
214
  PROMPT = PromptTemplate(
215
+ template=prompt_template,
216
+ input_variables=["context", "question"]
217
  )
218
 
219
+ # Set up retriever from vector store
220
+ retriever = vectordb.as_retriever(search_kwargs={"k": CFG.k, "search_type": "similarity"})
221
 
222
+ # Create the QA chain
223
  qa_chain = RetrievalQA.from_chain_type(
224
+ llm=llm,
225
+ chain_type="stuff", # map_reduce, map_rerank, stuff, refine
226
+ retriever=retriever,
227
+ chain_type_kwargs={"prompt": PROMPT},
228
+ return_source_documents=True,
229
+ verbose=False
230
  )
231
 
232
+ # Function to wrap text to preserve newlines
233
  def wrap_text_preserve_newlines(text, width=700):
 
234
  lines = text.split('\n')
 
 
235
  wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
 
 
236
  wrapped_text = '\n'.join(wrapped_lines)
 
237
  return wrapped_text
238
 
239
 
240
+ # Function to process the LLM response
241
  def process_llm_response(llm_response):
242
  ans = wrap_text_preserve_newlines(llm_response['result'])
 
243
  sources_used = ' \n'.join(
244
  [
245
  source.metadata['source'].split('/')[-1][:-4]
 
248
  for source in llm_response['source_documents']
249
  ]
250
  )
 
251
  ans = ans + '\n\nSources: \n' + sources_used
252
  return ans
253
 
254
+
255
+ # Function to get LLM response
256
  def llm_ans(query):
257
  start = time.time()
258
 
 
260
  ans = process_llm_response(llm_response)
261
 
262
  end = time.time()
 
263
  time_elapsed = int(round(end - start, 0))
264
  time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
265
  return ans + time_elapsed_str
266
 
267
+
268
+ # Correct locale issue
269
  import locale
270
  locale.getpreferredencoding = lambda: "UTF-8"
271
 
272
+ # Gradio interface
273
  import gradio as gr
274
 
275
+ def predict(message, history):
276
+ output = str(llm_ans(message)).replace("\n", "<br/>")
277
+ return output
278
 
279
+ demo = gr.Interface(
280
+ fn=predict,
281
+ inputs=gr.Textbox(label="Enter your question"),
282
+ outputs="html",
283
+ title=f'Open-Source LLM ({CFG.model_name}) Question Answering'
284
+ )
 
285
 
286
+ demo.queue()
287
+ demo.launch()