|
import streamlit as st |
|
import pandas as pd |
|
import torch |
|
from transformers import BertTokenizer, AutoModelForSequenceClassification |
|
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report, f1_score, precision_recall_fscore_support |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
import plotly.express as px |
|
from tqdm import tqdm |
|
|
|
def load_model_and_tokenizer(): |
|
try: |
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
model = AutoModelForSequenceClassification.from_pretrained("CIS519PG/News_Classifier_Demo") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = model.to(device) |
|
model.eval() |
|
return model, tokenizer, device |
|
except Exception as e: |
|
st.error(f"Error loading model or tokenizer: {str(e)}") |
|
return None, None, None |
|
|
|
def preprocess_data(df): |
|
try: |
|
processed_data = [] |
|
for _, row in df.iterrows(): |
|
outlet = row["News Outlet"].strip().upper() |
|
if outlet == "FOX NEWS": |
|
outlet = "FOXNEWS" |
|
elif outlet == "NBC NEWS": |
|
outlet = "NBC" |
|
|
|
processed_data.append({ |
|
"title": row["title"], |
|
"outlet": outlet |
|
}) |
|
return processed_data |
|
except Exception as e: |
|
st.error(f"Error preprocessing data: {str(e)}") |
|
return None |
|
|
|
def evaluate_model(model, tokenizer, device, test_dataset): |
|
label2id = {"FOXNEWS": 0, "NBC": 1} |
|
all_logits = [] |
|
references = [] |
|
|
|
batch_size = 16 |
|
progress_bar = st.progress(0) |
|
|
|
for i in range(0, len(test_dataset), batch_size): |
|
|
|
progress = min(i / len(test_dataset), 1.0) |
|
progress_bar.progress(progress) |
|
|
|
batch = test_dataset[i:i + batch_size] |
|
texts = [item['title'] for item in batch] |
|
|
|
encoded = tokenizer( |
|
texts, |
|
padding=True, |
|
truncation=True, |
|
max_length=128, |
|
return_tensors="pt" |
|
) |
|
|
|
inputs = {k: v.to(device) for k, v in encoded.items()} |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits.cpu().numpy() |
|
|
|
true_labels = [label2id[item['outlet']] for item in batch] |
|
all_logits.extend(logits) |
|
references.extend(true_labels) |
|
progress_bar.progress(1.0) |
|
probabilities = torch.softmax(torch.tensor(all_logits), dim=1).numpy() |
|
return references, probabilities |
|
|
|
def plot_roc_curve(references, probabilities): |
|
fpr, tpr, _ = roc_curve(references, probabilities[:, 1]) |
|
auc_score = roc_auc_score(references, probabilities[:, 1]) |
|
fig = go.Figure() |
|
fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'ROC Curve (AUC = {auc_score:.4f})')) |
|
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], name='Random Guess', line=dict(dash='dash'))) |
|
fig.update_layout( |
|
title='ROC Curve', |
|
xaxis_title='False Positive Rate', |
|
yaxis_title='True Positive Rate', |
|
showlegend=True |
|
) |
|
return fig, auc_score |
|
|
|
def plot_metrics_by_threshold(references, probabilities): |
|
thresholds = np.arange(0.0, 1.0, 0.01) |
|
metrics = { |
|
'threshold': thresholds, |
|
'f1': [], |
|
'precision': [], |
|
'recall': [] |
|
} |
|
best_f1 = 0 |
|
best_threshold = 0 |
|
best_metrics = {} |
|
for threshold in thresholds: |
|
preds = (probabilities[:, 1] > threshold).astype(int) |
|
f1 = f1_score(references, preds) |
|
precision, recall, _, _ = precision_recall_fscore_support(references, preds, average='binary') |
|
metrics['f1'].append(f1) |
|
metrics['precision'].append(precision) |
|
metrics['recall'].append(recall) |
|
if f1 > best_f1: |
|
best_f1 = f1 |
|
best_threshold = threshold |
|
cm = confusion_matrix(references, preds) |
|
report = classification_report(references, preds, target_names=['FOXNEWS', 'NBC'], digits=4) |
|
best_metrics = { |
|
'threshold': threshold, |
|
'f1_score': f1, |
|
'confusion_matrix': cm, |
|
'classification_report': report |
|
} |
|
fig = go.Figure() |
|
fig.add_trace(go.Scatter(x=thresholds, y=metrics['f1'], name='F1 Score')) |
|
fig.add_trace(go.Scatter(x=thresholds, y=metrics['precision'], name='Precision')) |
|
fig.add_trace(go.Scatter(x=thresholds, y=metrics['recall'], name='Recall')) |
|
fig.update_layout( |
|
title='Metrics by Threshold', |
|
xaxis_title='Threshold', |
|
yaxis_title='Score', |
|
showlegend=True |
|
) |
|
return fig, best_metrics |
|
|
|
def plot_confusion_matrix(cm): |
|
labels = ['FOXNEWS', 'NBC'] |
|
annotations = [] |
|
for i in range(len(labels)): |
|
for j in range(len(labels)): |
|
annotations.append( |
|
dict( |
|
text=str(cm[i, j]), |
|
x=labels[j], |
|
y=labels[i], |
|
showarrow=False, |
|
font=dict(color='white' if cm[i, j] > cm.max()/2 else 'black') |
|
) |
|
) |
|
fig = go.Figure(data=go.Heatmap( |
|
z=cm, |
|
x=labels, |
|
y=labels, |
|
colorscale='Blues', |
|
showscale=True |
|
)) |
|
fig.update_layout( |
|
title='Confusion Matrix', |
|
xaxis_title='Predicted Label', |
|
yaxis_title='True Label', |
|
annotations=annotations |
|
) |
|
return fig |
|
|
|
def main(): |
|
st.title("News Classifier Model Evaluation") |
|
uploaded_file = st.file_uploader("Upload your test dataset (CSV)", type=['csv']) |
|
if uploaded_file is not None: |
|
df = pd.read_csv(uploaded_file) |
|
st.write("Preview of uploaded data:") |
|
st.dataframe(df.head()) |
|
model, tokenizer, device = load_model_and_tokenizer() |
|
if model and tokenizer: |
|
test_dataset = preprocess_data(df) |
|
if test_dataset: |
|
st.write(f"Total examples: {len(test_dataset)}") |
|
with st.spinner('Evaluating model...'): |
|
references, probabilities = evaluate_model(model, tokenizer, device, test_dataset) |
|
roc_fig, auc_score = plot_roc_curve(references, probabilities) |
|
st.plotly_chart(roc_fig) |
|
st.metric("AUC-ROC Score", f"{auc_score:.4f}") |
|
metrics_fig, best_metrics = plot_metrics_by_threshold(references, probabilities) |
|
st.plotly_chart(metrics_fig) |
|
st.subheader("Best Threshold Evaluation") |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.metric("Best Threshold", f"{best_metrics['threshold']:.2f}") |
|
with col2: |
|
st.metric("Best F1 Score", f"{best_metrics['f1_score']:.4f}") |
|
st.subheader("Confusion Matrix") |
|
cm_fig = plot_confusion_matrix(best_metrics['confusion_matrix']) |
|
st.plotly_chart(cm_fig) |
|
st.subheader("Classification Report") |
|
st.text(best_metrics['classification_report']) |
|
if __name__ == "__main__": |
|
main() |