|
import gradio as gr |
|
import pickle |
|
import fasttext |
|
import numpy as np |
|
import os |
|
import torch |
|
import time |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch.nn.functional as F |
|
from openai import AzureOpenAI |
|
from dotenv import load_dotenv |
|
from config import get_fasttext_path, is_model_enabled |
|
|
|
load_dotenv() |
|
|
|
|
|
AZURE_API_VERSION = "2024-02-01" |
|
|
|
|
|
MODEL_DIR = "models" |
|
|
|
|
|
azure_client = AzureOpenAI( |
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"), |
|
api_version=AZURE_API_VERSION, |
|
azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT") |
|
) |
|
|
|
def generate_e5_embedding(text, model_name='intfloat/multilingual-e5-large'): |
|
"""Generate E5 embeddings for a single text.""" |
|
start_time = time.time() |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name) |
|
|
|
|
|
text = f"query: {text}" |
|
|
|
|
|
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
attention_mask = inputs['attention_mask'] |
|
embeddings = mean_pooling(outputs.last_hidden_state, attention_mask) |
|
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
inference_time = time.time() - start_time |
|
return embeddings[0].numpy(), inference_time |
|
|
|
def generate_e5_instruct_embedding(text, model_name='intfloat/multilingual-e5-large-instruct'): |
|
"""Generate E5-instruct embeddings for a single text.""" |
|
start_time = time.time() |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name) |
|
|
|
|
|
text = f"query: {text}" |
|
|
|
|
|
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
attention_mask = inputs['attention_mask'] |
|
embeddings = mean_pooling(outputs.last_hidden_state, attention_mask) |
|
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
inference_time = time.time() - start_time |
|
return embeddings[0].numpy(), inference_time |
|
|
|
def mean_pooling(token_embeddings, attention_mask): |
|
"""Mean pooling function for E5 models.""" |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
def get_azure_embedding(text): |
|
"""Get embeddings from Azure OpenAI API.""" |
|
start_time = time.time() |
|
response = azure_client.embeddings.create( |
|
model="text-embedding-3-large", |
|
input=text |
|
) |
|
inference_time = time.time() - start_time |
|
return np.array(response.data[0].embedding), inference_time |
|
|
|
|
|
def load_models(): |
|
models = {} |
|
|
|
|
|
pickle_models = { |
|
'E5 Classifier': 'e5_classifier.pkl', |
|
'E5-Instruct Classifier': 'e5_large_instruct_classifier.pkl', |
|
'Azure Classifier': 'azure_classifier.pkl', |
|
'Azure KNN Classifier': 'azure_knn_classifier.pkl', |
|
'GTE Classifier': 'gte_classifier.pkl' |
|
} |
|
|
|
for model_name, filename in pickle_models.items(): |
|
if is_model_enabled(model_name): |
|
with open(os.path.join(MODEL_DIR, filename), 'rb') as f: |
|
models[model_name] = pickle.load(f) |
|
|
|
|
|
if is_model_enabled('FastText Default'): |
|
models['FastText Default'] = fasttext.load_model(get_fasttext_path('fasttext_default')) |
|
if is_model_enabled('FastText Preprocessed'): |
|
models['FastText Preprocessed'] = fasttext.load_model(get_fasttext_path('fasttext_preprocessed')) |
|
if is_model_enabled('Fasttext WordnNGram 1'): |
|
models['Fasttext WordnNGram 1'] = fasttext.load_model(get_fasttext_path('word_n_gram_1')) |
|
if is_model_enabled('Fasttext WordnNGram 2'): |
|
models['Fasttext WordnNGram 2'] = fasttext.load_model(get_fasttext_path('word_n_gram_2')) |
|
if is_model_enabled('Fasttext WordnNGram 3'): |
|
models['Fasttext WordnNGram 3'] = fasttext.load_model(get_fasttext_path('word_n_gram_3')) |
|
if is_model_enabled('Fasttext Low Overfit'): |
|
models['Fasttext Low Overfit'] = fasttext.load_model(get_fasttext_path('low_overfit')) |
|
|
|
return models |
|
|
|
def format_results(results): |
|
"""Format results into HTML for better visualization.""" |
|
html = "<div style='font-family: monospace; padding: 10px 20px;'>" |
|
html += "<table style='width: 100%; border-collapse: collapse; background-color: #1a1a1a; color: #ffffff; margin-bottom: 0;'>" |
|
html += "<tr style='background-color: #2c3e50;'>" |
|
html += "<th style='padding: 12px; text-align: left; border: 1px solid #34495e;'>Model</th>" |
|
html += "<th style='padding: 12px; text-align: left; border: 1px solid #34495e;'>Prediction</th>" |
|
html += "<th style='padding: 12px; text-align: left; border: 1px solid #34495e;'>Confidence</th>" |
|
html += "<th style='padding: 12px; text-align: left; border: 1px solid #34495e;'>Time (sec)</th>" |
|
html += "</tr>" |
|
|
|
for result in results: |
|
confidence_color = get_confidence_color(result['confidence']) |
|
html += f"<tr style='background-color: #2d2d2d; border-bottom: 1px solid #404040;'>" |
|
html += f"<td style='padding: 12px; border: 1px solid #404040;'>{result['model']}</td>" |
|
html += f"<td style='padding: 12px; border: 1px solid #404040;'><span style='color: #4CAF50; font-weight: bold;'>{result['prediction']}</span></td>" |
|
html += f"<td style='padding: 12px; border: 1px solid #404040;'><span style='color: {confidence_color}; font-weight: bold;'>{result['confidence']:.4f}</span></td>" |
|
html += f"<td style='padding: 12px; border: 1px solid #404040;'>{result['time']:.4f}</td>" |
|
html += "</tr>" |
|
|
|
html += "</table></div>" |
|
return html |
|
|
|
def format_progress(progress_value, desc): |
|
"""Format progress bar HTML.""" |
|
if progress_value >= 100: |
|
return "" |
|
|
|
html = f""" |
|
<div style='width: 100%; background-color: #1a1a1a; padding: 10px; border-radius: 5px; margin-bottom: 10px;'> |
|
<div style='color: white; margin-bottom: 5px;'>{desc}</div> |
|
<div style='background-color: #2d2d2d; border-radius: 3px;'> |
|
<div style='background-color: #6b46c1; width: {progress_value}%; height: 20px; border-radius: 3px; transition: width 0.3s ease;'></div> |
|
</div> |
|
<div style='color: white; text-align: right; margin-top: 5px;'>{progress_value:.1f}%</div> |
|
</div> |
|
""" |
|
return html |
|
|
|
def get_confidence_color(confidence): |
|
"""Return color based on confidence score.""" |
|
if confidence >= 0.8: |
|
return "#00ff00" |
|
elif confidence >= 0.5: |
|
return "#ffa500" |
|
else: |
|
return "#ff4444" |
|
|
|
|
|
def generate_gte_embedding(text, model_name='Alibaba-NLP/gte-multilingual-base'): |
|
"""Generate GTE embeddings for a single text.""" |
|
start_time = time.time() |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
embeddings = outputs.last_hidden_state[:, 0, :] |
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
inference_time = time.time() - start_time |
|
return embeddings[0].numpy(), inference_time |
|
|
|
|
|
def predict_text_streaming(text): |
|
try: |
|
models = load_models() |
|
results = [] |
|
|
|
if not models: |
|
return "", "<div style='color: red; padding: 20px;'>No models are enabled in the configuration.</div>" |
|
|
|
|
|
progress_step = 100.0 / len(models) |
|
current_progress = 0 |
|
|
|
|
|
yield format_progress(current_progress, "Loading models..."), format_results(results) |
|
|
|
|
|
for model_name, model in models.items(): |
|
if isinstance(model, fasttext.FastText._FastText): |
|
yield format_progress(current_progress, f"Processing {model_name}..."), format_results(results) |
|
start_time = time.time() |
|
prediction = model.predict(text) |
|
label = prediction[0][0].replace('__label__', '') |
|
confidence = float(prediction[1][0]) |
|
inference_time = time.time() - start_time |
|
|
|
results.append({ |
|
'model': model_name, |
|
'prediction': label, |
|
'confidence': confidence, |
|
'time': inference_time |
|
}) |
|
current_progress += progress_step |
|
yield format_progress(current_progress, f"Completed {model_name}"), format_results(results) |
|
|
|
|
|
e5_embedding = None |
|
for model_name, model in models.items(): |
|
if model_name in ['E5 Classifier', 'E5-Instruct Classifier']: |
|
if e5_embedding is None: |
|
yield format_progress(current_progress, f"Generating E5 embeddings..."), format_results(results) |
|
e5_embedding, embed_time = generate_e5_embedding(text) |
|
|
|
start_time = time.time() |
|
embedding_2d = e5_embedding.reshape(1, -1) |
|
prediction = model.predict(embedding_2d)[0] |
|
probabilities = model.predict_proba(embedding_2d)[0] |
|
confidence = max(probabilities) |
|
inference_time = time.time() - start_time |
|
|
|
results.append({ |
|
'model': model_name, |
|
'prediction': prediction, |
|
'confidence': confidence, |
|
'time': inference_time + embed_time |
|
}) |
|
current_progress += progress_step |
|
yield format_progress(current_progress, f"Completed {model_name}"), format_results(results) |
|
|
|
|
|
azure_embedding = None |
|
for model_name, model in models.items(): |
|
if model_name in ['Azure Classifier', 'Azure KNN Classifier']: |
|
if azure_embedding is None: |
|
yield format_progress(current_progress, "Generating Azure embeddings..."), format_results(results) |
|
azure_embedding, embed_time = get_azure_embedding(text) |
|
|
|
start_time = time.time() |
|
embedding_2d = azure_embedding.reshape(1, -1) |
|
prediction = model.predict(embedding_2d)[0] |
|
probabilities = model.predict_proba(embedding_2d)[0] |
|
confidence = max(probabilities) |
|
inference_time = time.time() - start_time |
|
|
|
results.append({ |
|
'model': model_name, |
|
'prediction': prediction, |
|
'confidence': confidence, |
|
'time': inference_time + embed_time |
|
}) |
|
current_progress += progress_step |
|
yield format_progress(current_progress, f"Completed {model_name}"), format_results(results) |
|
|
|
|
|
if 'GTE Classifier' in models: |
|
yield format_progress(current_progress, "Processing GTE Classifier..."), format_results(results) |
|
gte_embedding, embed_time = generate_gte_embedding(text) |
|
model = models['GTE Classifier'] |
|
embedding_2d = gte_embedding.reshape(1, -1) |
|
prediction = model.predict(embedding_2d)[0] |
|
probabilities = model.predict_proba(embedding_2d)[0] |
|
confidence = max(probabilities) |
|
inference_time = time.time() - start_time |
|
|
|
results.append({ |
|
'model': 'GTE Classifier', |
|
'prediction': prediction, |
|
'confidence': confidence, |
|
'time': inference_time + embed_time |
|
}) |
|
current_progress = 100 |
|
yield format_progress(current_progress, "Completed!"), format_results(results) |
|
|
|
except Exception as e: |
|
yield "", f"<div style='color: red; padding: 20px;'>Error occurred: {str(e)}</div>" |
|
|
|
|
|
css = """ |
|
.main { |
|
gap: 0 !important; |
|
} |
|
.contain { |
|
gap: 0 !important; |
|
} |
|
.feedback { |
|
margin-top: 0 !important; |
|
margin-bottom: 0 !important; |
|
} |
|
""" |
|
|
|
iface = gr.Interface( |
|
fn=predict_text_streaming, |
|
inputs=gr.Textbox(label="Enter text to classify", lines=3), |
|
outputs=[ |
|
gr.HTML(label="Progress"), |
|
gr.HTML(label="Model Predictions") |
|
], |
|
title="Text Classification Model Comparison", |
|
description="Compare predictions from different text classification models (Results stream as they become available)", |
|
theme=gr.themes.Soft(), |
|
css=css |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch(debug=True) |