zainnobody commited on
Commit
c9fb0e9
1 Parent(s): 9dbe43a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +382 -0
app.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+ from sentence_transformers import SentenceTransformer, util
4
+ from sklearn.feature_extraction.text import TfidfVectorizer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import re
7
+ import traceback
8
+ import torch
9
+ import os
10
+ from sentence_transformers import SentenceTransformer, util
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ import re
14
+ import pandas as pd
15
+ import json
16
+
17
+
18
+ # Preprocessing text by lowercasing, removing punctuation, and extra spaces
19
+ def optimized_preprocess_text(text):
20
+ text = text.lower()
21
+ text = re.sub(r'[^\w\s]', '', text)
22
+ text = re.sub(r'\s+', ' ', text).strip()
23
+ return text
24
+
25
+ # Compute cosine similarity between two texts using TF-IDF
26
+ def optimized_compute_text_similarity(text1, text2):
27
+ tfidf = TfidfVectorizer(stop_words='english', ngram_range=(1, 1))
28
+ tfidf_matrix = tfidf.fit_transform([text1, text2])
29
+ cosine_sim = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2]).flatten()
30
+ return cosine_sim[0]
31
+
32
+ # Compute SBERT similarity between question and context
33
+ def compute_sbert_similarity(question, context, model):
34
+ embeddings = model.encode([question, context], convert_to_tensor=True)
35
+ similarity = util.pytorch_cos_sim(embeddings[0], embeddings[1]).item()
36
+ return similarity
37
+
38
+ # Use hybrid approach: TF-IDF to narrow down top N contexts, then SBERT for refined similarity
39
+ def hybrid_sbert_approach(question, filtered_contexts, model, top_n=10):
40
+ tfidf = TfidfVectorizer(stop_words='english')
41
+ contexts_combined = [question] + filtered_contexts
42
+ tfidf_matrix = tfidf.fit_transform(contexts_combined)
43
+
44
+ # Calculate TF-IDF similarity and rank contexts
45
+ similarity_scores = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:]).flatten()
46
+ ranked_contexts = [filtered_contexts[i] for i in similarity_scores.argsort()[::-1][:top_n]]
47
+
48
+ # Refine using SBERT
49
+ sbert_similarities = [compute_sbert_similarity(question, context, model) for context in ranked_contexts]
50
+ ranked_by_sbert = sorted(zip(ranked_contexts, sbert_similarities), key=lambda x: x[1], reverse=True)
51
+
52
+ return [context for context, _ in ranked_by_sbert]
53
+
54
+ # RAG with optimized SBERT function
55
+ def optimized_generate_rag_context(question, filtered_contexts, selected_context_window=2):
56
+ hybrid_retrieved_contexts = hybrid_sbert_approach(question, filtered_contexts, sbert_model, top_n=int(selected_context_window))
57
+ rag_context = "\n".join(hybrid_retrieved_contexts[:selected_context_window])
58
+ return rag_context
59
+
60
+ # Extract unique contexts and filter them by length
61
+ def extract_and_filter_contexts(data, min_length=151, max_length=3706):
62
+ unique_contexts = data['context'].unique()
63
+ filtered_contexts = [context for context in unique_contexts if min_length <= len(context) <= max_length]
64
+ return filtered_contexts
65
+
66
+ # Compute the TF-IDF matrix for the question and contexts
67
+ def compute_tfidf_and_similarity_scores(question, contexts):
68
+ tfidf = TfidfVectorizer(stop_words='english')
69
+ contexts_combined = [question] + contexts
70
+ tfidf_matrix = tfidf.fit_transform(contexts_combined)
71
+
72
+ # Calculate the cosine similarity scores
73
+ similarity_scores = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:]).flatten()
74
+ return tfidf_matrix, similarity_scores
75
+
76
+ # Rank contexts based on similarity scores
77
+ def rank_contexts_by_similarity(contexts, similarity_scores):
78
+ ranked_indices = similarity_scores.argsort()[::-1]
79
+ ranked_contexts = [contexts[i] for i in ranked_indices]
80
+ ranked_scores = similarity_scores[ranked_indices]
81
+ return ranked_contexts, ranked_scores
82
+
83
+ # Select the top contexts based on the selected window
84
+ def select_top_contexts(selected_context_window, ranked_contexts, ranked_scores):
85
+ count = int(selected_context_window)
86
+ top_contexts = ranked_contexts[:count]
87
+ top_scores = ranked_scores[:count]
88
+ return top_contexts, top_scores
89
+
90
+
91
+ # Helper function to maintain chat history and generate the response
92
+ def maintain_chat_history(message, chat_history):
93
+ if chat_history is None:
94
+ chat_history = []
95
+ chat_history.append({"role": "user", "content": message})
96
+ return chat_history
97
+
98
+ def generate_rag_context(question, filtered_contexts, selected_context_window = 3):
99
+ tfidf_matrix, similarity_scores = compute_tfidf_and_similarity_scores(question, filtered_contexts)
100
+ ranked_contexts, ranked_scores = rank_contexts_by_similarity(filtered_contexts, similarity_scores)
101
+ top_contexts, top_scores = select_top_contexts(str(selected_context_window), ranked_contexts, ranked_scores)
102
+ rag_context = "\n".join(top_contexts)
103
+ return rag_context
104
+
105
+ def load_squad_data(filepath):
106
+ with open(filepath, 'r') as f:
107
+ squad_data = json.load(f)
108
+ return squad_data
109
+
110
+
111
+
112
+ # Preprocess the data: extract contexts, questions, and answers from the SQuAD data
113
+ def raw_preprocess_data(squad_data):
114
+ contexts = []
115
+ questions = []
116
+ answers = []
117
+
118
+ for group in squad_data['data']:
119
+ for passage in group['paragraphs']:
120
+ context = passage['context']
121
+ for qa in passage['qas']:
122
+ question = qa['question']
123
+ for answer in qa['answers']:
124
+ contexts.append(context)
125
+ questions.append(question)
126
+ # Make a copy to avoid modifying the original answer
127
+ answers.append({
128
+ 'text': answer['text'],
129
+ 'answer_start': answer['answer_start']
130
+ })
131
+
132
+ return contexts, questions, answers
133
+
134
+
135
+ # Add the end index of the answer in the context
136
+ def add_end_idx(answers, contexts):
137
+ for answer, context in zip(answers, contexts):
138
+ gold_text = answer['text']
139
+ start_idx = answer['answer_start']
140
+ end_idx = start_idx + len(gold_text)
141
+
142
+ if context[start_idx:end_idx] == gold_text:
143
+ answer['answer_end'] = end_idx
144
+ else:
145
+ # Try to find the correct position if there's a mismatch
146
+ for n in range(1, 30):
147
+ if context[start_idx - n:end_idx - n] == gold_text:
148
+ answer['answer_start'] = start_idx - n
149
+ answer['answer_end'] = end_idx - n
150
+ break
151
+ elif context[start_idx + n:end_idx + n] == gold_text:
152
+ answer['answer_start'] = start_idx + n
153
+ answer['answer_end'] = end_idx + n
154
+ break
155
+ else:
156
+ answer['answer_start'] = -1
157
+ answer['answer_end'] = -1
158
+
159
+
160
+ # Create a DataFrame from the contexts, questions, and answers
161
+ def create_dataframe(contexts, questions, answers):
162
+ data = pd.DataFrame({
163
+ 'context': contexts,
164
+ 'question': questions,
165
+ 'answer_text': [answer['text'] for answer in answers],
166
+ 'answer_start': [answer['answer_start'] for answer in answers],
167
+ 'answer_end': [answer.get('answer_end', -1) for answer in answers]
168
+ })
169
+
170
+ # Remove samples with -1 start index
171
+ data = data[data['answer_start'] != -1].reset_index(drop=True)
172
+ return data
173
+
174
+ # Check if a GPU (CUDA) is available; otherwise, use the CPU
175
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
176
+
177
+
178
+ # Loading the pre-trained SBERT model globally for efficiency
179
+ sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
180
+
181
+ # Available models
182
+ electra_models = [
183
+ "./models/fine_tuned_electra_model_1000",
184
+ "./models/fine_tuned_electra_model_20000",
185
+ "./models/fine_tuned_electra_model_5000",
186
+ "./models/fine_tuned_electra_model_all"
187
+ ]
188
+ other_models = [
189
+ "./models/fine_tuned_bert_base_cased_1000",
190
+ "./models/fine_tuned_bert_base_cased_all",
191
+ "./models/fine_tuned_distilbert_base_uncased_10000",
192
+ "./models/fine_tuned_distilgpt2_10000",
193
+ "./models/fine_tuned_retro-reader_intensive_1000",
194
+ "./models/fine_tuned_retro-reader_intensive_5000",
195
+ "./models/fine_tuned_retro-reader_sketchy_1000"
196
+ ]
197
+
198
+ DATA_DIR = './data'
199
+
200
+ # Load and preprocess data
201
+ squad_data = load_squad_data(DATA_DIR+ '/train-v1.1.json')
202
+ contexts, questions, answers = raw_preprocess_data(squad_data)
203
+ add_end_idx(answers, contexts)
204
+ data = create_dataframe(contexts, questions, answers)
205
+
206
+ # Function to generate a response with logging and custom content
207
+ def generate_response(message, chat_history, model_name, debug, rag, selected_context_window):
208
+ try:
209
+ if chat_history is None:
210
+ chat_history = []
211
+ context = message
212
+
213
+ # Determine if the model is for question answering based on its name
214
+ is_question_answering = "electra_model" in model_name
215
+
216
+ # Initialize the tokenizer and model
217
+ if is_question_answering:
218
+ model = pipeline("question-answering", model=model_name, tokenizer=model_name, device=device)
219
+ else:
220
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
221
+ model = AutoModelForCausalLM.from_pretrained(model_name)
222
+ model.to(device)
223
+
224
+ # Append the new user message to the chat history
225
+ chat_history.append({"role": "user", "content": message})
226
+
227
+ if is_question_answering:
228
+ if rag:
229
+ filtered_contexts = extract_and_filter_contexts(data, min_length=100, max_length=4000)
230
+ context = generate_rag_context(message, filtered_contexts, selected_context_window)
231
+ else:
232
+ context = "\n".join([turn["content"] for turn in chat_history if turn["role"] == "user"])
233
+
234
+ if debug:
235
+ print("context:\n" + context)
236
+ print("message:\n" + message)
237
+
238
+ # Call the pipeline for question-answering
239
+ answer = model(question=message, context=context)
240
+ response = answer['answer']
241
+
242
+ else:
243
+ # Prepare the conversation history for a regular chatbot
244
+ conversation = ""
245
+ for turn in chat_history:
246
+ if turn["role"] == "user":
247
+ conversation += f"User: {turn['content']}\n"
248
+ else:
249
+ conversation += f"Assistant: {turn['content']}\n"
250
+
251
+ if debug:
252
+ print("Conversation being sent to the model:\n", conversation)
253
+
254
+ # Encode the input and generate a response
255
+ inputs = tokenizer.encode(conversation + "Assistant:", return_tensors='pt').to(device)
256
+ outputs = model.generate(
257
+ inputs,
258
+ max_length=inputs.shape[1] + 100,
259
+ pad_token_id=tokenizer.eos_token_id,
260
+ do_sample=True,
261
+ top_p=0.95,
262
+ top_k=50,
263
+ temperature=0.7,
264
+ eos_token_id=tokenizer.eos_token_id,
265
+ )
266
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
267
+
268
+ # Extract the assistant's reply
269
+ response = response[len(conversation):].strip()
270
+ if "User:" in response:
271
+ response = response.split("User:")[0].strip()
272
+
273
+ # Append the assistant's response to the chat history
274
+ chat_history.append({"role": "assistant", "content": response})
275
+ if debug:
276
+ print("Generated response:", response)
277
+ print("Configurations:")
278
+ print(f"Model Name: {model_name}")
279
+ print(f"Is Question Answering: {is_question_answering}")
280
+ print(f"RAG Enabled: {rag}")
281
+ print(f"Selected Context Window: {selected_context_window}")
282
+
283
+ # Return the updated chat history and the assistant's response
284
+ display_history = [[turn["content"], chat_history[i + 1]["content"]] for i, turn in enumerate(chat_history[:-1]) if turn["role"] == "user" and i + 1 < len(chat_history)]
285
+ return display_history, chat_history
286
+
287
+ except Exception as e:
288
+ # Capture the traceback details
289
+ error_message = f"An error occurred: {str(e)}"
290
+ detailed_error = traceback.format_exc()
291
+ chat_history.append({"role": "assistant", "content": error_message})
292
+ if debug:
293
+ print("Error Details:\n", detailed_error)
294
+
295
+ # Ensure safe generation of the display history
296
+ try:
297
+ display_history = [[turn["content"], chat_history[i + 1]["content"]] for i, turn in enumerate(chat_history[:-1]) if turn["role"] == "user" and i + 1 < len(chat_history)]
298
+ except Exception as history_error:
299
+ if debug:
300
+ print("Error while generating display history:", str(history_error))
301
+ display_history = []
302
+
303
+ return display_history, chat_history
304
+
305
+ # Gradio Interface Configuration
306
+ def run_prod_chatbot(local=True):
307
+ with gr.Blocks() as demo:
308
+ gr.Markdown("""
309
+ <div style="text-align: center;">
310
+ <h1><strong>SQuAD Q&A ChatBot</strong></h1>
311
+ <h3>Authors: <a href="https://github.com/zainnobody">Zain Ali</a> & <a href="https://github.com/AIBenHopwood/">Ben Hopwood</a></h3>
312
+ <p>
313
+ <a href="https://github.com/zainnobody/AAI-520-Final-Project" target="_blank">Code: GitHub link</a> &nbsp;|&nbsp;
314
+ <a href="https://huggingface.co/zainnobody/AAI-520-Final-Project-Models" target="_blank">Models: Huggingface link</a>
315
+ </p>
316
+ </div>
317
+
318
+ <div style="text-align: center;">
319
+ <p>
320
+ This project aims to develop a chatbot capable of multi-turn, context-adaptive conversations across various topics, using the Stanford Question Answering Dataset (SQuAD) as the primary source for training.
321
+ </p>
322
+ </div>
323
+
324
+ <div style="text-align: center;">
325
+ <h4>University of San Diego - AAI 520</h4>
326
+ </div>
327
+
328
+ """)
329
+ with gr.Row(variant="compact"):
330
+ model_dropdown = gr.Dropdown(
331
+ choices=electra_models + other_models,
332
+ label="Select Model",
333
+ value="./models/fine_tuned_electra_model_all"
334
+ )
335
+ # Column for Use RAG and Debug Mode checkboxes
336
+ with gr.Column():
337
+ rag_checkbox = gr.Checkbox(
338
+ label="Use RAG",
339
+ value=True,
340
+ interactive=True
341
+ )
342
+ debug_checkbox = gr.Checkbox(
343
+ label="Debug Mode",
344
+ value=False
345
+ )
346
+ context_window_dropdown = gr.Dropdown(
347
+ choices=[1, 2, 3],
348
+ label="Select Context Window",
349
+ value=1
350
+ )
351
+
352
+ # Commented out the is_question_answering_checkbox, making it auto detectable. Leaving this as a reminder that other models do not use pipeline
353
+ # is_question_answering_checkbox = gr.Checkbox(
354
+ # label="Use Question Answering (Electra Only)",
355
+ # value=True
356
+ # )
357
+
358
+ chatbot = gr.Chatbot()
359
+ state = gr.State([])
360
+
361
+ with gr.Row():
362
+ # Textbox taking 75% of the space
363
+ msg = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter", scale=3)
364
+ # Send button taking 25% of the space and stretching full width
365
+ send_btn = gr.Button("Send", scale=1)
366
+
367
+
368
+
369
+ send_btn.click(lambda message, chat_history, model_name, debug, rag, selected_context_window: generate_response(message, chat_history, model_name, debug, rag, selected_context_window),
370
+ inputs=[msg, state, model_dropdown, debug_checkbox, rag_checkbox, context_window_dropdown],
371
+ outputs=[chatbot, state])
372
+ msg.submit(lambda message, chat_history, model_name, debug, rag, selected_context_window: generate_response(message, chat_history, model_name, debug, rag, selected_context_window),
373
+ inputs=[msg, state, model_dropdown, debug_checkbox, rag_checkbox, context_window_dropdown],
374
+ outputs=[chatbot, state])
375
+
376
+ if local:
377
+ demo.launch(share=True)
378
+ else:
379
+ demo.launch(server_name="0.0.0.0", server_port=None)
380
+
381
+ # Launch the Gradio app
382
+ run_prod_chatbot()