HemaMeena commited on
Commit
bbe50c0
·
verified ·
1 Parent(s): d4ee3e6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +285 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+
4
+ import os
5
+ import glob
6
+ import textwrap
7
+ import time
8
+
9
+ import langchain
10
+
11
+ ### loaders
12
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
13
+
14
+ ### splits
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+
17
+ ### prompts
18
+ from langchain import PromptTemplate, LLMChain
19
+
20
+ ### vector stores
21
+ from langchain.vectorstores import FAISS
22
+
23
+ ### models
24
+ from langchain.llms import HuggingFacePipeline
25
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
26
+
27
+ ### retrievers
28
+ from langchain.chains import RetrievalQA
29
+
30
+ import torch
31
+ import transformers
32
+ from transformers import (
33
+ AutoTokenizer, AutoModelForCausalLM,
34
+ BitsAndBytesConfig,
35
+ pipeline
36
+ )
37
+ import gradio as gr
38
+ import locale
39
+ import time
40
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
41
+
42
+ class CFG:
43
+ # LLMs
44
+ model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
45
+ temperature = 0
46
+ top_p = 0.95
47
+ repetition_penalty = 1.15
48
+
49
+ # splitting
50
+ split_chunk_size = 800
51
+ split_overlap = 0
52
+
53
+ # embeddings
54
+ embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2'
55
+
56
+ # similar passages
57
+ k = 6
58
+
59
+ # paths
60
+ PDFs_path = './'
61
+ Embeddings_path = './faiss-hp-sentence-transformers'
62
+ Output_folder = './rag-vectordb'
63
+
64
+ def get_model(model = CFG.model_name):
65
+
66
+ print('\nDownloading model: ', model, '\n\n')
67
+
68
+ if model == 'wizardlm':
69
+ model_repo = 'TheBloke/wizardLM-7B-HF'
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
72
+
73
+ bnb_config = BitsAndBytesConfig(
74
+ load_in_4bit = True,
75
+ bnb_4bit_quant_type = "nf4",
76
+ bnb_4bit_compute_dtype = torch.float16,
77
+ bnb_4bit_use_double_quant = True,
78
+ )
79
+
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ model_repo,
82
+ quantization_config = bnb_config,
83
+ device_map = 'auto',
84
+ low_cpu_mem_usage = True
85
+ )
86
+
87
+ max_len = 1024
88
+
89
+ elif model == 'llama2-7b-chat':
90
+ model_repo = 'daryl149/llama-2-7b-chat-hf'
91
+
92
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
93
+
94
+ bnb_config = BitsAndBytesConfig(
95
+ load_in_4bit = True,
96
+ bnb_4bit_quant_type = "nf4",
97
+ bnb_4bit_compute_dtype = torch.float16,
98
+ bnb_4bit_use_double_quant = True,
99
+ )
100
+
101
+ model = AutoModelForCausalLM.from_pretrained(
102
+ model_repo,
103
+ quantization_config = bnb_config,
104
+ device_map = 'auto',
105
+ low_cpu_mem_usage = True,
106
+ trust_remote_code = True
107
+ )
108
+
109
+ max_len = 2048
110
+
111
+ elif model == 'llama2-13b-chat':
112
+ model_repo = 'daryl149/llama-2-13b-chat-hf'
113
+
114
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
115
+
116
+ bnb_config = BitsAndBytesConfig(
117
+ load_in_4bit = True,
118
+ bnb_4bit_quant_type = "nf4",
119
+ bnb_4bit_compute_dtype = torch.float16,
120
+ bnb_4bit_use_double_quant = True,
121
+ )
122
+
123
+ model = AutoModelForCausalLM.from_pretrained(
124
+ model_repo,
125
+ quantization_config = bnb_config,
126
+
127
+ low_cpu_mem_usage = True,
128
+ trust_remote_code = True
129
+ )
130
+
131
+ max_len = 2048 #8192
132
+ truncation=True, # Explicitly enable truncation
133
+ padding="max_len" # Optional: pad to max_length
134
+
135
+ elif model == 'mistral-7B':
136
+ model_repo = 'mistralai/Mistral-7B-v0.1'
137
+
138
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
139
+
140
+ bnb_config = BitsAndBytesConfig(
141
+ load_in_4bit = True,
142
+ bnb_4bit_quant_type = "nf4",
143
+ bnb_4bit_compute_dtype = torch.float16,
144
+ bnb_4bit_use_double_quant = True,
145
+ )
146
+
147
+ model = AutoModelForCausalLM.from_pretrained(
148
+ model_repo,
149
+ quantization_config = bnb_config,
150
+ device_map = 'auto',
151
+ low_cpu_mem_usage = True,
152
+ )
153
+
154
+ max_len = 1024
155
+
156
+ else:
157
+ print("Not implemented model (tokenizer and backbone)")
158
+
159
+ return tokenizer, model, max_len
160
+
161
+ def wrap_text_preserve_newlines(text, width=700):
162
+ # Split the input text into lines based on newline characters
163
+ lines = text.split('\n')
164
+
165
+ # Wrap each line individually
166
+ wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
167
+
168
+ # Join the wrapped lines back together using newline characters
169
+ wrapped_text = '\n'.join(wrapped_lines)
170
+
171
+ return wrapped_text
172
+
173
+
174
+ def process_llm_response(llm_response):
175
+ ans = wrap_text_preserve_newlines(llm_response['result'])
176
+
177
+ sources_used = ' \n'.join(
178
+ [
179
+ source.metadata['source'].split('/')[-1][:-4]
180
+ + ' - page: '
181
+ + str(source.metadata['page'])
182
+ for source in llm_response['source_documents']
183
+ ]
184
+ )
185
+
186
+ ans = ans + '\n\nSources: \n' + sources_used
187
+ return ans
188
+
189
+ def llm_ans(query):
190
+ start = time.time()
191
+
192
+ llm_response = qa_chain.invoke(query)
193
+ ans = process_llm_response(llm_response)
194
+
195
+ end = time.time()
196
+
197
+ time_elapsed = int(round(end - start, 0))
198
+ time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
199
+ return ans + time_elapsed_str
200
+
201
+ def predict(message, history):
202
+ output = str(llm_ans(message)).replace("\n", "<br/>")
203
+ return output
204
+
205
+
206
+
207
+
208
+ tokenizer, model, max_len = get_model(model = CFG.model_name)
209
+
210
+ pipe = pipeline(
211
+ task = "text-generation",
212
+ model = model,
213
+ tokenizer = tokenizer,
214
+ pad_token_id = tokenizer.eos_token_id,
215
+ # do_sample = True,
216
+ max_length = max_len,
217
+ temperature = CFG.temperature,
218
+ top_p = CFG.top_p,
219
+ repetition_penalty = CFG.repetition_penalty
220
+ )
221
+
222
+ ### langchain pipeline
223
+ llm = HuggingFacePipeline(pipeline = pipe)
224
+
225
+ loader = DirectoryLoader(
226
+ CFG.PDFs_path,
227
+ glob="./*.pdf",
228
+ loader_cls=PyPDFLoader,
229
+ show_progress=True,
230
+ use_multithreading=True
231
+ )
232
+
233
+ documents = loader.load()
234
+ text_splitter = RecursiveCharacterTextSplitter(
235
+ chunk_size = CFG.split_chunk_size,
236
+ chunk_overlap = CFG.split_overlap
237
+ )
238
+
239
+ texts = text_splitter.split_documents(documents)
240
+
241
+ vectordb = FAISS.from_documents(
242
+ texts,
243
+ HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
244
+ )
245
+
246
+ ### persist vector database
247
+ vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag")
248
+
249
+ retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"})
250
+
251
+ qa_chain = RetrievalQA.from_chain_type(
252
+ llm = llm,
253
+ chain_type = "stuff", # map_reduce, map_rerank, stuff, refine
254
+ retriever = retriever,
255
+ chain_type_kwargs = {"prompt": PROMPT},
256
+ return_source_documents = True,
257
+ verbose = False
258
+ )
259
+
260
+ prompt_template = """
261
+ Don't try to make up an answer, if you don't know just say that you don't know.
262
+ Answer in the same language the question was asked.
263
+ Use only the following pieces of context to answer the question at the end.
264
+
265
+ {context}
266
+
267
+ Question: {question}
268
+ Answer:"""
269
+
270
+
271
+ PROMPT = PromptTemplate(
272
+ template = prompt_template,
273
+ input_variables = ["context", "question"]
274
+ )
275
+
276
+
277
+ locale.getpreferredencoding = lambda: "UTF-8"
278
+
279
+ demo = gr.ChatInterface(
280
+ predict,
281
+ title = f' Open-Source LLM ({CFG.model_name}) Question Answering'
282
+ )
283
+
284
+ demo.queue()
285
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sentence_transformers==2.2.2
2
+
3
+ langchain-community
4
+ langchain-huggingface
5
+ tiktoken
6
+ pypdf
7
+ faiss-gpu
8
+ InstructorEmbedding
9
+
10
+ transformers
11
+ accelerate
12
+ bitsandbytes