HemaMeena commited on
Commit
e2d159d
·
verified ·
1 Parent(s): 7726cbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -60
app.py CHANGED
@@ -1,64 +1,235 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
1
  import gradio as gr
2
+ import time
3
+ import os
4
+ import glob
5
+ import textwrap
6
+ import torch
7
+ from transformers import (
8
+ AutoTokenizer, AutoModelForCausalLM,
9
+ BitsAndBytesConfig,
10
+ pipeline
11
+ )
12
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.llms import HuggingFacePipeline
16
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
17
+ from langchain.chains import RetrievalQA
18
+ from langchain.prompts import PromptTemplate
19
+
20
+ # Configuration class
21
+ class CFG:
22
+ # LLMs
23
+ model_name = 'llama2-13b-chat' # wizardlm, llama2-7b-chat, llama2-13b-chat, mistral-7B
24
+ temperature = 0
25
+ top_p = 0.95
26
+ repetition_penalty = 1.15
27
+
28
+ # splitting
29
+ split_chunk_size = 800
30
+ split_overlap = 0
31
+
32
+ # embeddings
33
+ embeddings_model_repo = 'sentence-transformers/all-MiniLM-L6-v2'
34
+
35
+ # similar passages
36
+ k = 6
37
+
38
+ # paths
39
+ PDFs_path = './' # Set to your PDF path
40
+ Embeddings_path = './faiss-hp-sentence-transformers'
41
+ Output_folder = './rag-vectordb'
42
+
43
+ # Set preferred encoding to UTF-8 (for non-ASCII characters)
44
+ import locale
45
+ locale.getpreferredencoding = lambda: "UTF-8"
46
+
47
+ # Function to get model
48
+ def get_model(model = CFG.model_name):
49
+ print('\nDownloading model: ', model, '\n\n')
50
+
51
+ if model == 'wizardlm':
52
+ model_repo = 'TheBloke/wizardLM-7B-HF'
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
55
+ bnb_config = BitsAndBytesConfig(
56
+ load_in_4bit=True,
57
+ bnb_4bit_quant_type="nf4",
58
+ bnb_4bit_compute_dtype=torch.float16,
59
+ bnb_4bit_use_double_quant=True,
60
+ )
61
+
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ model_repo,
64
+ quantization_config=bnb_config,
65
+ device_map='auto',
66
+ low_cpu_mem_usage=True
67
+ )
68
+
69
+ max_len = 1024
70
+
71
+ elif model == 'llama2-7b-chat':
72
+ model_repo = 'daryl149/llama-2-7b-chat-hf'
73
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
74
+
75
+ bnb_config = BitsAndBytesConfig(
76
+ load_in_4bit=True,
77
+ bnb_4bit_quant_type="nf4",
78
+ bnb_4bit_compute_dtype=torch.float16,
79
+ bnb_4bit_use_double_quant=True,
80
+ )
81
+
82
+ model = AutoModelForCausalLM.from_pretrained(
83
+ model_repo,
84
+ quantization_config=bnb_config,
85
+ device_map='auto',
86
+ low_cpu_mem_usage=True,
87
+ trust_remote_code=True
88
+ )
89
+
90
+ max_len = 2048
91
+
92
+ elif model == 'llama2-13b-chat':
93
+ model_repo = 'daryl149/llama-2-13b-chat-hf'
94
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_fast=True)
95
+
96
+ bnb_config = BitsAndBytesConfig(
97
+ load_in_4bit=True,
98
+ bnb_4bit_quant_type="nf4",
99
+ bnb_4bit_compute_dtype=torch.float16,
100
+ bnb_4bit_use_double_quant=True,
101
+ )
102
+
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ model_repo,
105
+ quantization_config=bnb_config,
106
+ low_cpu_mem_usage=True,
107
+ trust_remote_code=True
108
+ )
109
+
110
+ max_len = 2048
111
+
112
+ else:
113
+ print("Model not implemented!")
114
+
115
+ return tokenizer, model, max_len
116
+
117
+ # Get the model
118
+ tokenizer, model, max_len = get_model(CFG.model_name)
119
+
120
+ # Set up Hugging Face pipeline
121
+ pipe = pipeline(
122
+ task="text-generation",
123
+ model=model,
124
+ tokenizer=tokenizer,
125
+ pad_token_id=tokenizer.eos_token_id,
126
+ max_length=max_len,
127
+ temperature=CFG.temperature,
128
+ top_p=CFG.top_p,
129
+ repetition_penalty=CFG.repetition_penalty
130
  )
