News_Classifier_Demo / eval_pipeline.py
tigerlinlxt's picture
Update eval_pipeline.py
c5afc23 verified
raw
history blame
7.12 kB
import streamlit as st
import pandas as pd
import torch
from transformers import BertTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report, f1_score, precision_recall_fscore_support
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from tqdm import tqdm
def load_model_and_tokenizer():
try:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("CIS519PG/News_Classifier_Demo")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
return model, tokenizer, device
except Exception as e:
st.error(f"Error loading model or tokenizer: {str(e)}")
return None, None, None
def preprocess_data(df):
try:
processed_data = []
for _, row in df.iterrows():
outlet = row["outlet"].strip().upper()
if outlet == "FOX NEWS":
outlet = "FOXNEWS"
elif outlet == "NBC NEWS":
outlet = "NBC"
processed_data.append({
"title": row["title"],
"outlet": outlet
})
return processed_data
except Exception as e:
st.error(f"Error preprocessing data: {str(e)}")
return None
def evaluate_model(model, tokenizer, device, test_dataset):
label2id = {"FOXNEWS": 0, "NBC": 1}
all_logits = []
references = []
batch_size = 16
progress_bar = st.progress(0)
for i in range(0, len(test_dataset), batch_size):
# Update progress
progress = min(i / len(test_dataset), 1.0)
progress_bar.progress(progress)
batch = test_dataset[i:i + batch_size]
texts = [item['title'] for item in batch]
encoded = tokenizer(
texts,
padding=True,
truncation=True,
max_length=128,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in encoded.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu().numpy()
true_labels = [label2id[item['outlet']] for item in batch]
all_logits.extend(logits)
references.extend(true_labels)
progress_bar.progress(1.0)
probabilities = torch.softmax(torch.tensor(all_logits), dim=1).numpy()
return references, probabilities
def plot_roc_curve(references, probabilities):
fpr, tpr, _ = roc_curve(references, probabilities[:, 1])
auc_score = roc_auc_score(references, probabilities[:, 1])
fig = go.Figure()
fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'ROC Curve (AUC = {auc_score:.4f})'))
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], name='Random Guess', line=dict(dash='dash')))
fig.update_layout(
title='ROC Curve',
xaxis_title='False Positive Rate',
yaxis_title='True Positive Rate',
showlegend=True
)
return fig, auc_score
def plot_metrics_by_threshold(references, probabilities):
thresholds = np.arange(0.0, 1.0, 0.01)
metrics = {
'threshold': thresholds,
'f1': [],
'precision': [],
'recall': []
}
best_f1 = 0
best_threshold = 0
best_metrics = {}
for threshold in thresholds:
preds = (probabilities[:, 1] > threshold).astype(int)
f1 = f1_score(references, preds)
precision, recall, _, _ = precision_recall_fscore_support(references, preds, average='binary')
metrics['f1'].append(f1)
metrics['precision'].append(precision)
metrics['recall'].append(recall)
if f1 > best_f1:
best_f1 = f1
best_threshold = threshold
cm = confusion_matrix(references, preds)
report = classification_report(references, preds, target_names=['FOXNEWS', 'NBC'], digits=4)
best_metrics = {
'threshold': threshold,
'f1_score': f1,
'confusion_matrix': cm,
'classification_report': report
}
fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds, y=metrics['f1'], name='F1 Score'))
fig.add_trace(go.Scatter(x=thresholds, y=metrics['precision'], name='Precision'))
fig.add_trace(go.Scatter(x=thresholds, y=metrics['recall'], name='Recall'))
fig.update_layout(
title='Metrics by Threshold',
xaxis_title='Threshold',
yaxis_title='Score',
showlegend=True
)
return fig, best_metrics
def plot_confusion_matrix(cm):
labels = ['FOXNEWS', 'NBC']
annotations = []
for i in range(len(labels)):
for j in range(len(labels)):
annotations.append(
dict(
text=str(cm[i, j]),
x=labels[j],
y=labels[i],
showarrow=False,
font=dict(color='white' if cm[i, j] > cm.max()/2 else 'black')
)
)
fig = go.Figure(data=go.Heatmap(
z=cm,
x=labels,
y=labels,
colorscale='Blues',
showscale=True
))
fig.update_layout(
title='Confusion Matrix',
xaxis_title='Predicted Label',
yaxis_title='True Label',
annotations=annotations
)
return fig
def main():
st.title("News Classifier Model Evaluation")
uploaded_file = st.file_uploader("Upload your test dataset (CSV)", type=['csv'])
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
st.write("Preview of uploaded data:")
st.dataframe(df.head())
model, tokenizer, device = load_model_and_tokenizer()
if model and tokenizer:
test_dataset = preprocess_data(df)
if test_dataset:
st.write(f"Total examples: {len(test_dataset)}")
with st.spinner('Evaluating model...'):
references, probabilities = evaluate_model(model, tokenizer, device, test_dataset)
roc_fig, auc_score = plot_roc_curve(references, probabilities)
st.plotly_chart(roc_fig)
st.metric("AUC-ROC Score", f"{auc_score:.4f}")
metrics_fig, best_metrics = plot_metrics_by_threshold(references, probabilities)
st.plotly_chart(metrics_fig)
st.subheader("Best Threshold Evaluation")
col1, col2 = st.columns(2)
with col1:
st.metric("Best Threshold", f"{best_metrics['threshold']:.2f}")
with col2:
st.metric("Best F1 Score", f"{best_metrics['f1_score']:.4f}")
st.subheader("Confusion Matrix")
cm_fig = plot_confusion_matrix(best_metrics['confusion_matrix'])
st.plotly_chart(cm_fig)
st.subheader("Classification Report")
st.text(best_metrics['classification_report'])
if __name__ == "__main__":
main()