Spaces:

arjunanand13 commited on
Commit
8d5368b
1 Parent(s): b797353

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -93
app.py CHANGED
@@ -1,16 +1,20 @@
 
 
 
 
1
  import concurrent.futures
2
- import threading
 
 
 
 
3
  import torch
 
 
4
  from datetime import datetime
5
  import json
6
  import gradio as gr
7
- import re
8
- import faiss
9
- import numpy as np
10
- from sentence_transformers import SentenceTransformer
11
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
12
- from langchain.document_loaders import DirectoryLoader, TextLoader
13
- from langchain.text_splitter import RecursiveCharacterTextSplitter
14
 
15
  class DocumentRetrievalAndGeneration:
16
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
@@ -18,7 +22,6 @@ class DocumentRetrievalAndGeneration:
18
  self.embeddings = SentenceTransformer(embedding_model_name)
19
  self.gpu_index = self.create_faiss_index()
20
  self.llm = self.initialize_llm(lm_model_id)
21
- self.cancel_flag = threading.Event()
22
 
23
  def load_documents(self, folder_path):
24
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
@@ -47,6 +50,7 @@ class DocumentRetrievalAndGeneration:
47
  bnb_4bit_quant_type="nf4",
48
  bnb_4bit_compute_dtype=torch.bfloat16
49
  )
 
50
  model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
51
  tokenizer = AutoTokenizer.from_pretrained(model_id)
52
  generate_text = pipeline(
@@ -60,85 +64,84 @@ class DocumentRetrievalAndGeneration:
60
  return generate_text
61
 
62
  def generate_response_with_timeout(self, model_inputs):
63
- def target(future):
64
- if self.cancel_flag.is_set():
65
- return
66
- generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
67
- if not self.cancel_flag.is_set():
68
- future.set_result(generated_ids)
69
- else:
70
- future.set_exception(TimeoutError("Text generation process was canceled"))
71
-
72
- future = concurrent.futures.Future()
73
- thread = threading.Thread(target=target, args=(future,))
74
- thread.start()
75
-
76
  try:
77
- generated_ids = future.result(timeout=60) # Timeout set to 60 seconds
 
 
78
  return generated_ids
79
  except concurrent.futures.TimeoutError:
80
- self.cancel_flag.set()
81
  raise TimeoutError("Text generation process timed out")
 
 
 
 
82
 
83
- def qa_infer_gradio(self, query):
84
- # Set the cancel flag to false for the new query
85
- self.cancel_flag.clear()
86
-
87
- try:
88
- query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
89
- distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5)
90
-
91
- content = ""
92
- for idx in indices[0]:
93
- content += "-" * 50 + "\n"
94
- content += self.all_splits[idx].page_content + "\n"
95
-
96
- prompt = f"""<s>
97
- You are a knowledgeable assistant with access to a comprehensive database.
98
- I need you to answer my question and provide related information in a specific format.
99
- I have provided five relatable json files {content}, choose the most suitable chunks for answering the query
100
- Here's what I need:
101
- Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
102
- content
103
- Here's my question:
104
- Query:{query}
105
- Solution==>
106
- RETURN ONLY SOLUTION . IF THEIR IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS , RETURN " NO SOLUTION AVAILABLE"
107
- Example1
108
- Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
109
- Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.",
110
-
111
- Example2
112
- Query: "Can BQ25896 support I2C interface?",
113
- Solution: "Yes, the BQ25896 charger supports the I2C interface for communication."
114
- </s>
115
- """
116
-
117
- messages = [{"role": "user", "content": prompt}]
118
- encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
119
- model_inputs = encodeds.to(self.llm.model.device)
120
-
121
- start_time = datetime.now()
122
- generated_ids = self.generate_response_with_timeout(model_inputs)
123
- elapsed_time = datetime.now() - start_time
124
-
125
- decoded = self.llm.tokenizer.batch_decode(generated_ids)
126
- generated_response = decoded[0]
127
 
128
- match = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE)
129
- if match:
130
- solution_text = match.group(1).strip()
131
- else:
132
- solution_text = "NO SOLUTION AVAILABLE"
 
 
 
 
 
133
 
