HemaMeena commited on
Commit
4d46aa9
·
verified ·
1 Parent(s): a51d3f1

Upload texttrail.py

Browse files
Files changed (1) hide show
  1. texttrail.py +388 -0
texttrail.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """TextTrail.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/19FMO4hPcBUvq4whuvATRRXDxJ4ONqElV
8
+ """
9
+
10
+ ! nvidia-smi -L
11
+
12
+ # Commented out IPython magic to ensure Python compatibility.
13
+ # %%time
14
+ #
15
+ # from IPython.display import clear_output
16
+ #
17
+ # ! pip install sentence_transformers==2.2.2
18
+ #
19
+ # ! pip install -qq -U langchain-community
20
+ # ! pip install -U langchain-huggingface
21
+ # ! pip install -qq -U tiktoken
22
+ # ! pip install -qq -U pypdf
23
+ # ! pip install -qq -U faiss-gpu
24
+ # ! pip install -qq -U InstructorEmbedding
25
+ #
26
+ # ! pip install -qq -U transformers
27
+ # ! pip install -qq -U accelerate
28
+ # ! pip install -qq -U bitsandbytes
29
+ #
30
+ # clear_output()
31
+
32
+ # Commented out IPython magic to ensure Python compatibility.
33
+ # %%time
34
+ #
35
+ # import warnings
36
+ # warnings.filterwarnings("ignore")
37
+ #
38
+ # import os
39
+ # import glob
40
+ # import textwrap
41
+ # import time
42
+ #
43
+ # import langchain
44
+ #
45
+ # ### loaders
46
+ # from langchain.document_loaders import PyPDFLoader, DirectoryLoader
47
+ #
48
+ # ### splits
49
+ # from langchain.text_splitter import RecursiveCharacterTextSplitter
50
+ #
51
+ # ### prompts
52
+ # from langchain import PromptTemplate, LLMChain
53
+ #
54
+ # ### vector stores
55
+ # from langchain.vectorstores import FAISS
56
+ #
57
+ # ### models
58
+ # from langchain.llms import HuggingFacePipeline
59
+ # from langchain.embeddings import HuggingFaceInstructEmbeddings
60
+ #
61
+ # ### retrievers
62
+ # from langchain.chains import RetrievalQA
63
+ #
64
+ # import torch
65
+ # import transformers
66
+ # from transformers import (
67
+ # AutoTokenizer, AutoModelForCausalLM,
68
+ # BitsAndBytesConfig,
69
+ # pipeline
70
+ # )
71
+ #
72
+ # clear_output()
73
+
74
+ sorted(glob.glob('/content/anatomy_vol_*'))
75
+
76
+ class CFG:
77
+ # LLMs
78
+ model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
79
+ temperature = 0
80
+ top_p = 0.95
81
+ repetition_penalty = 1.15
82
+
83
+ # splitting
84
+ split_chunk_size = 800
85
+ split_overlap = 0
86
+
87
+ # embeddings
88
+ embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2'
89
+
90
+ # similar passages
91
+ k = 6
92
+
93
+ # paths
94
+ PDFs_path = '/content/'
95
+ Embeddings_path = '/content/faiss-hp-sentence-transformers'
96
+ Output_folder = './rag-vectordb'
97
+
98
+ def get_model(model = CFG.model_name):
99
+
100
+ print('\nDownloading model: ', model, '\n\n')
101
+
102
+ if model == 'wizardlm':
103
+ model_repo = 'TheBloke/wizardLM-7B-HF'
104
+
105
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
106
+
107
+ bnb_config = BitsAndBytesConfig(
108
+ load_in_4bit = True,
109
+ bnb_4bit_quant_type = "nf4",
110
+ bnb_4bit_compute_dtype = torch.float16,
111
+ bnb_4bit_use_double_quant = True,
112
+ )
113
+
114
+ model = AutoModelForCausalLM.from_pretrained(
115
+ model_repo,
116
+ quantization_config = bnb_config,
117
+ device_map = 'auto',
118
+ low_cpu_mem_usage = True
119
+ )
120
+
121
+ max_len = 1024
122
+
123
+ elif model == 'llama2-7b-chat':
124
+ model_repo = 'daryl149/llama-2-7b-chat-hf'
125
+
126
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
127
+
128
+ bnb_config = BitsAndBytesConfig(
129
+ load_in_4bit = True,
130
+ bnb_4bit_quant_type = "nf4",
131
+ bnb_4bit_compute_dtype = torch.float16,
132
+ bnb_4bit_use_double_quant = True,
133
+ )
134
+
135
+ model = AutoModelForCausalLM.from_pretrained(
136
+ model_repo,
137
+ quantization_config = bnb_config,
138
+ device_map = 'auto',
139
+ low_cpu_mem_usage = True,
140
+ trust_remote_code = True
141
+ )
142
+
143
+ max_len = 2048
144
+
145
+ elif model == 'llama2-13b-chat':
146
+ model_repo = 'daryl149/llama-2-13b-chat-hf'
147
+
148
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
149
+
150
+ bnb_config = BitsAndBytesConfig(
151
+ load_in_4bit = True,
152
+ bnb_4bit_quant_type = "nf4",
153
+ bnb_4bit_compute_dtype = torch.float16,
154
+ bnb_4bit_use_double_quant = True,
155
+ )
156
+
157
+ model = AutoModelForCausalLM.from_pretrained(
158
+ model_repo,
159
+ quantization_config = bnb_config,
160
+
161
+ low_cpu_mem_usage = True,
162
+ trust_remote_code = True
163
+ )
164
+
165
+ max_len = 2048 #8192
166
+ truncation=True, # Explicitly enable truncation
167
+ padding="max_len" # Optional: pad to max_length
168
+
169
+ elif model == 'mistral-7B':
170
+ model_repo = 'mistralai/Mistral-7B-v0.1'
171
+
172
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
173
+
174
+ bnb_config = BitsAndBytesConfig(
175
+ load_in_4bit = True,
176
+ bnb_4bit_quant_type = "nf4",
177
+ bnb_4bit_compute_dtype = torch.float16,
178
+ bnb_4bit_use_double_quant = True,
179
+ )
180
+
181
+ model = AutoModelForCausalLM.from_pretrained(
182
+ model_repo,
183
+ quantization_config = bnb_config,
184
+ device_map = 'auto',
185
+ low_cpu_mem_usage = True,
186
+ )
187
+
188
+ max_len = 1024
189
+
190
+ else:
191
+ print("Not implemented model (tokenizer and backbone)")
192
+
193
+ return tokenizer, model, max_len
194
+
195
+ print(torch.cuda.is_available())
196
+ print(torch.cuda.device_count())
197
+
198
+ # Commented out IPython magic to ensure Python compatibility.
199
+ # %%time
200
+ #
201
+ # tokenizer, model, max_len = get_model(model = CFG.model_name)
202
+ #
203
+ # clear_output()
204
+
205
+ model.eval()
206
+
207
+ ### check how Accelerate split the model across the available devices (GPUs)
208
+ model.hf_device_map
209
+
210
+ ### hugging face pipeline
211
+ pipe = pipeline(
212
+ task = "text-generation",
213
+ model = model,
214
+ tokenizer = tokenizer,
215
+ pad_token_id = tokenizer.eos_token_id,
216
+ # do_sample = True,
217
+ max_length = max_len,
218
+ temperature = CFG.temperature,
219
+ top_p = CFG.top_p,
220
+ repetition_penalty = CFG.repetition_penalty
221
+ )
222
+
223
+ ### langchain pipeline
224
+ llm = HuggingFacePipeline(pipeline = pipe)
225
+
226
+ llm
227
+
228
+ query = "what are the structural organization of a human body"
229
+ llm.invoke(query)
230
+
231
+ """Langchain"""
232
+
233
+ CFG.model_name
234
+
235
+ """Loader"""
236
+
237
+ # Commented out IPython magic to ensure Python compatibility.
238
+ # %%time
239
+ #
240
+ # loader = DirectoryLoader(
241
+ # CFG.PDFs_path,
242
+ # glob="./*.pdf",
243
+ # loader_cls=PyPDFLoader,
244
+ # show_progress=True,
245
+ # use_multithreading=True
246
+ # )
247
+ #
248
+ # documents = loader.load()
249
+
250
+ print(f'We have {len(documents)} pages in total')
251
+
252
+ documents[8].page_content
253
+
254
+ """Splitter"""
255
+
256
+ text_splitter = RecursiveCharacterTextSplitter(
257
+ chunk_size = CFG.split_chunk_size,
258
+ chunk_overlap = CFG.split_overlap
259
+ )
260
+
261
+ texts = text_splitter.split_documents(documents)
262
+
263
+ print(f'We have created {len(texts)} chunks from {len(documents)} pages')
264
+
265
+ """Create Embeddings"""
266
+
267
+ # Commented out IPython magic to ensure Python compatibility.
268
+ # %%time
269
+ # from langchain.embeddings.huggingface import HuggingFaceEmbeddings
270
+ #
271
+ # vectordb = FAISS.from_documents(
272
+ # texts,
273
+ # HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')
274
+ # )
275
+ #
276
+ # ### persist vector database
277
+ # vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag") # save in output folder
278
+ # # vectordb.save_local(f"{CFG.Embeddings_path}/faiss_index_hp") # save in input folder
279
+ #
280
+ # clear_output()
281
+
282
+ """Prompt Template"""
283
+
284
+ prompt_template = """
285
+ Don't try to make up an answer, if you don't know just say that you don't know.
286
+ Answer in the same language the question was asked.
287
+ Use only the following pieces of context to answer the question at the end.
288
+
289
+ {context}
290
+
291
+ Question: {question}
292
+ Answer:"""
293
+
294
+
295
+ PROMPT = PromptTemplate(
296
+ template = prompt_template,
297
+ input_variables = ["context", "question"]
298
+ )
299
+
300
+ """Retriever chain"""
301
+
302
+ retriever = vectordb.as_retriever(search_kwargs = {"k": CFG.k, "search_type" : "similarity"})
303
+
304
+ qa_chain = RetrievalQA.from_chain_type(
305
+ llm = llm,
306
+ chain_type = "stuff", # map_reduce, map_rerank, stuff, refine
307
+ retriever = retriever,
308
+ chain_type_kwargs = {"prompt": PROMPT},
309
+ return_source_documents = True,
310
+ verbose = False
311
+ )
312
+
313
+ question = "what are the structural organization of a human body"
314
+ vectordb.max_marginal_relevance_search(question, k = CFG.k)
315
+
316
+ ### testing similarity search
317
+ question = "what are the structural organization of a human body"
318
+ vectordb.similarity_search(question, k = CFG.k)
319
+
320
+ """Post-process outputs"""
321
+
322
+ def wrap_text_preserve_newlines(text, width=700):
323
+ # Split the input text into lines based on newline characters
324
+ lines = text.split('\n')
325
+
326
+ # Wrap each line individually
327
+ wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
328
+
329
+ # Join the wrapped lines back together using newline characters
330
+ wrapped_text = '\n'.join(wrapped_lines)
331
+
332
+ return wrapped_text
333
+
334
+
335
+ def process_llm_response(llm_response):
336
+ ans = wrap_text_preserve_newlines(llm_response['result'])
337
+
338
+ sources_used = ' \n'.join(
339
+ [
340
+ source.metadata['source'].split('/')[-1][:-4]
341
+ + ' - page: '
342
+ + str(source.metadata['page'])
343
+ for source in llm_response['source_documents']
344
+ ]
345
+ )
346
+
347
+ ans = ans + '\n\nSources: \n' + sources_used
348
+ return ans
349
+
350
+ def llm_ans(query):
351
+ start = time.time()
352
+
353
+ llm_response = qa_chain.invoke(query)
354
+ ans = process_llm_response(llm_response)
355
+
356
+ end = time.time()
357
+
358
+ time_elapsed = int(round(end - start, 0))
359
+ time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
360
+ return ans + time_elapsed_str
361
+
362
+ query =question = "what are the structural organization of a human body"
363
+ print(llm_ans(query))
364
+
365
+ """Gradio Chat UI (Inspired from HinePo)"""
366
+
367
+ import locale
368
+ locale.getpreferredencoding = lambda: "UTF-8"
369
+
370
+ ! pip install --upgrade gradio -qq
371
+ clear_output()
372
+
373
+ def predict(message, history):
374
+ # output = message # debug mode
375
+
376
+ output = str(llm_ans(message)).replace("\n", "<br/>")
377
+ return output
378
+
379
+ demo = gr.ChatInterface(
380
+ predict,
381
+ title = f' Open-Source LLM ({CFG.model_name}) Question Answering'
382
+ )
383
+
384
+ demo.queue()
385
+ demo.launch()
386
+
387
+ import gradio as gr
388
+ print(gr.__version__)