talhasarit41 commited on
Commit
2a1307e
·
verified ·
1 Parent(s): 07af763

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pickle
3
+ import fasttext
4
+ import numpy as np
5
+ import os
6
+ import torch
7
+ import time
8
+ from transformers import AutoTokenizer, AutoModel
9
+ import torch.nn.functional as F
10
+ from openai import AzureOpenAI
11
+
12
+ # Azure OpenAI Configuration
13
+ AZURE_OPENAI_EMBEDDING_ENDPOINT = ...
14
+ AZURE_API_VERSION = "2024-02-01"
15
+ AZURE_OPENAI_API_KEY = ...
16
+
17
+ # Model directory
18
+ MODEL_DIR = "saved_models_synthetic"
19
+
20
+ # Initialize Azure OpenAI client
21
+ azure_client = AzureOpenAI(
22
+ api_key=AZURE_OPENAI_API_KEY,
23
+ api_version=AZURE_API_VERSION,
24
+ azure_endpoint=AZURE_OPENAI_EMBEDDING_ENDPOINT
25
+ )
26
+
27
+ def generate_e5_embedding(text, model_name='intfloat/multilingual-e5-large'):
28
+ """Generate E5 embeddings for a single text."""
29
+ start_time = time.time()
30
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
31
+ model = AutoModel.from_pretrained(model_name)
32
+
33
+ # Add prefix for E5 models
34
+ text = f"query: {text}"
35
+
36
+ # Tokenize and generate embedding
37
+ inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
+
41
+ # Mean pooling
42
+ attention_mask = inputs['attention_mask']
43
+ embeddings = mean_pooling(outputs.last_hidden_state, attention_mask)
44
+ # Normalize embeddings
45
+ embeddings = F.normalize(embeddings, p=2, dim=1)
46
+
47
+ inference_time = time.time() - start_time
48
+ return embeddings[0].numpy(), inference_time
49
+
50
+ def generate_e5_instruct_embedding(text, model_name='intfloat/multilingual-e5-large-instruct'):
51
+ """Generate E5-instruct embeddings for a single text."""
52
+ start_time = time.time()
53
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
54
+ model = AutoModel.from_pretrained(model_name)
55
+
56
+ # Add prefix for E5 models
57
+ text = f"query: {text}"
58
+
59
+ # Tokenize and generate embedding
60
+ inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
61
+ with torch.no_grad():
62
+ outputs = model(**inputs)
63
+
64
+ # Mean pooling
65
+ attention_mask = inputs['attention_mask']
66
+ embeddings = mean_pooling(outputs.last_hidden_state, attention_mask)
67
+ # Normalize embeddings
68
+ embeddings = F.normalize(embeddings, p=2, dim=1)
69
+
70
+ inference_time = time.time() - start_time
71
+ return embeddings[0].numpy(), inference_time
72
+
73
+ def generate_modernbert_embedding(text, model_name="answerdotai/ModernBERT-base"):
74
+ """Generate ModernBERT embeddings for a single text."""
75
+ start_time = time.time()
76
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
77
+ model = AutoModel.from_pretrained(model_name)
78
+
79
+ # Tokenize and generate embedding
80
+ inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
81
+ with torch.no_grad():
82
+ outputs = model(**inputs)
83
+ # Take [CLS] token embedding
84
+ embeddings = outputs.last_hidden_state[:, 0, :]
85
+
86
+ inference_time = time.time() - start_time
87
+ return embeddings[0].numpy(), inference_time
88
+
89
+ def mean_pooling(token_embeddings, attention_mask):
90
+ """Mean pooling function for E5 models."""
91
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
92
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
93
+
94
+ def get_azure_embedding(text):
95
+ """Get embeddings from Azure OpenAI API."""
96
+ start_time = time.time()
97
+ response = azure_client.embeddings.create(
98
+ model="text-embedding-3-large",
99
+ input=text
100
+ )
101
+ inference_time = time.time() - start_time
102
+ return np.array(response.data[0].embedding), inference_time
103
+
104
+ # Load models
105
+ def load_models():
106
+ models = {}
107
+
108
+ # Load pickle models
109
+ with open(os.path.join(MODEL_DIR, 'e5_classifier.pkl'), 'rb') as f:
110
+ models['E5 Classifier'] = pickle.load(f)
111
+
112
+ with open(os.path.join(MODEL_DIR, 'e5_large_instruct_classifier.pkl'), 'rb') as f:
113
+ models['E5-Instruct Classifier'] = pickle.load(f)
114
+
115
+ with open(os.path.join(MODEL_DIR, 'azure_classifier.pkl'), 'rb') as f:
116
+ models['Azure Classifier'] = pickle.load(f)
117
+
118
+ with open(os.path.join(MODEL_DIR, 'azure_knn_classifier.pkl'), 'rb') as f:
119
+ models['Azure KNN Classifier'] = pickle.load(f)
120
+
121
+ with open(os.path.join(MODEL_DIR, 'modernbert_rf_classifier.pkl'), 'rb') as f:
122
+ models['ModernBERT RF Classifier'] = pickle.load(f)
123
+
124
+ with open(os.path.join(MODEL_DIR, 'gte_classifier.pkl'), 'rb') as f:
125
+ models['GTE Classifier'] = pickle.load(f)
126
+
127
+ # Load FastText models
128
+ models['FastText Raw'] = fasttext.load_model(os.path.join(MODEL_DIR, 'fasttext_raw.bin'))
129
+ models['FastText Preprocessed'] = fasttext.load_model(os.path.join(MODEL_DIR, 'fasttext_preprocessed.bin'))
130
+
131
+ return models
132
+
133
+ def format_results(results):
134
+ """Format results into HTML for better visualization."""
135
+ html = "<div style='font-family: monospace; padding: 10px 20px;'>"
136
+ html += "<table style='width: 100%; border-collapse: collapse; background-color: #1a1a1a; color: #ffffff; margin-bottom: 0;'>"
137
+ html += "<tr style='background-color: #2c3e50;'>"
138
+ html += "<th style='padding: 12px; text-align: left; border: 1px solid #34495e;'>Model</th>"
139
+ html += "<th style='padding: 12px; text-align: left; border: 1px solid #34495e;'>Prediction</th>"
140
+ html += "<th style='padding: 12px; text-align: left; border: 1px solid #34495e;'>Confidence</th>"
141
+ html += "<th style='padding: 12px; text-align: left; border: 1px solid #34495e;'>Time (sec)</th>"
142
+ html += "</tr>"
143
+
144
+ for result in results:
145
+ color = get_confidence_color(result['confidence'])
146
+ html += f"<tr style='background-color: #2d2d2d; border-bottom: 1px solid #404040;'>"
147
+ html += f"<td style='padding: 12px; border: 1px solid #404040;'>{result['model']}</td>"
148
+ html += f"<td style='padding: 12px; border: 1px solid #404040;'>{result['prediction']}</td>"
149
+ html += f"<td style='padding: 12px; border: 1px solid #404040;'><span style='color: {color}; font-weight: bold;'>{result['confidence']:.4f}</span></td>"
150
+ html += f"<td style='padding: 12px; border: 1px solid #404040;'>{result['time']:.4f}</td>"
151
+ html += "</tr>"
152
+
153
+ html += "</table></div>"
154
+ return html
155
+
156
+ def format_progress(progress_value, desc):
157
+ """Format progress bar HTML."""
158
+ if progress_value >= 100:
159
+ return "" # Return empty string when complete
160
+
161
+ html = f"""
162
+ <div style='width: 100%; background-color: #1a1a1a; padding: 10px; border-radius: 5px; margin-bottom: 10px;'>
163
+ <div style='color: white; margin-bottom: 5px;'>{desc}</div>
164
+ <div style='background-color: #2d2d2d; border-radius: 3px;'>
165
+ <div style='background-color: #6b46c1; width: {progress_value}%; height: 20px; border-radius: 3px; transition: width 0.3s ease;'></div>
166
+ </div>
167
+ <div style='color: white; text-align: right; margin-top: 5px;'>{progress_value:.1f}%</div>
168
+ </div>
169
+ """
170
+ return html
171
+
172
+ def get_confidence_color(confidence):
173
+ """Return color based on confidence score."""
174
+ if confidence >= 0.8:
175
+ return "#00ff00" # Bright green for high confidence
176
+ elif confidence >= 0.5:
177
+ return "#ffa500" # Bright orange for medium confidence
178
+ else:
179
+ return "#ff4444" # Bright red for low confidence
180
+
181
+ # [Add GTE embedding generation function]
182
+ def generate_gte_embedding(text, model_name='Alibaba-NLP/gte-base'):
183
+ """Generate GTE embeddings for a single text."""
184
+ start_time = time.time()
185
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
186
+ model = AutoModel.from_pretrained(model_name)
187
+
188
+ # Tokenize and generate embedding
189
+ inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
190
+ with torch.no_grad():
191
+ outputs = model(**inputs)
192
+ embeddings = outputs.last_hidden_state[:, 0, :] # [CLS] token
193
+ embeddings = F.normalize(embeddings, p=2, dim=1) # normalize
194
+
195
+ inference_time = time.time() - start_time
196
+ return embeddings[0].numpy(), inference_time
197
+
198
+ # Make predictions (streaming version)
199
+ def predict_text_streaming(text):
200
+ try:
201
+ models = load_models()
202
+ results = []
203
+
204
+ # First yield empty table and progress bar
205
+ yield format_progress(0, "Loading models..."), format_results(results)
206
+
207
+ # Process FastText models first (they're fastest as they don't need embeddings)
208
+ for model_name, model in models.items():
209
+ if isinstance(model, fasttext.FastText._FastText):
210
+ yield format_progress(10, f"Processing {model_name}..."), format_results(results)
211
+ start_time = time.time()
212
+ prediction = model.predict(text)
213
+ label = prediction[0][0].replace('__label__', '')
214
+ confidence = float(prediction[1][0])
215
+ inference_time = time.time() - start_time
216
+
217
+ results.append({
218
+ 'model': model_name,
219
+ 'prediction': label,
220
+ 'confidence': confidence,
221
+ 'time': inference_time
222
+ })
223
+ yield format_progress(20, f"Completed {model_name}"), format_results(results)
224
+
225
+ # Process E5 models
226
+ yield format_progress(30, "Processing E5 Classifier..."), format_results(results)
227
+ e5_embedding, embed_time = generate_e5_embedding(text)
228
+ for model_name in ['E5 Classifier', 'E5-Instruct Classifier']:
229
+ start_time = time.time()
230
+ model = models[model_name]
231
+ embedding_2d = e5_embedding.reshape(1, -1)
232
+ prediction = model.predict(embedding_2d)[0]
233
+ probabilities = model.predict_proba(embedding_2d)[0]
234
+ confidence = max(probabilities)
235
+ inference_time = time.time() - start_time
236
+
237
+ results.append({
238
+ 'model': model_name,
239
+ 'prediction': prediction,
240
+ 'confidence': confidence,
241
+ 'time': inference_time + embed_time
242
+ })
243
+ yield format_progress(40, f"Completed {model_name}"), format_results(results)
244
+
245
+ # Process Azure models
246
+ yield format_progress(50, "Processing Azure Embeddings..."), format_results(results)
247
+ azure_embedding, embed_time = get_azure_embedding(text)
248
+ for model_name in ['Azure Classifier', 'Azure KNN Classifier']:
249
+ start_time = time.time()
250
+ model = models[model_name]
251
+ embedding_2d = azure_embedding.reshape(1, -1)
252
+ prediction = model.predict(embedding_2d)[0]
253
+ probabilities = model.predict_proba(embedding_2d)[0]
254
+ confidence = max(probabilities)
255
+ inference_time = time.time() - start_time
256
+
257
+ results.append({
258
+ 'model': model_name,
259
+ 'prediction': prediction,
260
+ 'confidence': confidence,
261
+ 'time': inference_time + embed_time
262
+ })
263
+ yield format_progress(70, f"Completed {model_name}"), format_results(results)
264
+
265
+ # Process ModernBERT model
266
+ yield format_progress(80, "Processing ModernBERT RF Classifier..."), format_results(results)
267
+ modernbert_embedding, embed_time = generate_modernbert_embedding(text)
268
+ model = models['ModernBERT RF Classifier']
269
+ embedding_2d = modernbert_embedding.reshape(1, -1)
270
+ prediction = model.predict(embedding_2d)[0]
271
+ probabilities = model.predict_proba(embedding_2d)[0]
272
+ confidence = max(probabilities)
273
+ inference_time = time.time() - start_time
274
+
275
+ results.append({
276
+ 'model': 'ModernBERT RF Classifier',
277
+ 'prediction': prediction,
278
+ 'confidence': confidence,
279
+ 'time': inference_time + embed_time
280
+ })
281
+ yield format_progress(90, "Completed ModernBERT RF Classifier"), format_results(results)
282
+
283
+ # Process GTE model
284
+ yield format_progress(95, "Processing GTE Classifier..."), format_results(results)
285
+ gte_embedding, embed_time = generate_gte_embedding(text)
286
+ model = models['GTE Classifier']
287
+ embedding_2d = gte_embedding.reshape(1, -1)
288
+ prediction = model.predict(embedding_2d)[0]
289
+ probabilities = model.predict_proba(embedding_2d)[0]
290
+ confidence = max(probabilities)
291
+ inference_time = time.time() - start_time
292
+
293
+ results.append({
294
+ 'model': 'GTE Classifier',
295
+ 'prediction': prediction,
296
+ 'confidence': confidence,
297
+ 'time': inference_time + embed_time
298
+ })
299
+ yield format_progress(100, "Completed!"), format_results(results)
300
+
301
+ except Exception as e:
302
+ yield "", f"<div style='color: red; padding: 20px;'>Error occurred: {str(e)}</div>"
303
+
304
+ # Create Gradio interface with custom CSS
305
+ css = """
306
+ .main {
307
+ gap: 0 !important;
308
+ }
309
+ .contain {
310
+ gap: 0 !important;
311
+ }
312
+ .feedback {
313
+ margin-top: 0 !important;
314
+ margin-bottom: 0 !important;
315
+ }
316
+ """
317
+
318
+ iface = gr.Interface(
319
+ fn=predict_text_streaming,
320
+ inputs=gr.Textbox(label="Enter text to classify", lines=3),
321
+ outputs=[
322
+ gr.HTML(label="Progress"),
323
+ gr.HTML(label="Model Predictions")
324
+ ],
325
+ title="Text Classification Model Comparison",
326
+ description="Compare predictions from different text classification models (Results stream as they become available)",
327
+ theme=gr.themes.Soft(),
328
+ css=css
329
+ )
330
+
331
+ if __name__ == "__main__":
332
+ iface.launch(debug=True)