Jiahuita commited on
Commit
a232772
·
verified ·
1 Parent(s): 445733e

Upload eval_pipeline.py

Browse files
Files changed (1) hide show
  1. eval_pipeline.py +193 -0
eval_pipeline.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import BertTokenizer, AutoModelForSequenceClassification
5
+ from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report, f1_score, precision_recall_fscore_support
6
+ import numpy as np
7
+ import plotly.graph_objects as go
8
+ import plotly.express as px
9
+ from tqdm import tqdm
10
+
11
+ def load_model_and_tokenizer():
12
+ try:
13
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
14
+ model = AutoModelForSequenceClassification.from_pretrained("CIS519PG/News_Classifier_Demo")
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model = model.to(device)
17
+ model.eval()
18
+ return model, tokenizer, device
19
+ except Exception as e:
20
+ st.error(f"Error loading model or tokenizer: {str(e)}")
21
+ return None, None, None
22
+
23
+ def preprocess_data(df):
24
+ try:
25
+ processed_data = []
26
+ for _, row in df.iterrows():
27
+ outlet = row["News Outlet"].strip().upper()
28
+ if outlet == "FOX NEWS":
29
+ outlet = "FOXNEWS"
30
+ elif outlet == "NBC NEWS":
31
+ outlet = "NBC"
32
+
33
+ processed_data.append({
34
+ "title": row["title"],
35
+ "outlet": outlet
36
+ })
37
+ return processed_data
38
+ except Exception as e:
39
+ st.error(f"Error preprocessing data: {str(e)}")
40
+ return None
41
+
42
+ def evaluate_model(model, tokenizer, device, test_dataset):
43
+ label2id = {"FOXNEWS": 0, "NBC": 1}
44
+ all_logits = []
45
+ references = []
46
+
47
+ batch_size = 16
48
+ progress_bar = st.progress(0)
49
+
50
+ for i in range(0, len(test_dataset), batch_size):
51
+ # Update progress
52
+ progress = min(i / len(test_dataset), 1.0)
53
+ progress_bar.progress(progress)
54
+
55
+ batch = test_dataset[i:i + batch_size]
56
+ texts = [item['title'] for item in batch]
57
+
58
+ encoded = tokenizer(
59
+ texts,
60
+ padding=True,
61
+ truncation=True,
62
+ max_length=128,
63
+ return_tensors="pt"
64
+ )
65
+
66
+ inputs = {k: v.to(device) for k, v in encoded.items()}
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ logits = outputs.logits.cpu().numpy()
70
+
71
+ true_labels = [label2id[item['outlet']] for item in batch]
72
+ all_logits.extend(logits)
73
+ references.extend(true_labels)
74
+ progress_bar.progress(1.0)
75
+ probabilities = torch.softmax(torch.tensor(all_logits), dim=1).numpy()
76
+ return references, probabilities
77
+
78
+ def plot_roc_curve(references, probabilities):
79
+ fpr, tpr, _ = roc_curve(references, probabilities[:, 1])
80
+ auc_score = roc_auc_score(references, probabilities[:, 1])
81
+ fig = go.Figure()
82
+ fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'ROC Curve (AUC = {auc_score:.4f})'))
83
+ fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], name='Random Guess', line=dict(dash='dash')))
84
+ fig.update_layout(
85
+ title='ROC Curve',
86
+ xaxis_title='False Positive Rate',
87
+ yaxis_title='True Positive Rate',
88
+ showlegend=True
89
+ )
90
+ return fig, auc_score
91
+
92
+ def plot_metrics_by_threshold(references, probabilities):
93
+ thresholds = np.arange(0.0, 1.0, 0.01)
94
+ metrics = {
95
+ 'threshold': thresholds,
96
+ 'f1': [],
97
+ 'precision': [],
98
+ 'recall': []
99
+ }
100
+ best_f1 = 0
101
+ best_threshold = 0
102
+ best_metrics = {}
103
+ for threshold in thresholds:
104
+ preds = (probabilities[:, 1] > threshold).astype(int)
105
+ f1 = f1_score(references, preds)
106
+ precision, recall, _, _ = precision_recall_fscore_support(references, preds, average='binary')
107
+ metrics['f1'].append(f1)
108
+ metrics['precision'].append(precision)
109
+ metrics['recall'].append(recall)
110
+ if f1 > best_f1:
111
+ best_f1 = f1
112
+ best_threshold = threshold
113
+ cm = confusion_matrix(references, preds)
114
+ report = classification_report(references, preds, target_names=['FOXNEWS', 'NBC'], digits=4)
115
+ best_metrics = {
116
+ 'threshold': threshold,
117
+ 'f1_score': f1,
118
+ 'confusion_matrix': cm,
119
+ 'classification_report': report
120
+ }
121
+ fig = go.Figure()
122
+ fig.add_trace(go.Scatter(x=thresholds, y=metrics['f1'], name='F1 Score'))
123
+ fig.add_trace(go.Scatter(x=thresholds, y=metrics['precision'], name='Precision'))
124
+ fig.add_trace(go.Scatter(x=thresholds, y=metrics['recall'], name='Recall'))
125
+ fig.update_layout(
126
+ title='Metrics by Threshold',
127
+ xaxis_title='Threshold',
128
+ yaxis_title='Score',
129
+ showlegend=True
130
+ )
131
+ return fig, best_metrics
132
+
133
+ def plot_confusion_matrix(cm):
134
+ labels = ['FOXNEWS', 'NBC']
135
+ annotations = []
136
+ for i in range(len(labels)):
137
+ for j in range(len(labels)):
138
+ annotations.append(
139
+ dict(
140
+ text=str(cm[i, j]),
141
+ x=labels[j],
142
+ y=labels[i],
143
+ showarrow=False,
144
+ font=dict(color='white' if cm[i, j] > cm.max()/2 else 'black')
145
+ )
146
+ )
147
+ fig = go.Figure(data=go.Heatmap(
148
+ z=cm,
149
+ x=labels,
150
+ y=labels,
151
+ colorscale='Blues',
152
+ showscale=True
153
+ ))
154
+ fig.update_layout(
155
+ title='Confusion Matrix',
156
+ xaxis_title='Predicted Label',
157
+ yaxis_title='True Label',
158
+ annotations=annotations
159
+ )
160
+ return fig
161
+
162
+ def main():
163
+ st.title("News Classifier Model Evaluation")
164
+ uploaded_file = st.file_uploader("Upload your test dataset (CSV)", type=['csv'])
165
+ if uploaded_file is not None:
166
+ df = pd.read_csv(uploaded_file)
167
+ st.write("Preview of uploaded data:")
168
+ st.dataframe(df.head())
169
+ model, tokenizer, device = load_model_and_tokenizer()
170
+ if model and tokenizer:
171
+ test_dataset = preprocess_data(df)
172
+ if test_dataset:
173
+ st.write(f"Total examples: {len(test_dataset)}")
174
+ with st.spinner('Evaluating model...'):
175
+ references, probabilities = evaluate_model(model, tokenizer, device, test_dataset)
176
+ roc_fig, auc_score = plot_roc_curve(references, probabilities)
177
+ st.plotly_chart(roc_fig)
178
+ st.metric("AUC-ROC Score", f"{auc_score:.4f}")
179
+ metrics_fig, best_metrics = plot_metrics_by_threshold(references, probabilities)
180
+ st.plotly_chart(metrics_fig)
181
+ st.subheader("Best Threshold Evaluation")
182
+ col1, col2 = st.columns(2)
183
+ with col1:
184
+ st.metric("Best Threshold", f"{best_metrics['threshold']:.2f}")
185
+ with col2:
186
+ st.metric("Best F1 Score", f"{best_metrics['f1_score']:.4f}")
187
+ st.subheader("Confusion Matrix")
188
+ cm_fig = plot_confusion_matrix(best_metrics['confusion_matrix'])
189
+ st.plotly_chart(cm_fig)
190
+ st.subheader("Classification Report")
191
+ st.text(best_metrics['classification_report'])
192
+ if __name__ == "__main__":
193
+ main()