SeeknnDestroy
commited on
remove modernbert
Browse files
app.py
CHANGED
@@ -73,22 +73,6 @@ def generate_e5_instruct_embedding(text, model_name='intfloat/multilingual-e5-la
|
|
73 |
inference_time = time.time() - start_time
|
74 |
return embeddings[0].numpy(), inference_time
|
75 |
|
76 |
-
def generate_modernbert_embedding(text, model_name="answerdotai/ModernBERT-base"):
|
77 |
-
"""Generate ModernBERT embeddings for a single text."""
|
78 |
-
start_time = time.time()
|
79 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
80 |
-
model = AutoModel.from_pretrained(model_name)
|
81 |
-
|
82 |
-
# Tokenize and generate embedding
|
83 |
-
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
84 |
-
with torch.no_grad():
|
85 |
-
outputs = model(**inputs)
|
86 |
-
# Take [CLS] token embedding
|
87 |
-
embeddings = outputs.last_hidden_state[:, 0, :]
|
88 |
-
|
89 |
-
inference_time = time.time() - start_time
|
90 |
-
return embeddings[0].numpy(), inference_time
|
91 |
-
|
92 |
def mean_pooling(token_embeddings, attention_mask):
|
93 |
"""Mean pooling function for E5 models."""
|
94 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
@@ -121,9 +105,6 @@ def load_models():
|
|
121 |
with open(os.path.join(MODEL_DIR, 'azure_knn_classifier.pkl'), 'rb') as f:
|
122 |
models['Azure KNN Classifier'] = pickle.load(f)
|
123 |
|
124 |
-
with open(os.path.join(MODEL_DIR, 'modernbert_rf_classifier.pkl'), 'rb') as f:
|
125 |
-
models['ModernBERT RF Classifier'] = pickle.load(f)
|
126 |
-
|
127 |
with open(os.path.join(MODEL_DIR, 'gte_classifier.pkl'), 'rb') as f:
|
128 |
models['GTE Classifier'] = pickle.load(f)
|
129 |
|
@@ -265,26 +246,8 @@ def predict_text_streaming(text):
|
|
265 |
})
|
266 |
yield format_progress(70, f"Completed {model_name}"), format_results(results)
|
267 |
|
268 |
-
# Process ModernBERT model
|
269 |
-
yield format_progress(80, "Processing ModernBERT RF Classifier..."), format_results(results)
|
270 |
-
modernbert_embedding, embed_time = generate_modernbert_embedding(text)
|
271 |
-
model = models['ModernBERT RF Classifier']
|
272 |
-
embedding_2d = modernbert_embedding.reshape(1, -1)
|
273 |
-
prediction = model.predict(embedding_2d)[0]
|
274 |
-
probabilities = model.predict_proba(embedding_2d)[0]
|
275 |
-
confidence = max(probabilities)
|
276 |
-
inference_time = time.time() - start_time
|
277 |
-
|
278 |
-
results.append({
|
279 |
-
'model': 'ModernBERT RF Classifier',
|
280 |
-
'prediction': prediction,
|
281 |
-
'confidence': confidence,
|
282 |
-
'time': inference_time + embed_time
|
283 |
-
})
|
284 |
-
yield format_progress(90, "Completed ModernBERT RF Classifier"), format_results(results)
|
285 |
-
|
286 |
# Process GTE model
|
287 |
-
yield format_progress(
|
288 |
gte_embedding, embed_time = generate_gte_embedding(text)
|
289 |
model = models['GTE Classifier']
|
290 |
embedding_2d = gte_embedding.reshape(1, -1)
|
|
|
73 |
inference_time = time.time() - start_time
|
74 |
return embeddings[0].numpy(), inference_time
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def mean_pooling(token_embeddings, attention_mask):
|
77 |
"""Mean pooling function for E5 models."""
|
78 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
|
105 |
with open(os.path.join(MODEL_DIR, 'azure_knn_classifier.pkl'), 'rb') as f:
|
106 |
models['Azure KNN Classifier'] = pickle.load(f)
|
107 |
|
|
|
|
|
|
|
108 |
with open(os.path.join(MODEL_DIR, 'gte_classifier.pkl'), 'rb') as f:
|
109 |
models['GTE Classifier'] = pickle.load(f)
|
110 |
|
|
|
246 |
})
|
247 |
yield format_progress(70, f"Completed {model_name}"), format_results(results)
|
248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
# Process GTE model
|
250 |
+
yield format_progress(90, "Processing GTE Classifier..."), format_results(results)
|
251 |
gte_embedding, embed_time = generate_gte_embedding(text)
|
252 |
model = models['GTE Classifier']
|
253 |
embedding_2d = gte_embedding.reshape(1, -1)
|