Quexoo commited on
Commit
293e02f
·
verified ·
1 Parent(s): 6cea2e1

Upload 2 files

Browse files
Files changed (2) hide show
  1. VisionBERT.py +533 -0
  2. data/Vision_Survey_Cleaned.csv +0 -0
VisionBERT.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+ from datasets import Dataset
4
+ import torch
5
+ from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, accuracy_score
6
+ from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForSequenceClassification, DataCollatorWithPadding
7
+ import pandas as pd
8
+ import numpy as np
9
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.cluster import KMeans
12
+ from torch.nn import CrossEntropyLoss
13
+ import pickle
14
+
15
+ os.environ['OMP_NUM_THREADS'] = '7'
16
+
17
+
18
+ class WeightedTrainer(Trainer):
19
+ def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch: int = None):
20
+ """
21
+ Custom loss computation with sample weights
22
+ """
23
+ labels = inputs.get("labels")
24
+ weights = inputs.get("weight")
25
+
26
+ # Forward pass
27
+ outputs = model(**{k: v for k, v in inputs.items()
28
+ if k not in ["weight", "labels"]})
29
+ logits = outputs.get("logits")
30
+
31
+ # Add labels back to outputs
32
+ outputs["labels"] = labels
33
+
34
+ # Compute weighted loss
35
+ if weights is not None:
36
+ weights = weights.to(logits.device)
37
+ loss_fct = CrossEntropyLoss(reduction='none')
38
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels),
39
+ labels.view(-1))
40
+
41
+ # Adjust weights if num_items_in_batch is provided
42
+ if num_items_in_batch:
43
+ weights = weights[:num_items_in_batch]
44
+
45
+ loss = (loss * weights.view(-1)).mean()
46
+ else:
47
+ loss_fct = CrossEntropyLoss(label_smoothing=0.1)
48
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels),
49
+ labels.view(-1))
50
+
51
+ outputs["loss"] = loss
52
+ return (loss, outputs) if return_outputs else loss
53
+
54
+
55
+ def create_feature_vector(df):
56
+ """Create numerical feature vector for clustering with sample size weighting, handling missing/unseen labels."""
57
+
58
+ # Initialize LabelEncoders
59
+ le_gender = LabelEncoder()
60
+ le_race = LabelEncoder()
61
+ le_risk = LabelEncoder()
62
+
63
+ # Fit and transform while handling missing values
64
+ gender_encoded = le_gender.fit(df['Gender'].unique()).transform(df['Gender'].fillna('Unknown'))
65
+ race_encoded = le_race.fit(df['RaceEthnicity'].unique()).transform(df['RaceEthnicity'].fillna('Unknown'))
66
+ risk_encoded = le_risk.fit(df['RiskFactor'].unique()).transform(df['RiskFactor'].fillna('Unknown'))
67
+
68
+ # Create age groups numerical representation with a default for missing values
69
+ age_map = {
70
+ '12-17 years': 0,
71
+ '18-39 years': 1,
72
+ '40-64 years': 2,
73
+ '65-79 years': 3,
74
+ '80 years and older': 4 # Include all possible labels, even if missing
75
+ }
76
+
77
+ # Use `.get()` with a default value for missing/unseen age groups
78
+ age_encoded = df['Age'].map(lambda x: age_map.get(x, -1))
79
+
80
+ # Combine features
81
+ features = np.column_stack([
82
+ age_encoded,
83
+ gender_encoded,
84
+ race_encoded,
85
+ risk_encoded,
86
+ df['Sample_Size'].values # Add sample size as a feature
87
+ ])
88
+
89
+ # Scale features
90
+ scaler = StandardScaler()
91
+ features_scaled = scaler.fit_transform(features)
92
+
93
+ return features_scaled, scaler
94
+
95
+
96
+ def weighted_kmeans(X, sample_weights, n_clusters, max_iter=300, random_state=42):
97
+ """Custom K-means implementation that considers sample weights"""
98
+ n_samples = X.shape[0]
99
+
100
+ # Initialize centroids randomly from the weighted distribution
101
+ rng = np.random.RandomState(random_state)
102
+ weighted_indices = rng.choice(n_samples, size=n_clusters, p=sample_weights / sample_weights.sum())
103
+ centroids = X[weighted_indices]
104
+
105
+ for _ in range(max_iter):
106
+ # Assign points to nearest centroid
107
+ distances = np.sqrt(((X[:, np.newaxis] - centroids) ** 2).sum(axis=2))
108
+ labels = np.argmin(distances, axis=1)
109
+
110
+ # Update centroids using weighted means
111
+ new_centroids = np.zeros_like(centroids)
112
+ for k in range(n_clusters):
113
+ mask = labels == k
114
+ if mask.any():
115
+ weights_k = sample_weights[mask]
116
+ new_centroids[k] = np.average(X[mask], axis=0, weights=weights_k)
117
+
118
+ # Check for convergence
119
+ if np.allclose(centroids, new_centroids):
120
+ break
121
+
122
+ centroids = new_centroids
123
+
124
+ return labels, centroids
125
+
126
+
127
+ def prepare_data(file_path='data/Vision_Survey_Cleaned.csv'):
128
+ """Load and prepare the vision health dataset with sample-size-aware clustering."""
129
+ print("\nLoading and preparing data...")
130
+ df = pd.read_csv(file_path)
131
+
132
+ # Filter data
133
+ vision_cat = ['Best-corrected visual acuity']
134
+ df = df[df['Question'].isin(vision_cat)].copy()
135
+ df = df[df["RiskFactor"] != "All participants"]
136
+ df = df[df["RiskFactorResponse"] != "Total"]
137
+
138
+ # Reset index after filtering
139
+ df = df.reset_index(drop=True)
140
+
141
+ # Create feature vectors for clustering
142
+ features_scaled, scaler = create_feature_vector(df)
143
+
144
+ # Normalize sample sizes for weights
145
+ sample_weights = df['Sample_Size'].values
146
+ sample_weights = sample_weights / sample_weights.sum()
147
+
148
+ # Apply weighted clustering
149
+ n_clusters = min(5, len(df))
150
+ clusters, centroids = weighted_kmeans(
151
+ features_scaled,
152
+ sample_weights,
153
+ n_clusters=n_clusters
154
+ )
155
+
156
+ # Add clusters as a column
157
+ df['cluster'] = clusters
158
+
159
+ # Calculate cluster importance based on total sample size in each cluster
160
+ cluster_total_samples = df.groupby('cluster')['Sample_Size'].sum()
161
+ cluster_weights = cluster_total_samples / cluster_total_samples.sum()
162
+
163
+ # Enhanced feature engineering with clustering information
164
+ df['doc'] = df.apply(
165
+ lambda x: f"""
166
+ Patient Demographics:
167
+ - Age Category: {x['Age']}
168
+ - Gender: {x['Gender']}
169
+ - Race/Ethnicity: {x['RaceEthnicity']}
170
+
171
+ Risk Factors:
172
+ - {x['RiskFactor']}: {x['RiskFactorResponse']}
173
+
174
+ Additional Information:
175
+ - Sample Size: {x['Sample_Size']}
176
+ - Cluster Profile: {x['cluster']} (Weight: {cluster_weights.get(x['cluster'], 0):.3f})
177
+ """.strip(),
178
+ axis=1
179
+ )
180
+
181
+ # Encode labels
182
+ le = LabelEncoder()
183
+ df['labels'] = le.fit_transform(df['Response'].astype(str))
184
+
185
+ # Combine sample size weights with cluster importance
186
+ df['weight'] = df.apply(
187
+ lambda x: (x['Sample_Size'] / df['Sample_Size'].sum()) *
188
+ cluster_weights.get(x['cluster'], 0),
189
+ axis=1
190
+ )
191
+
192
+ # Create train and test splits with stratification
193
+ train_df, test_df = train_test_split(
194
+ df,
195
+ test_size=0.2,
196
+ stratify=df['labels'],
197
+ random_state=42
198
+ )
199
+
200
+ # Convert to dict format
201
+ train_data = {
202
+ 'doc': train_df['doc'].tolist(),
203
+ 'labels': train_df['labels'].tolist(),
204
+ 'weight': train_df['weight'].tolist()
205
+ }
206
+
207
+ test_data = {
208
+ 'doc': test_df['doc'].tolist(),
209
+ 'labels': test_df['labels'].tolist(),
210
+ 'weight': test_df['weight'].tolist()
211
+ }
212
+
213
+ # Convert to datasets
214
+ train_dataset = Dataset.from_dict(train_data)
215
+ test_dataset = Dataset.from_dict(test_data)
216
+
217
+ dataset_dict = {
218
+ 'train': train_dataset,
219
+ 'test': test_dataset
220
+ }
221
+
222
+ # Print detailed dataset statistics
223
+ print("\nDataset Summary:")
224
+ print(f"Training samples: {len(train_dataset)}")
225
+ print(f"Test samples: {len(test_dataset)}")
226
+
227
+ print("\nCluster Distribution:")
228
+ for i in range(n_clusters):
229
+ cluster_mask = df['cluster'] == i
230
+ cluster_samples = df[cluster_mask]['Sample_Size'].sum()
231
+ print(f"\nCluster {i} (Total samples: {cluster_samples:,}, Weight: {cluster_weights.get(i, 0):.3f}):")
232
+ print("Most common characteristics:")
233
+ for col in ['Age', 'Gender', 'RaceEthnicity', 'RiskFactor']:
234
+ values = df[col][cluster_mask].value_counts().head(3)
235
+ samples = df[cluster_mask].groupby(col)['Sample_Size'].sum().sort_values(ascending=False).head(3)
236
+ print(f"{col}:")
237
+ for val, count in values.items():
238
+ sample_count = samples.get(val, 0) # Use .get() for safety
239
+ print(f" - {val}: {count} groups ({sample_count:,} individuals)")
240
+
241
+ print("\nLabel Distribution:")
242
+ for label, idx in zip(le.classes_, range(len(le.classes_))):
243
+ count = (df['labels'] == idx).sum()
244
+ total_size = df[df['labels'] == idx]['Sample_Size'].sum()
245
+ print(f"{label}: {count} groups, {total_size:,} individuals")
246
+
247
+ return dataset_dict, le
248
+
249
+
250
+
251
+ def main():
252
+ # Setup
253
+ output_dir = "models/vision-classifier"
254
+ os.makedirs(output_dir, exist_ok=True)
255
+
256
+ # Load the dataset
257
+ dataset_dict, label_encoder = prepare_data()
258
+
259
+ # Initialize the tokenizer
260
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
261
+
262
+ # Define tokenization function within main to have access to tokenizer
263
+ def tokenize_function(examples):
264
+ """Tokenize the input texts and maintain the correct column names"""
265
+ tokenized = tokenizer(
266
+ examples["doc"],
267
+ truncation=True,
268
+ padding='max_length',
269
+ max_length=128,
270
+ return_tensors=None
271
+ )
272
+ # Keep the additional columns
273
+ tokenized['labels'] = examples['labels']
274
+ tokenized['weight'] = examples['weight']
275
+ return tokenized
276
+
277
+ # Tokenize the datasets
278
+ tokenized_datasets = {}
279
+ for split, dataset in dataset_dict.items():
280
+ tokenized_datasets[split] = dataset.map(
281
+ tokenize_function,
282
+ batched=True,
283
+ remove_columns=['doc']
284
+ )
285
+
286
+ # Print sample to verify
287
+ print("\nSample tokenized data:", tokenized_datasets["train"][0])
288
+
289
+ # Initialize the model
290
+ model = AutoModelForSequenceClassification.from_pretrained(
291
+ "distilbert-base-uncased",
292
+ num_labels=len(label_encoder.classes_),
293
+ id2label={i: label for i, label in enumerate(label_encoder.classes_)},
294
+ label2id={label: i for i, label in enumerate(label_encoder.classes_)},
295
+ )
296
+
297
+ # Data collator
298
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
299
+
300
+ # Check device
301
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
302
+ print(f"\nTraining on device: {device}")
303
+
304
+ # Move model to device
305
+ model.to(device)
306
+
307
+ # Set up training arguments
308
+ training_args = TrainingArguments(
309
+ output_dir=output_dir,
310
+ learning_rate=3e-5,
311
+ per_device_train_batch_size=8,
312
+ per_device_eval_batch_size=8,
313
+ num_train_epochs=7,
314
+ weight_decay=0.01,
315
+ eval_strategy="epoch",
316
+ save_strategy="epoch",
317
+ load_best_model_at_end=True,
318
+ remove_unused_columns=False,
319
+ push_to_hub=True,
320
+ )
321
+
322
+ # Create the Trainer
323
+ trainer = WeightedTrainer(
324
+ model=model,
325
+ args=training_args,
326
+ train_dataset=tokenized_datasets["train"],
327
+ eval_dataset=tokenized_datasets["test"],
328
+ data_collator=data_collator,
329
+ )
330
+
331
+ # Train the model
332
+ print("\nStarting training...")
333
+ trainer.train()
334
+
335
+ # Save the model
336
+ print("\nSaving model...")
337
+ trainer.save_model(output_dir=os.path.join(output_dir, "model"))
338
+
339
+ # Save the tokenizer
340
+ tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
341
+
342
+ # Save the label encoder
343
+ label_encoder_path = os.path.join(output_dir, "label_encoder.pkl")
344
+ with open(label_encoder_path, 'wb') as f:
345
+ pickle.dump(label_encoder, f)
346
+
347
+ return trainer, model, tokenizer, label_encoder
348
+
349
+
350
+ def evaluate_model(model, eval_dataset, tokenizer, label_encoder, device) -> Dict:
351
+ """
352
+ Evaluate model performance using multiple metrics
353
+ """
354
+ model.eval()
355
+ all_predictions = []
356
+ all_labels = []
357
+
358
+ # Process each example in evaluation dataset
359
+ for item in eval_dataset:
360
+ # Tokenize input
361
+ inputs = tokenizer(
362
+ item['doc'],
363
+ truncation=True,
364
+ padding=True,
365
+ return_tensors="pt"
366
+ )
367
+ inputs = {k: v.to(device) for k, v in inputs.items()}
368
+
369
+ # Get predictions
370
+ with torch.no_grad():
371
+ outputs = model(**inputs)
372
+ predictions = torch.argmax(outputs.logits, dim=1)
373
+
374
+ all_predictions.extend(predictions.cpu().numpy())
375
+ all_labels.append(item['labels'])
376
+
377
+ # Calculate metrics
378
+ accuracy = accuracy_score(all_labels, all_predictions)
379
+ precision, recall, f1, support = precision_recall_fscore_support(
380
+ all_labels,
381
+ all_predictions,
382
+ average='weighted'
383
+ )
384
+
385
+ # Calculate per-class metrics
386
+ per_class_precision, per_class_recall, per_class_f1, _ = precision_recall_fscore_support(
387
+ all_labels,
388
+ all_predictions,
389
+ average=None
390
+ )
391
+
392
+ # Create confusion matrix
393
+ conf_matrix = confusion_matrix(all_labels, all_predictions)
394
+
395
+ # Combine metrics
396
+ metrics = {
397
+ 'accuracy': accuracy,
398
+ 'weighted_precision': precision,
399
+ 'weighted_recall': recall,
400
+ 'weighted_f1': f1,
401
+ 'confusion_matrix': conf_matrix,
402
+ 'per_class_metrics': {
403
+ label: {
404
+ 'precision': p,
405
+ 'recall': r,
406
+ 'f1': f
407
+ } for label, p, r, f in zip(
408
+ label_encoder.classes_,
409
+ per_class_precision,
410
+ per_class_recall,
411
+ per_class_f1
412
+ )
413
+ }
414
+ }
415
+
416
+ return metrics
417
+
418
+
419
+ def print_evaluation_report(metrics: Dict, label_encoder):
420
+ """
421
+ Print formatted evaluation report
422
+ """
423
+ print("\n" + "=" * 50)
424
+ print("MODEL EVALUATION REPORT")
425
+ print("=" * 50)
426
+
427
+ print("\nOverall Metrics:")
428
+ print(f"Accuracy: {metrics['accuracy']:.4f}")
429
+ print(f"Weighted Precision: {metrics['weighted_precision']:.4f}")
430
+ print(f"Weighted Recall: {metrics['weighted_recall']:.4f}")
431
+ print(f"Weighted F1-Score: {metrics['weighted_f1']:.4f}")
432
+
433
+ print("\nPer-Class Metrics:")
434
+ print("-" * 50)
435
+ print(f"{'Class':<30} {'Precision':>10} {'Recall':>10} {'F1-Score':>10}")
436
+ print("-" * 50)
437
+
438
+ for label, class_metrics in metrics['per_class_metrics'].items():
439
+ print(
440
+ f"{label:<30} {class_metrics['precision']:>10.4f} {class_metrics['recall']:>10.4f} {class_metrics['f1']:>10.4f}")
441
+
442
+ print("\nConfusion Matrix:")
443
+ print("-" * 50)
444
+ conf_matrix = metrics['confusion_matrix']
445
+ print(conf_matrix)
446
+
447
+
448
+ if __name__ == "__main__":
449
+ output_dir = "models/vision-classifier"
450
+ model_path = os.path.join(output_dir, "model")
451
+ tokenizer_path = os.path.join(output_dir, "tokenizer")
452
+
453
+ if os.path.exists(model_path):
454
+ print("\nLoading pre-trained model...")
455
+ try:
456
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
457
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
458
+ label_encoder_path = os.path.join(output_dir, "label_encoder.pkl")
459
+ if os.path.exists(label_encoder_path):
460
+ with open(label_encoder_path, 'rb') as f:
461
+ label_encoder = pickle.load(f)
462
+ else:
463
+ print("Warning: Label encoder not found. Running full training...")
464
+ trainer, model, tokenizer, label_encoder = main()
465
+
466
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
467
+ model.to(device)
468
+ print(f"Model loaded successfully and moved to {device}")
469
+
470
+ # Load test dataset for evaluation
471
+ dataset_dict, _ = prepare_data()
472
+
473
+ # Run evaluation
474
+ print("\nEvaluating model performance...")
475
+ eval_metrics = evaluate_model(
476
+ model,
477
+ dataset_dict['test'],
478
+ tokenizer,
479
+ label_encoder,
480
+ device
481
+ )
482
+
483
+ # Print evaluation report
484
+ print_evaluation_report(eval_metrics, label_encoder)
485
+
486
+ except Exception as e:
487
+ print(f"Error loading model: {e}")
488
+ print("Running full training instead...")
489
+ trainer, model, tokenizer, label_encoder = main()
490
+ else:
491
+ print("\nNo pre-trained model found. Running training...")
492
+ trainer, model, tokenizer, label_encoder = main()
493
+
494
+
495
+ def predict_vision_status(text, model, tokenizer, label_encoder):
496
+ """Make prediction using the loaded/trained model"""
497
+ inputs = tokenizer(
498
+ text,
499
+ truncation=True,
500
+ padding=True,
501
+ return_tensors="pt"
502
+ )
503
+
504
+ device = next(model.parameters()).device
505
+ inputs = {k: v.to(device) for k, v in inputs.items()}
506
+
507
+ with torch.no_grad():
508
+ outputs = model(**inputs)
509
+ # Apply softmax to get probabilities
510
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
511
+
512
+ # Convert to numpy array
513
+ probabilities = probabilities.cpu().numpy()[0]
514
+
515
+ # Create list of (label, probability) tuples
516
+ predictions = []
517
+ for idx, prob in enumerate(probabilities):
518
+ label = label_encoder.inverse_transform([idx])[0]
519
+ predictions.append((label, float(prob)))
520
+
521
+ # Sort by probability in descending order
522
+ predictions.sort(key=lambda x: x[1], reverse=True)
523
+
524
+ return predictions
525
+
526
+ example_text = "Age: 40-64 years, Gender: Female, Race: White, non-Hispanic, Diabetes: No"
527
+ predictions = predict_vision_status(example_text, model, tokenizer, label_encoder)
528
+
529
+ print(f"\nPredictions for: {example_text}")
530
+ print("\nLabel Confidence Scores:")
531
+ print("-" * 50)
532
+ for label, confidence in predictions:
533
+ print(f"{label:<30} {confidence:.2%}")
data/Vision_Survey_Cleaned.csv ADDED
The diff for this file is too large to render. See raw diff