HemaMeena commited on
Commit
d57fe52
·
verified ·
1 Parent(s): e5c6476

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -284
app.py CHANGED
@@ -1,285 +1,207 @@
1
- import warnings
2
- warnings.filterwarnings("ignore")
3
-
4
- import os
5
- import glob
6
- import textwrap
7
- import time
8
-
9
- import langchain
10
-
11
- ### loaders
12
- from langchain.document_loaders import PyPDFLoader, DirectoryLoader
13
-
14
- ### splits
15
- from langchain.text_splitter import RecursiveCharacterTextSplitter
16
-
17
- ### prompts
18
- from langchain import PromptTemplate, LLMChain
19
-
20
- ### vector stores
21
- from langchain.vectorstores import FAISS
22
-
23
- ### models
24
- from langchain.llms import HuggingFacePipeline
25
- from langchain.embeddings import HuggingFaceInstructEmbeddings
26
-
27
- ### retrievers
28
- from langchain.chains import RetrievalQA
29
-
30
- import torch
31
- import transformers
32
- from transformers import (
33
- AutoTokenizer, AutoModelForCausalLM,
34
- BitsAndBytesConfig,
35
- pipeline
36
- )
37
- import gradio as gr
38
- import locale
39
- import time
40
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
41
-
42
- class CFG:
43
- # LLMs
44
- model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
45
- temperature = 0
46
- top_p = 0.95
47
- repetition_penalty = 1.15
48
-
49
- # splitting
50
- split_chunk_size = 800
51
- split_overlap = 0
52
-
53
- # embeddings
54
- embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2'
55
-
56
- # similar passages
57
- k = 6
58
-
59
- # paths
60
- PDFs_path = './'
61
- Embeddings_path = './faiss-hp-sentence-transformers'
62
- Output_folder = './rag-vectordb'
63
-
64
- def get_model(model = CFG.model_name):
65
-
66
- print('\nDownloading model: ', model, '\n\n')
67
-
68
- if model == 'wizardlm':
69
- model_repo = 'TheBloke/wizardLM-7B-HF'
70
-
71
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
72
-
73
- bnb_config = BitsAndBytesConfig(
74
- load_in_4bit = True,
75
- bnb_4bit_quant_type = "nf4",
76
- bnb_4bit_compute_dtype = torch.float16,
77
- bnb_4bit_use_double_quant = True,
78
- )
79
-
80
- model = AutoModelForCausalLM.from_pretrained(
81
- model_repo,
82
- quantization_config = bnb_config,
83
- device_map = 'auto',
84
- low_cpu_mem_usage = True
85
- )
86
-
87
- max_len = 1024
88
-
89
- elif model == 'llama2-7b-chat':
90
- model_repo = 'daryl149/llama-2-7b-chat-hf'
91
-
92
- tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
93
-
94
- bnb_config = BitsAndBytesConfig(
95
- load_in_4bit = True,
96
- bnb_4bit_quant_type = "nf4",
97
- bnb_4bit_compute_dtype = torch.float16,
98
- bnb_4bit_use_double_quant = True,
99
- )
100
-
101
- model = AutoModelForCausalLM.from_pretrained(
102
- model_repo,
103
- quantization_config = bnb_config,
104
- device_map = 'auto',
105
- low_cpu_mem_usage = True,
106
- trust_remote_code = True
107
- )
108
-
109
- max_len = 2048
110
-
111
- elif model == 'llama2-13b-chat':
112
- model_repo = 'daryl149/llama-2-13b-chat-hf'
113
-
114
- tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
115
-
116
- bnb_config = BitsAndBytesConfig(
117
- load_in_4bit = True,
118
- bnb_4bit_quant_type = "nf4",
119
- bnb_4bit_compute_dtype = torch.float16,
120
- bnb_4bit_use_double_quant = True,
121
- )
122
-
123
- model = AutoModelForCausalLM.from_pretrained(
124
- model_repo,
125
- quantization_config = bnb_config,
126
-
127
- low_cpu_mem_usage = True,
128
- trust_remote_code = True
129
- )
130
-
131
- max_len = 2048 #8192
132
- truncation=True, # Explicitly enable truncation
133
- padding="max_len" # Optional: pad to max_length
134
-
135
- elif model == 'mistral-7B':
136
- model_repo = 'mistralai/Mistral-7B-v0.1'
137
-
138
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
139
-
140
- bnb_config = BitsAndBytesConfig(
141
- load_in_4bit = True,
142
- bnb_4bit_quant_type = "nf4",
143
- bnb_4bit_compute_dtype = torch.float16,
144
- bnb_4bit_use_double_quant = True,
145
- )
146
-
147
- model = AutoModelForCausalLM.from_pretrained(
148
- model_repo,
149
- quantization_config = bnb_config,
150
- device_map = 'auto',
151
- low_cpu_mem_usage = True,
152
- )
153
-
154
- max_len = 1024
155
-
156
- else:
157
- print("Not implemented model (tokenizer and backbone)")
158
-
159
- return tokenizer, model, max_len
160
-
161
- def wrap_text_preserve_newlines(text, width=700):
162
- # Split the input text into lines based on newline characters
163
- lines = text.split('\n')
164
-
165
- # Wrap each line individually
166
- wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
167
-
168
- # Join the wrapped lines back together using newline characters
169
- wrapped_text = '\n'.join(wrapped_lines)
170
-
171
- return wrapped_text
172
-
173
-
174
- def process_llm_response(llm_response):
175
- ans = wrap_text_preserve_newlines(llm_response['result'])
176
-
177
- sources_used = ' \n'.join(
178
- [
179
- source.metadata['source'].split('/')[-1][:-4]
180
- + ' - page: '
181
- + str(source.metadata['page'])
182
- for source in llm_response['source_documents']
183
- ]
184
- )
185
-
186
- ans = ans + '\n\nSources: \n' + sources_used
187
- return ans
188
-
189
- def llm_ans(query):
190
- start = time.time()
191
-
192
- llm_response = qa_chain.invoke(query)
193
- ans = process_llm_response(llm_response)
194
-
195
- end = time.time()
196
-
197
- time_elapsed = int(round(end - start, 0))
198
- time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
199
- return ans + time_elapsed_str
200
-
201
- def predict(message, history):
202
- output = str(llm_ans(message)).replace("\n", "<br/>")
203
- return output
204
-
205
-
206
-
207
-
208
- tokenizer, model, max_len = get_model(model = CFG.model_name)
209
-
210
- pipe = pipeline(
211
- task = "text-generation",
212
- model = model,
213
- tokenizer = tokenizer,
214
- pad_token_id = tokenizer.eos_token_id,
215
- # do_sample = True,
216
- max_length = max_len,
217
- temperature = CFG.temperature,
218
- top_p = CFG.top_p,
219
- repetition_penalty = CFG.repetition_penalty
220
- )
221
-
222
- ### langchain pipeline
223
- llm = HuggingFacePipeline(pipeline = pipe)
224
-
225
- loader = DirectoryLoader(
226
- CFG.PDFs_path,
227
- glob="./*.pdf",
228
- loader_cls=PyPDFLoader,
229
- show_progress=True,
230
- use_multithreading=True
231
- )
232
-
233
- documents = loader.load()
234
- text_splitter = RecursiveCharacterTextSplitter(
235
- chunk_size = CFG.split_chunk_size,
236
- chunk_overlap = CFG.split_overlap
237
- )
238
-
239
- texts = text_splitter.split_documents(documents)
240
-
241
- vectordb = FAISS.from_documents(
242
- texts,
243
- HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
244
- )
245
-
246
- ### persist vector database
247
- vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag")
248
-
249
- retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"})
250
-
251
- qa_chain = RetrievalQA.from_chain_type(
252
- llm = llm,
253
- chain_type = "stuff", # map_reduce, map_rerank, stuff, refine
254
- retriever = retriever,
255
- chain_type_kwargs = {"prompt": PROMPT},
256
- return_source_documents = True,
257
- verbose = False
258
- )
259
-
260
- prompt_template = """
261
- Don't try to make up an answer, if you don't know just say that you don't know.
262
- Answer in the same language the question was asked.
263
- Use only the following pieces of context to answer the question at the end.
264
-
265
- {context}
266
-
267
- Question: {question}
268
- Answer:"""
269
-
270
-
271
- PROMPT = PromptTemplate(
272
- template = prompt_template,
273
- input_variables = ["context", "question"]
274
- )
275
-
276
-
277
- locale.getpreferredencoding = lambda: "UTF-8"
278
-
279
- demo = gr.ChatInterface(
280
- predict,
281
- title = f' Open-Source LLM ({CFG.model_name}) Question Answering'
282
- )
283
-
284
- demo.queue()
285
  demo.launch()
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+
4
+ import os
5
+ import glob
6
+ import textwrap
7
+ import time
8
+
9
+ import langchain
10
+
11
+ ### loaders
12
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
13
+
14
+ ### splits
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+
17
+ ### prompts
18
+ from langchain import PromptTemplate, LLMChain
19
+
20
+ ### vector stores
21
+ from langchain.vectorstores import FAISS
22
+
23
+ ### models
24
+ from langchain.llms import HuggingFacePipeline
25
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
26
+
27
+ ### retrievers
28
+ from langchain.chains import RetrievalQA
29
+
30
+ import torch
31
+ import transformers
32
+ from transformers import (
33
+ AutoTokenizer, AutoModelForCausalLM,
34
+ BitsAndBytesConfig,
35
+ pipeline
36
+ )
37
+ import gradio as gr
38
+ import locale
39
+ import time
40
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
41
+
42
+ class CFG:
43
+ # LLMs
44
+ model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
45
+ temperature = 0
46
+ top_p = 0.95
47
+ repetition_penalty = 1.15
48
+
49
+ # splitting
50
+ split_chunk_size = 800
51
+ split_overlap = 0
52
+
53
+ # embeddings
54
+ embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2'
55
+
56
+ # similar passages
57
+ k = 6
58
+
59
+ # paths
60
+ PDFs_path = './'
61
+ Embeddings_path = './faiss-hp-sentence-transformers'
62
+ Output_folder = './rag-vectordb'
63
+
64
+ def get_model(model=CFG.model_name):
65
+ print('\nDownloading model: ', model, '\n\n')
66
+ model_repo = None
67
+
68
+ if model == 'llama2-13b-chat':
69
+ model_repo = 'daryl149/llama-2-13b-chat-hf'
70
+
71
+ if model_repo:
72
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ model_repo,
75
+ device_map="auto",
76
+ trust_remote_code=True
77
+ )
78
+ max_len = 2048
79
+ else:
80
+ raise ValueError("Model not implemented: " + model)
81
+
82
+ return tokenizer, model, max_len
83
+
84
+ def wrap_text_preserve_newlines(text, width=700):
85
+ # Split the input text into lines based on newline characters
86
+ lines = text.split('\n')
87
+
88
+ # Wrap each line individually
89
+ wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
90
+
91
+ # Join the wrapped lines back together using newline characters
92
+ wrapped_text = '\n'.join(wrapped_lines)
93
+
94
+ return wrapped_text
95
+
96
+
97
+ def process_llm_response(llm_response):
98
+ ans = wrap_text_preserve_newlines(llm_response['result'])
99
+
100
+ sources_used = ' \n'.join(
101
+ [
102
+ source.metadata['source'].split('/')[-1][:-4]
103
+ + ' - page: '
104
+ + str(source.metadata['page'])
105
+ for source in llm_response['source_documents']
106
+ ]
107
+ )
108
+
109
+ ans = ans + '\n\nSources: \n' + sources_used
110
+ return ans
111
+
112
+ def llm_ans(query):
113
+ start = time.time()
114
+
115
+ llm_response = qa_chain.invoke(query)
116
+ ans = process_llm_response(llm_response)
117
+
118
+ end = time.time()
119
+
120
+ time_elapsed = int(round(end - start, 0))
121
+ time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
122
+ return ans + time_elapsed_str
123
+
124
+ def predict(message, history):
125
+ output = str(llm_ans(message)).replace("\n", "<br/>")
126
+ return output
127
+
128
+
129
+
130
+
131
+ tokenizer, model, max_len = get_model(model = CFG.model_name)
132
+
133
+ pipe = pipeline(
134
+ task="text-generation",
135
+ model=model,
136
+ tokenizer=tokenizer,
137
+ pad_token_id=tokenizer.eos_token_id,
138
+ max_length=max_len,
139
+ temperature=CFG.temperature,
140
+ top_p=CFG.top_p,
141
+ repetition_penalty=CFG.repetition_penalty
142
+ )
143
+
144
+ ### langchain pipeline
145
+ llm = HuggingFacePipeline(pipeline = pipe)
146
+
147
+ loader = DirectoryLoader(
148
+ CFG.PDFs_path,
149
+ glob="./*.pdf",
150
+ loader_cls=PyPDFLoader,
151
+ show_progress=True,
152
+ use_multithreading=True
153
+ )
154
+
155
+ documents = loader.load()
156
+ text_splitter = RecursiveCharacterTextSplitter(
157
+ chunk_size = CFG.split_chunk_size,
158
+ chunk_overlap = CFG.split_overlap
159
+ )
160
+
161
+ texts = text_splitter.split_documents(documents)
162
+
163
+ vectordb = FAISS.from_documents(
164
+ texts,
165
+ HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
166
+ )
167
+
168
+ ### persist vector database
169
+ vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag")
170
+
171
+ retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"})
172
+
173
+ qa_chain = RetrievalQA.from_chain_type(
174
+ llm = llm,
175
+ chain_type = "stuff", # map_reduce, map_rerank, stuff, refine
176
+ retriever = retriever,
177
+ chain_type_kwargs = {"prompt": PROMPT},
178
+ return_source_documents = True,
179
+ verbose = False
180
+ )
181
+
182
+ prompt_template = """
183
+ Don't try to make up an answer, if you don't know just say that you don't know.
184
+ Answer in the same language the question was asked.
185
+ Use only the following pieces of context to answer the question at the end.
186
+
187
+ {context}
188
+
189
+ Question: {question}
190
+ Answer:"""
191
+
192
+
193
+ PROMPT = PromptTemplate(
194
+ template = prompt_template,
195
+ input_variables = ["context", "question"]
196
+ )
197
+
198
+
199
+ locale.getpreferredencoding = lambda: "UTF-8"
200
+
201
+ demo = gr.ChatInterface(
202
+ predict,
203
+ title = f' Open-Source LLM ({CFG.model_name}) Question Answering'
204
+ )
205
+
206
+ demo.queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  demo.launch()