HemaMeena commited on
Commit
57baa30
·
verified ·
1 Parent(s): 8291021

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -376
app.py DELETED
@@ -1,376 +0,0 @@
1
- import os
2
- import glob
3
- import textwrap
4
- import time
5
- import langchain
6
- import locale
7
- import gradio as gr
8
-
9
- locale.getpreferredencoding = lambda: "UTF-8"
10
-
11
- from langchain.document_loaders import PyPDFLoader, DirectoryLoader
12
- from langchain.text_splitter import RecursiveCharacterTextSplitter
13
- from langchain import PromptTemplate, LLMChain
14
-
15
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
16
-
17
- ### vector stores
18
- from langchain.vectorstores import FAISS
19
-
20
- ### models
21
- from langchain.llms import HuggingFacePipeline
22
- from langchain.embeddings import HuggingFaceInstructEmbeddings
23
-
24
- ### retrievers
25
- from langchain.chains import RetrievalQA
26
-
27
- import torch
28
- import transformers
29
- from transformers import (
30
- AutoTokenizer, AutoModelForCausalLM,
31
- BitsAndBytesConfig,
32
- pipeline
33
- )
34
-
35
- sorted(glob.glob('/content/anatomy_vol_*'))
36
-
37
-
38
- def wrap_text_preserve_newlines(text, width=700):
39
- # Split the input text into lines based on newline characters
40
- lines = text.split('\n')
41
-
42
- # Wrap each line individually
43
- wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
44
-
45
- # Join the wrapped lines back together using newline characters
46
- wrapped_text = '\n'.join(wrapped_lines)
47
-
48
- return wrapped_text
49
-
50
-
51
- def process_llm_response(llm_response):
52
- ans = wrap_text_preserve_newlines(llm_response['result'])
53
-
54
- sources_used = ' \n'.join(
55
- [
56
- source.metadata['source'].split('/')[-1][:-4]
57
- + ' - page: '
58
- + str(source.metadata['page'])
59
- for source in llm_response['source_documents']
60
- ]
61
- )
62
-
63
- ans = ans + '\n\nSources: \n' + sources_used
64
- return ans
65
-
66
- def llm_ans(query):
67
- start = time.time()
68
-
69
- llm_response = qa_chain.invoke(query)
70
- ans = process_llm_response(llm_response)
71
-
72
- end = time.time()
73
-
74
- time_elapsed = int(round(end - start, 0))
75
- time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
76
- return ans + time_elapsed_str
77
-
78
-
79
- def predict(message, history):
80
- output = str(llm_ans(message)).replace("\n", "<br/>")
81
- return output
82
-
83
-
84
- def get_model(model = CFG.model_name):
85
-
86
- if model == 'wizardlm':
87
- model_repo = 'TheBloke/wizardLM-7B-HF'
88
-
89
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
90
-
91
- bnb_config = BitsAndBytesConfig(
92
- load_in_4bit = True,
93
- bnb_4bit_quant_type = "nf4",
94
- bnb_4bit_compute_dtype = torch.float16,
95
- bnb_4bit_use_double_quant = True,
96
- )
97
-
98
- model = AutoModelForCausalLM.from_pretrained(
99
- model_repo,
100
- quantization_config = bnb_config,
101
- device_map = 'auto',
102
- low_cpu_mem_usage = True
103
- )
104
-
105
- max_len = 1024
106
-
107
- elif model == 'llama2-7b-chat':
108
- model_repo = 'daryl149/llama-2-7b-chat-hf'
109
-
110
- tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
111
-
112
- bnb_config = BitsAndBytesConfig(
113
- load_in_4bit = True,
114
- bnb_4bit_quant_type = "nf4",
115
- bnb_4bit_compute_dtype = torch.float16,
116
- bnb_4bit_use_double_quant = True,
117
- )
118
-
119
- model = AutoModelForCausalLM.from_pretrained(
120
- model_repo,
121
- quantization_config = bnb_config,
122
- device_map = 'auto',
123
- low_cpu_mem_usage = True,
124
- trust_remote_code = True
125
- )
126
-
127
- max_len = 2048
128
-
129
- elif model == 'llama2-13b-chat':
130
- model_repo = 'daryl149/llama-2-13b-chat-hf'
131
-
132
- tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
133
-
134
- bnb_config = BitsAndBytesConfig(
135
- load_in_4bit = True,
136
- bnb_4bit_quant_type = "nf4",
137
- bnb_4bit_compute_dtype = torch.float16,
138
- bnb_4bit_use_double_quant = True,
139
- )
140
-
141
- model = AutoModelForCausalLM.from_pretrained(
142
- model_repo,
143
- quantization_config = bnb_config,
144
-
145
- low_cpu_mem_usage = True,
146
- trust_remote_code = True
147
- )
148
-
149
- max_len = 2048 #8192
150
- truncation=True, # Explicitly enable truncation
151
- padding="max_len" # Optional: pad to max_length
152
-
153
- elif model == 'mistral-7B':
154
- model_repo = 'mistralai/Mistral-7B-v0.1'
155
-
156
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
157
-
158
- bnb_config = BitsAndBytesConfig(
159
- load_in_4bit = True,
160
- bnb_4bit_quant_type = "nf4",
161
- bnb_4bit_compute_dtype = torch.float16,
162
- bnb_4bit_use_double_quant = True,
163
- )
164
-
165
- model = AutoModelForCausalLM.from_pretrained(
166
- model_repo,
167
- quantization_config = bnb_config,
168
- device_map = 'auto',
169
- low_cpu_mem_usage = True,
170
- )
171
-
172
- max_len = 1024
173
-
174
- else:
175
- print("Not implemented model (tokenizer and backbone)")
176
-
177
- return tokenizer, model, max_len
178
-
179
- def get_model(model = CFG.model_name):
180
-
181
-
182
- if model == 'wizardlm':
183
- model_repo = 'TheBloke/wizardLM-7B-HF'
184
-
185
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
186
-
187
- bnb_config = BitsAndBytesConfig(
188
- load_in_4bit = True,
189
- bnb_4bit_quant_type = "nf4",
190
- bnb_4bit_compute_dtype = torch.float16,
191
- bnb_4bit_use_double_quant = True,
192
- )
193
-
194
- model = AutoModelForCausalLM.from_pretrained(
195
- model_repo,
196
- quantization_config = bnb_config,
197
- device_map = 'auto',
198
- low_cpu_mem_usage = True
199
- )
200
-
201
- max_len = 1024
202
-
203
- elif model == 'llama2-7b-chat':
204
- model_repo = 'daryl149/llama-2-7b-chat-hf'
205
-
206
- tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
207
-
208
- bnb_config = BitsAndBytesConfig(
209
- load_in_4bit = True,
210
- bnb_4bit_quant_type = "nf4",
211
- bnb_4bit_compute_dtype = torch.float16,
212
- bnb_4bit_use_double_quant = True,
213
- )
214
-
215
- model = AutoModelForCausalLM.from_pretrained(
216
- model_repo,
217
- quantization_config = bnb_config,
218
- device_map = 'auto',
219
- low_cpu_mem_usage = True,
220
- trust_remote_code = True
221
- )
222
-
223
- max_len = 2048
224
-
225
- elif model == 'llama2-13b-chat':
226
- model_repo = 'daryl149/llama-2-13b-chat-hf'
227
-
228
- tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
229
-
230
- bnb_config = BitsAndBytesConfig(
231
- load_in_4bit = True,
232
- bnb_4bit_quant_type = "nf4",
233
- bnb_4bit_compute_dtype = torch.float16,
234
- bnb_4bit_use_double_quant = True,
235
- )
236
-
237
- model = AutoModelForCausalLM.from_pretrained(
238
- model_repo,
239
- quantization_config = bnb_config,
240
-
241
- low_cpu_mem_usage = True,
242
- trust_remote_code = True
243
- )
244
-
245
- max_len = 2048 #8192
246
- truncation=True, # Explicitly enable truncation
247
- padding="max_len" # Optional: pad to max_length
248
-
249
- elif model == 'mistral-7B':
250
- model_repo = 'mistralai/Mistral-7B-v0.1'
251
-
252
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
253
-
254
- bnb_config = BitsAndBytesConfig(
255
- load_in_4bit = True,
256
- bnb_4bit_quant_type = "nf4",
257
- bnb_4bit_compute_dtype = torch.float16,
258
- bnb_4bit_use_double_quant = True,
259
- )
260
-
261
- model = AutoModelForCausalLM.from_pretrained(
262
- model_repo,
263
- quantization_config = bnb_config,
264
- device_map = 'auto',
265
- low_cpu_mem_usage = True,
266
- )
267
-
268
- max_len = 1024
269
-
270
- else:
271
- print("Not implemented model (tokenizer and backbone)")
272
-
273
- return tokenizer, model, max_len
274
-
275
-
276
- class CFG:
277
- # LLMs
278
- model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
279
- temperature = 0
280
- top_p = 0.95
281
- repetition_penalty = 1.15
282
-
283
- # splitting
284
- split_chunk_size = 800
285
- split_overlap = 0
286
-
287
- # embeddings
288
- embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2'
289
-
290
- # similar passages
291
- k = 6
292
-
293
- # paths
294
- PDFs_path = '/content/'
295
- Embeddings_path = '/content/faiss-hp-sentence-transformers'
296
- Output_folder = './rag-vectordb'
297
-
298
-
299
- tokenizer, model, max_len = get_model(model = CFG.model_name)
300
-
301
- pipe = pipeline(
302
- task = "text-generation",
303
- model = model,
304
- tokenizer = tokenizer,
305
- pad_token_id = tokenizer.eos_token_id,
306
- # do_sample = True,
307
- max_length = max_len,
308
- temperature = CFG.temperature,
309
- top_p = CFG.top_p,
310
- repetition_penalty = CFG.repetition_penalty
311
- )
312
-
313
- ### langchain pipeline
314
- llm = HuggingFacePipeline(pipeline = pipe)
315
-
316
- loader = DirectoryLoader(
317
- CFG.PDFs_path,
318
- glob="./*.pdf",
319
- loader_cls=PyPDFLoader,
320
- show_progress=True,
321
- use_multithreading=True
322
- )
323
-
324
- documents = loader.load()
325
-
326
- text_splitter = RecursiveCharacterTextSplitter(
327
- chunk_size = CFG.split_chunk_size,
328
- chunk_overlap = CFG.split_overlap
329
- )
330
-
331
- texts = text_splitter.split_documents(documents)
332
-
333
- vectordb = FAISS.from_documents(
334
- texts,
335
- HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
336
- )
337
-
338
- ### persist vector database
339
- vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag") # save in output folder
340
- # vectordb.save_local(f"{CFG.Embeddings_path}/faiss_index_hp") # save in input folder
341
-
342
- prompt_template = """
343
- Don't try to make up an answer, if you don't know just say that you don't know.
344
- Answer in the same language the question was asked.
345
- Use only the following pieces of context to answer the question at the end.
346
-
347
- {context}
348
-
349
- Question: {question}
350
- Answer:"""
351
-
352
-
353
- PROMPT = PromptTemplate(
354
- template = prompt_template,
355
- input_variables = ["context", "question"]
356
- )
357
-
358
- retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"})
359
-
360
- qa_chain = RetrievalQA.from_chain_type(
361
- llm = llm,
362
- chain_type = "stuff", # map_reduce, map_rerank, stuff, refine
363
- retriever = retriever,
364
- chain_type_kwargs = {"prompt": PROMPT},
365
- return_source_documents = True,
366
- verbose = False
367
- )
368
-
369
- def start_demo():
370
- demo = gr.ChatInterface(
371
- predict,
372
- title=f'Open-Source LLM ({CFG.model_name}) Question Answering'
373
- )
374
- demo.queue()
375
- demo.launch()
376
-