import streamlit as st import pandas as pd import torch from transformers import BertTokenizer, AutoModelForSequenceClassification from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report, f1_score, precision_recall_fscore_support import numpy as np import plotly.graph_objects as go import plotly.express as px from tqdm import tqdm def load_model_and_tokenizer(): try: tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") model = AutoModelForSequenceClassification.from_pretrained("CIS519PG/News_Classifier_Demo") device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() return model, tokenizer, device except Exception as e: st.error(f"Error loading model or tokenizer: {str(e)}") return None, None, None def preprocess_data(df): try: processed_data = [] for _, row in df.iterrows(): outlet = row["News Outlet"].strip().upper() if outlet == "FOX NEWS": outlet = "FOXNEWS" elif outlet == "NBC NEWS": outlet = "NBC" processed_data.append({ "title": row["title"], "outlet": outlet }) return processed_data except Exception as e: st.error(f"Error preprocessing data: {str(e)}") return None def evaluate_model(model, tokenizer, device, test_dataset): label2id = {"FOXNEWS": 0, "NBC": 1} all_logits = [] references = [] batch_size = 16 progress_bar = st.progress(0) for i in range(0, len(test_dataset), batch_size): # Update progress progress = min(i / len(test_dataset), 1.0) progress_bar.progress(progress) batch = test_dataset[i:i + batch_size] texts = [item['title'] for item in batch] encoded = tokenizer( texts, padding=True, truncation=True, max_length=128, return_tensors="pt" ) inputs = {k: v.to(device) for k, v in encoded.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits.cpu().numpy() true_labels = [label2id[item['outlet']] for item in batch] all_logits.extend(logits) references.extend(true_labels) progress_bar.progress(1.0) probabilities = torch.softmax(torch.tensor(all_logits), dim=1).numpy() return references, probabilities def plot_roc_curve(references, probabilities): fpr, tpr, _ = roc_curve(references, probabilities[:, 1]) auc_score = roc_auc_score(references, probabilities[:, 1]) fig = go.Figure() fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'ROC Curve (AUC = {auc_score:.4f})')) fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], name='Random Guess', line=dict(dash='dash'))) fig.update_layout( title='ROC Curve', xaxis_title='False Positive Rate', yaxis_title='True Positive Rate', showlegend=True ) return fig, auc_score def plot_metrics_by_threshold(references, probabilities): thresholds = np.arange(0.0, 1.0, 0.01) metrics = { 'threshold': thresholds, 'f1': [], 'precision': [], 'recall': [] } best_f1 = 0 best_threshold = 0 best_metrics = {} for threshold in thresholds: preds = (probabilities[:, 1] > threshold).astype(int) f1 = f1_score(references, preds) precision, recall, _, _ = precision_recall_fscore_support(references, preds, average='binary') metrics['f1'].append(f1) metrics['precision'].append(precision) metrics['recall'].append(recall) if f1 > best_f1: best_f1 = f1 best_threshold = threshold cm = confusion_matrix(references, preds) report = classification_report(references, preds, target_names=['FOXNEWS', 'NBC'], digits=4) best_metrics = { 'threshold': threshold, 'f1_score': f1, 'confusion_matrix': cm, 'classification_report': report } fig = go.Figure() fig.add_trace(go.Scatter(x=thresholds, y=metrics['f1'], name='F1 Score')) fig.add_trace(go.Scatter(x=thresholds, y=metrics['precision'], name='Precision')) fig.add_trace(go.Scatter(x=thresholds, y=metrics['recall'], name='Recall')) fig.update_layout( title='Metrics by Threshold', xaxis_title='Threshold', yaxis_title='Score', showlegend=True ) return fig, best_metrics def plot_confusion_matrix(cm): labels = ['FOXNEWS', 'NBC'] annotations = [] for i in range(len(labels)): for j in range(len(labels)): annotations.append( dict( text=str(cm[i, j]), x=labels[j], y=labels[i], showarrow=False, font=dict(color='white' if cm[i, j] > cm.max()/2 else 'black') ) ) fig = go.Figure(data=go.Heatmap( z=cm, x=labels, y=labels, colorscale='Blues', showscale=True )) fig.update_layout( title='Confusion Matrix', xaxis_title='Predicted Label', yaxis_title='True Label', annotations=annotations ) return fig def main(): st.title("News Classifier Model Evaluation") uploaded_file = st.file_uploader("Upload your test dataset (CSV)", type=['csv']) if uploaded_file is not None: df = pd.read_csv(uploaded_file) st.write("Preview of uploaded data:") st.dataframe(df.head()) model, tokenizer, device = load_model_and_tokenizer() if model and tokenizer: test_dataset = preprocess_data(df) if test_dataset: st.write(f"Total examples: {len(test_dataset)}") with st.spinner('Evaluating model...'): references, probabilities = evaluate_model(model, tokenizer, device, test_dataset) roc_fig, auc_score = plot_roc_curve(references, probabilities) st.plotly_chart(roc_fig) st.metric("AUC-ROC Score", f"{auc_score:.4f}") metrics_fig, best_metrics = plot_metrics_by_threshold(references, probabilities) st.plotly_chart(metrics_fig) st.subheader("Best Threshold Evaluation") col1, col2 = st.columns(2) with col1: st.metric("Best Threshold", f"{best_metrics['threshold']:.2f}") with col2: st.metric("Best F1 Score", f"{best_metrics['f1_score']:.4f}") st.subheader("Confusion Matrix") cm_fig = plot_confusion_matrix(best_metrics['confusion_matrix']) st.plotly_chart(cm_fig) st.subheader("Classification Report") st.text(best_metrics['classification_report']) if __name__ == "__main__": main()