File size: 11,377 Bytes
be5bceb 445733e be5bceb 445733e be5bceb 445733e e6603d4 445733e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 |
---
title: News Source Classifier
emoji: 📰
colorFrom: blue
colorTo: red
sdk: streamlit
app_file: eval_pipeline.py
library_name: transformers
pinned: false
language: en
license: mit
tags:
- text-classification
- news-classification
- BERT
- pytorch
- transformers
pipeline_tag: text-classification
widget:
- example_title: "Politics News Headline"
text: "Trump's campaign rival decides between voting for him or Biden"
- example_title: "International News Headline"
text: "World Food Programme Director Cindy McCain: Northern Gaza is in a 'full-blown famine'"
- example_title: "Domestic News Headline"
text: "Ohio sheriff suggests residents keep a list of homes with Harris yard signs"
model-index:
- name: News Source Classifier
results:
- task:
type: text-classification
name: Text Classification
dataset:
name: Custom FOX-NBC Dataset
type: Custom
metrics:
- name: F1 Score
type: f1
value: 0.85
---
# News Source Classifier - BERT Model
## Model Overview
This repository contains a fine-tuned BERT model that classifies news headlines between Fox News and NBC News, along with an evaluation pipeline for assessing model performance using Streamlit.
### Model Details
- **Base Model**: BERT (bert-base-uncased)
- **Task**: Binary classification (Fox News vs NBC News)
- **Model ID**: CIS519PG/News_Classifier_Demo
- **Training Data**: News headlines from Fox News and NBC News
- **Input**: News article headlines (text)
- **Output**: Binary classification with probability scores
## Evaluation Pipeline Setup
### Prerequisites
- Python 3.8+
- pip package manager
### Required Dependencies
Install the required packages using pip:
```bash
pip install streamlit pandas torch transformers scikit-learn numpy plotly tqdm
```
### Running the Evaluation Pipeline
1. Save the following provided evaluation code as `eval_pipeline.py`, also downloadable in files.
```bash
import streamlit as st
import pandas as pd
import torch
from transformers import BertTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report, f1_score, precision_recall_fscore_support
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from tqdm import tqdm
def load_model_and_tokenizer():
try:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("CIS519PG/News_Classifier_Demo")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
return model, tokenizer, device
except Exception as e:
st.error(f"Error loading model or tokenizer: {str(e)}")
return None, None, None
def preprocess_data(df):
try:
processed_data = []
for _, row in df.iterrows():
outlet = row["outlet"].strip().upper()
if outlet == "FOX NEWS":
outlet = "FOXNEWS"
elif outlet == "NBC NEWS":
outlet = "NBC"
processed_data.append({
"title": row["title"],
"outlet": outlet
})
return processed_data
except Exception as e:
st.error(f"Error preprocessing data: {str(e)}")
return None
def evaluate_model(model, tokenizer, device, test_dataset):
label2id = {"FOXNEWS": 0, "NBC": 1}
all_logits = []
references = []
batch_size = 16
progress_bar = st.progress(0)
for i in range(0, len(test_dataset), batch_size):
progress = min(i / len(test_dataset), 1.0)
progress_bar.progress(progress)
batch = test_dataset[i:i + batch_size]
texts = [item['title'] for item in batch]
encoded = tokenizer(
texts,
padding=True,
truncation=True,
max_length=128,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in encoded.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu().numpy()
true_labels = [label2id[item['outlet']] for item in batch]
all_logits.extend(logits)
references.extend(true_labels)
progress_bar.progress(1.0)
probabilities = torch.softmax(torch.tensor(all_logits), dim=1).numpy()
return references, probabilities
def plot_roc_curve(references, probabilities):
fpr, tpr, _ = roc_curve(references, probabilities[:, 1])
auc_score = roc_auc_score(references, probabilities[:, 1])
fig = go.Figure()
fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'ROC Curve (AUC = {auc_score:.4f})'))
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], name='Random Guess', line=dict(dash='dash')))
fig.update_layout(
title='ROC Curve',
xaxis_title='False Positive Rate',
yaxis_title='True Positive Rate',
showlegend=True
)
return fig, auc_score
def plot_metrics_by_threshold(references, probabilities):
thresholds = np.arange(0.0, 1.0, 0.01)
metrics = {
'threshold': thresholds,
'f1': [],
'precision': [],
'recall': []
}
best_f1 = 0
best_threshold = 0
best_metrics = {}
for threshold in thresholds:
preds = (probabilities[:, 1] > threshold).astype(int)
f1 = f1_score(references, preds)
precision, recall, _, _ = precision_recall_fscore_support(references, preds, average='binary')
metrics['f1'].append(f1)
metrics['precision'].append(precision)
metrics['recall'].append(recall)
if f1 > best_f1:
best_f1 = f1
best_threshold = threshold
cm = confusion_matrix(references, preds)
report = classification_report(references, preds, target_names=['FOXNEWS', 'NBC'], digits=4)
best_metrics = {
'threshold': threshold,
'f1_score': f1,
'confusion_matrix': cm,
'classification_report': report
}
fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds, y=metrics['f1'], name='F1 Score'))
fig.add_trace(go.Scatter(x=thresholds, y=metrics['precision'], name='Precision'))
fig.add_trace(go.Scatter(x=thresholds, y=metrics['recall'], name='Recall'))
fig.update_layout(
title='Metrics by Threshold',
xaxis_title='Threshold',
yaxis_title='Score',
showlegend=True
)
return fig, best_metrics
def plot_confusion_matrix(cm):
labels = ['FOXNEWS', 'NBC']
annotations = []
for i in range(len(labels)):
for j in range(len(labels)):
annotations.append(
dict(
text=str(cm[i, j]),
x=labels[j],
y=labels[i],
showarrow=False,
font=dict(color='white' if cm[i, j] > cm.max()/2 else 'black')
)
)
fig = go.Figure(data=go.Heatmap(
z=cm,
x=labels,
y=labels,
colorscale='Blues',
showscale=True
))
fig.update_layout(
title='Confusion Matrix',
xaxis_title='Predicted Label',
yaxis_title='True Label',
annotations=annotations
)
return fig
def main():
st.title("News Classifier Model Evaluation")
uploaded_file = st.file_uploader("Upload your test dataset (CSV)", type=['csv'])
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
st.write("Preview of uploaded data:")
st.dataframe(df.head())
model, tokenizer, device = load_model_and_tokenizer()
if model and tokenizer:
test_dataset = preprocess_data(df)
if test_dataset:
st.write(f"Total examples: {len(test_dataset)}")
with st.spinner('Evaluating model...'):
references, probabilities = evaluate_model(model, tokenizer, device, test_dataset)
roc_fig, auc_score = plot_roc_curve(references, probabilities)
st.plotly_chart(roc_fig)
st.metric("AUC-ROC Score", f"{auc_score:.4f}")
metrics_fig, best_metrics = plot_metrics_by_threshold(references, probabilities)
st.plotly_chart(metrics_fig)
st.subheader("Best Threshold Evaluation")
col1, col2 = st.columns(2)
with col1:
st.metric("Best Threshold", f"{best_metrics['threshold']:.2f}")
with col2:
st.metric("Best F1 Score", f"{best_metrics['f1_score']:.4f}")
st.subheader("Confusion Matrix")
cm_fig = plot_confusion_matrix(best_metrics['confusion_matrix'])
st.plotly_chart(cm_fig)
st.subheader("Classification Report")
st.text(best_metrics['classification_report'])
if __name__ == "__main__":
main()
```
2. Run the Streamlit application:
```bash
streamlit run eval_pipeline.py
```
3. The web interface will automatically open in your default browser
### Using the Web Interface
1. **Upload Test Data**:
- Prepare your test data in CSV format
- Required columns:
- Index column (automatic numbering)
- "title": The news headline text
- "label": Binary label (0 for Fox News, 1 for NBC News)
- "News Outlet": The source ("Fox News" or "NBC News")
2. **View Evaluation Results**:
The pipeline will display:
- Data preview
- ROC curve with AUC score
- Metrics vs threshold plot
- Best threshold and F1 score
- Confusion matrix visualization
- Detailed classification report
### Sample Data Format
```csv
,title,label,News Outlet
0,"Jack Carr's take on the late Tom Clancy, born on this day in 1947",0,Fox News
1,"Feeding America CEO asks community to help others amid today's high inflation",0,Fox News
2,"World Food Programme Director Cindy McCain: Northern Gaza is in a 'full-blown famine'",1,NBC News
3,"Ohio sheriff suggests residents keep a list of homes with Harris yard signs",1,NBC News
```
## Model Architecture
- Base model: BERT (bert-base-uncased)
- Fine-tuned for binary classification
- Uses PyTorch and Hugging Face Transformers
## Limitations and Bias
This model has been trained on news headlines from specific sources (Fox News and NBC News) and time periods, which may introduce certain biases:
- Limited to two specific news sources
- Temporal bias based on training data collection period
- May not generalize well to other news sources or formats
## Evaluation Metrics
The pipeline provides comprehensive evaluation metrics:
- AUC-ROC Score
- F1 Score
- Precision & Recall
- Confusion Matrix
- Detailed Classification Report
## Troubleshooting
Common issues and solutions:
1. **CUDA/GPU Error**:
- The pipeline automatically falls back to CPU if CUDA is not available
- No action needed from user
2. **Memory Issues**:
- Default batch size is 16
- Reduce batch size if memory constraints exist
3. **File Format Error**:
- Ensure CSV file has exact column names: "title", "label", "News Outlet"
- Verify label values are 0 or 1
- Confirm "News Outlet" values are exactly "Fox News" or "NBC News"
## License
This project is licensed under the MIT License. |