134
- print("Generated response:", generated_response)
135
- print("Time elapsed:", elapsed_time)
136
- print("Device in use:", self.llm.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- return solution_text, content
139
 
140
- except TimeoutError:
141
- return "timeout", content
 
142
 
143
  if __name__ == "__main__":
144
  # Example usage
@@ -148,7 +151,8 @@ if __name__ == "__main__":
148
 
149
  doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
150
 
151
- # Define Gradio interface function
 
152
  def launch_interface():
153
  css_code = """
154
  .gradio-container {
@@ -164,21 +168,23 @@ if __name__ == "__main__":
164
  font-size: 16px; /* Increase font size */
165
  font-weight: bold; /* Make text bold */
166
  }
167
- """
168
- EXAMPLES = ["Can the VIP and CSI2 modules operate simultaneously?",
169
- "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
170
- "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"]
171
-
 
 
172
  file_path = "ticketNames.txt"
173
-
174
  # Read the file content
175
  with open(file_path, "r") as file:
176
  content = file.read()
177
  ticket_names = json.loads(content)
178
  dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names)
179
-
180
- # Define Gradio interface
181
- interface = gr.Interface(
182
  fn=doc_retrieval_gen.qa_infer_gradio,
183
  inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
184
  allow_flagging='never',
@@ -187,9 +193,69 @@ if __name__ == "__main__":
187
  outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")],
188
  css=css_code
189
  )
190
-
191
- # Launch Gradio interface
192
- interface.launch(debug=True)
193
-
 
 
 
 
 
 
 
 
 
 
 
 
194
  # Launch the interface
195
  launch_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Single Thread"
2
+
3
+ import os
4
+ import multiprocessing
5
  import concurrent.futures
6
+ from langchain.document_loaders import TextLoader, DirectoryLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.vectorstores import FAISS
9
+ from sentence_transformers import SentenceTransformer
10
+ import faiss
11
  import torch
12
+ import numpy as np
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
14
  from datetime import datetime
15
  import json
16
  import gradio as gr
17
+ import re
 
 
 
 
 
 
18
 
19
  class DocumentRetrievalAndGeneration:
20
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
 
22
  self.embeddings = SentenceTransformer(embedding_model_name)
23
  self.gpu_index = self.create_faiss_index()
24
  self.llm = self.initialize_llm(lm_model_id)
 
25
 
26
  def load_documents(self, folder_path):
27
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
 
50
  bnb_4bit_quant_type="nf4",
51
  bnb_4bit_compute_dtype=torch.bfloat16
52
  )
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
55
  tokenizer = AutoTokenizer.from_pretrained(model_id)
56
  generate_text = pipeline(
 
64
  return generate_text
65
 
66
  def generate_response_with_timeout(self, model_inputs):
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
+ with concurrent.futures.ThreadPoolExecutor() as executor:
69
+ future = executor.submit(self.llm.model.generate, model_inputs, max_new_tokens=1000, do_sample=True)
70
+ generated_ids = future.result(timeout=60) # Timeout set to 60 seconds
71
  return generated_ids
72
  except concurrent.futures.TimeoutError:
 
73
  raise TimeoutError("Text generation process timed out")
74
+
75
+ def query_and_generate_response(self, query):
76
+ query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
77
+ distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5)
78
 
79
+ content = ""
80
+ for idx in indices[0]:
81
+ content += "-" * 50 + "\n"
82
+ content += self.all_splits[idx].page_content + "\n"
83
+ print("CHUNK", idx)
84
+ print(self.all_splits[idx].page_content)
85
+ print("############################")
86
+ prompt = f"""<s>
87
+ You are a knowledgeable assistant with access to a comprehensive database.
88
+ I need you to answer my question and provide related information in a specific format.
89
+ I have provided five relatable json files {content}, choose the most suitable chunks for answering the query
90
+ Here's what I need:
91
+ Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
92
+ content
93
+ Here's my question:
94
+ Query:{query}
95
+ Solution==>
96
+ RETURN ONLY SOLUTION . IF THEIR IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS , RETURN " NO SOLUTION AVAILABLE"
97
+ Example1
98
+ Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
99
+ Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.",
100
+
101
+ Example2
102
+ Query: "Can BQ25896 support I2C interface?",
103
+ Solution: "Yes, the BQ25896 charger supports the I2C interface for communication."
104
+ </s>
105
+ """
106
+ # prompt = f"Query: {query}\nSolution: {content}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # Encode and prepare inputs
109
+ messages = [{"role": "user", "content": prompt}]
110
+ encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
111
+ model_inputs = encodeds.to(self.llm.device)
112
+
113
+ # Perform inference and measure time
114
+ start_time = datetime.now()
115
+ generated_ids = self.generate_response_with_timeout(model_inputs)
116
+ # generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
117
+ elapsed_time = datetime.now() - start_time
118
 