131
 
132
+ # Langchain pipeline
133
+ llm = HuggingFacePipeline(pipeline=pipe)
134
+
135
+ # Load the documents
136
+ loader = DirectoryLoader(
137
+ CFG.PDFs_path,
138
+ glob="./*.pdf",
139
+ loader_cls=PyPDFLoader,
140
+ show_progress=True,
141
+ use_multithreading=True
142
+ )
143
+ documents = loader.load()
144
+
145
+ # Split the documents
146
+ text_splitter = RecursiveCharacterTextSplitter(
147
+ chunk_size=CFG.split_chunk_size,
148
+ chunk_overlap=CFG.split_overlap
149
+ )
150
+ texts = text_splitter.split_documents(documents)
151
+
152
+ # Set up vector store
153
+ vectordb = FAISS.from_documents(
154
+ texts,
155
+ HuggingFaceInstructEmbeddings(model_name=CFG.embeddings_model_repo)
156
+ )
157
+
158
+ # Save the vector store
159
+ vectordb.save_local(f"{CFG.Output_folder}/faiss_index_rag")
160
+
161
+ # Define the prompt template
162
+ prompt_template = """
163
+ Don't try to make up an answer, if you don't know just say that you don't know.
164
+ Answer in the same language the question was asked.
165
+ Use only the following pieces of context to answer the question at the end.
166
+
167
+ {context}
168
+
169
+ Question: {question}
170
+ Answer:"""
171
+
172
+ PROMPT = PromptTemplate(
173
+ template=prompt_template,
174
+ input_variables=["context", "question"]
175
+ )
176
+
177
+ # Set up retriever
178
+ retriever = vectordb.as_retriever(search_kwargs={"k": CFG.k, "search_type": "similarity"})
179
+
180
+ # Create the retrieval-based QA chain
181
+ qa_chain = RetrievalQA.from_chain_type(
182
+ llm=llm,
183
+ chain_type="stuff", # other options: "map_reduce", "map_rerank", "refine"
184
+ retriever=retriever,
185
+ chain_type_kwargs={"prompt": PROMPT},
186
+ return_source_documents=True,
187
+ verbose=False
188
+ )
189
+
190
+ # Function to wrap text for proper display
191
+ def wrap_text_preserve_newlines(text, width=700):
192
+ lines = text.split('\n')
193
+ wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
194
+ wrapped_text = '\n'.join(wrapped_lines)
195
+ return wrapped_text
196
+
197
+ # Function to process model response
198
+ def process_llm_response(llm_response):
199
+ ans = wrap_text_preserve_newlines(llm_response['result'])
200
+ sources_used = ' \n'.join(
201
+ [
202
+ source.metadata['source'].split('/')[-1][:-4]
203
+ + ' - page: '
204
+ + str(source.metadata['page'])
205
+ for source in llm_response['source_documents']
206
+ ]
207
+ )
208
+ ans = ans + '\n\nSources: \n' + sources_used
209
+ return ans
210
+
211
+ # Function to get the answer from the model
212
+ def llm_ans(query):
213
+ start = time.time()
214
+ llm_response = qa_chain.invoke(query)
215
+ ans = process_llm_response(llm_response)
216
+ end = time.time()
217
+
218
+ time_elapsed = int(round(end - start, 0))
219
+ time_elapsed_str = f'\n\nTime elapsed: {time_elapsed} s'
220
+ return ans + time_elapsed_str
221
+
222
+ # Function for Gradio chat interface
223
+ def predict(message, history):
224
+ output = str(llm_ans(message)).replace("\n", "<br/>")
225
+ return output
226
+
227
+ # Set up Gradio interface
228
+ demo = gr.ChatInterface(
229
+ fn=predict,
230
+ title=f'Open-Source LLM ({CFG.model_name}) Question Answering'
231
+ )
232
 
233
+ # Start the Gradio interface
234
+ demo.queue()
235
+ demo.launch()