119
+ # Decode and return output
120
+ decoded = self.llm.tokenizer.batch_decode(generated_ids)
121
+ generated_response = decoded[0]
122
+ match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL)
123
+
124
+ match2 = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE)
125
+ if match1:
126
+ solution_text = match1.group(1).strip()
127
+ print(solution_text)
128
+ if "Solution:" in solution_text:
129
+ solution_text = solution_text.split("Solution:", 1)[1].strip()
130
+ elif match2:
131
+ solution_text = match2.group(1).strip()
132
+ print(solution_text)
133
+
134
+ else:
135
+ solution_text=generated_response
136
+ print("Generated response:", generated_response)
137
+ print("Time elapsed:", elapsed_time)
138
+ print("Device in use:", self.llm.device)
139
 
140
+ return solution_text, content
141
 
142
+ def qa_infer_gradio(self, query):
143
+ response = self.query_and_generate_response(query)
144
+ return response
145
 
146
  if __name__ == "__main__":
147
  # Example usage
 
151
 
152
  doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
153
 
154
+ """Dual Interface"""
155
+
156
  def launch_interface():
157
  css_code = """
158
  .gradio-container {
 
168
  font-size: 16px; /* Increase font size */
169
  font-weight: bold; /* Make text bold */
170
  }
171
+ """
172
+ EXAMPLES = [
173
+ "On which devices can the VIP and CSI2 modules operate simultaneously?",
174
+ "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
175
+ "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"
176
+ ]
177
+
178
  file_path = "ticketNames.txt"
179
+
180
  # Read the file content
181
  with open(file_path, "r") as file:
182
  content = file.read()
183
  ticket_names = json.loads(content)
184
  dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names)
185
+
186
+ # Define Gradio interfaces
187
+ tab1 = gr.Interface(
188
  fn=doc_retrieval_gen.qa_infer_gradio,
189
  inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
190
  allow_flagging='never',
 
193
  outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")],
194
  css=css_code
195
  )
196
+ tab2 = gr.Interface(
197
+ fn=doc_retrieval_gen.qa_infer_gradio,
198
+ inputs=[dropdown],
199
+ allow_flagging='never',
200
+ outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")],
201
+ css=css_code
202
+ )
203
+
204
+ # Combine interfaces into a tabbed interface
205
+ gr.TabbedInterface(
206
+ [tab1, tab2],
207
+ ["Textbox Input", "FAQs"],
208
+ title="TI E2E FORUM",
209
+ css=css_code
210
+ ).launch(debug=True)
211
+
212
  # Launch the interface
213
  launch_interface()
214
+
215
+
216
+
217
+ """Single Interface"""
218
+ # def launch_interface():
219
+ # css_code = """
220
+ # .gradio-container {
221
+ # background-color: #daccdb;
222
+ # }
223
+ # /* Button styling for all buttons */
224
+ # button {
225
+ # background-color: #927fc7; /* Default color for all other buttons */
226
+ # color: black;
227
+ # border: 1px solid black;
228
+ # padding: 10px;
229
+ # margin-right: 10px;
230
+ # font-size: 16px; /* Increase font size */
231
+ # font-weight: bold; /* Make text bold */
232
+ # }
233
+ # """
234
+ # EXAMPLES = ["On which devices can the VIP and CSI2 modules operate simultaneously? ",
235
+ # "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
236
+ # "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"]
237
+
238
+ # file_path = "ticketNames.txt"
239
+
240
+ # # Read the file content
241
+ # with open(file_path, "r") as file:
242
+ # content = file.read()
243
+ # ticket_names = json.loads(content)
244
+ # dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names)
245
+
246
+ # # Define Gradio interface
247
+ # interface = gr.Interface(
248
+ # fn=doc_retrieval_gen.qa_infer_gradio,
249
+ # inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
250
+ # allow_flagging='never',
251
+ # examples=EXAMPLES,
252
+ # cache_examples=False,
253
+ # outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")],
254
+ # css=css_code
255
+ # )
256
+
257
+ # # Launch Gradio interface
258
+ # interface.launch(debug=True)
259
+
260
+ # # Launch the interface
261
+ # launch_interface()