Upload 2 files
Browse files- VisionBERT.py +533 -0
- data/Vision_Survey_Cleaned.csv +0 -0
VisionBERT.py
ADDED
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict
|
3 |
+
from datasets import Dataset
|
4 |
+
import torch
|
5 |
+
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, accuracy_score
|
6 |
+
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForSequenceClassification, DataCollatorWithPadding
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
10 |
+
from sklearn.model_selection import train_test_split
|
11 |
+
from sklearn.cluster import KMeans
|
12 |
+
from torch.nn import CrossEntropyLoss
|
13 |
+
import pickle
|
14 |
+
|
15 |
+
os.environ['OMP_NUM_THREADS'] = '7'
|
16 |
+
|
17 |
+
|
18 |
+
class WeightedTrainer(Trainer):
|
19 |
+
def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch: int = None):
|
20 |
+
"""
|
21 |
+
Custom loss computation with sample weights
|
22 |
+
"""
|
23 |
+
labels = inputs.get("labels")
|
24 |
+
weights = inputs.get("weight")
|
25 |
+
|
26 |
+
# Forward pass
|
27 |
+
outputs = model(**{k: v for k, v in inputs.items()
|
28 |
+
if k not in ["weight", "labels"]})
|
29 |
+
logits = outputs.get("logits")
|
30 |
+
|
31 |
+
# Add labels back to outputs
|
32 |
+
outputs["labels"] = labels
|
33 |
+
|
34 |
+
# Compute weighted loss
|
35 |
+
if weights is not None:
|
36 |
+
weights = weights.to(logits.device)
|
37 |
+
loss_fct = CrossEntropyLoss(reduction='none')
|
38 |
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
|
39 |
+
labels.view(-1))
|
40 |
+
|
41 |
+
# Adjust weights if num_items_in_batch is provided
|
42 |
+
if num_items_in_batch:
|
43 |
+
weights = weights[:num_items_in_batch]
|
44 |
+
|
45 |
+
loss = (loss * weights.view(-1)).mean()
|
46 |
+
else:
|
47 |
+
loss_fct = CrossEntropyLoss(label_smoothing=0.1)
|
48 |
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
|
49 |
+
labels.view(-1))
|
50 |
+
|
51 |
+
outputs["loss"] = loss
|
52 |
+
return (loss, outputs) if return_outputs else loss
|
53 |
+
|
54 |
+
|
55 |
+
def create_feature_vector(df):
|
56 |
+
"""Create numerical feature vector for clustering with sample size weighting, handling missing/unseen labels."""
|
57 |
+
|
58 |
+
# Initialize LabelEncoders
|
59 |
+
le_gender = LabelEncoder()
|
60 |
+
le_race = LabelEncoder()
|
61 |
+
le_risk = LabelEncoder()
|
62 |
+
|
63 |
+
# Fit and transform while handling missing values
|
64 |
+
gender_encoded = le_gender.fit(df['Gender'].unique()).transform(df['Gender'].fillna('Unknown'))
|
65 |
+
race_encoded = le_race.fit(df['RaceEthnicity'].unique()).transform(df['RaceEthnicity'].fillna('Unknown'))
|
66 |
+
risk_encoded = le_risk.fit(df['RiskFactor'].unique()).transform(df['RiskFactor'].fillna('Unknown'))
|
67 |
+
|
68 |
+
# Create age groups numerical representation with a default for missing values
|
69 |
+
age_map = {
|
70 |
+
'12-17 years': 0,
|
71 |
+
'18-39 years': 1,
|
72 |
+
'40-64 years': 2,
|
73 |
+
'65-79 years': 3,
|
74 |
+
'80 years and older': 4 # Include all possible labels, even if missing
|
75 |
+
}
|
76 |
+
|
77 |
+
# Use `.get()` with a default value for missing/unseen age groups
|
78 |
+
age_encoded = df['Age'].map(lambda x: age_map.get(x, -1))
|
79 |
+
|
80 |
+
# Combine features
|
81 |
+
features = np.column_stack([
|
82 |
+
age_encoded,
|
83 |
+
gender_encoded,
|
84 |
+
race_encoded,
|
85 |
+
risk_encoded,
|
86 |
+
df['Sample_Size'].values # Add sample size as a feature
|
87 |
+
])
|
88 |
+
|
89 |
+
# Scale features
|
90 |
+
scaler = StandardScaler()
|
91 |
+
features_scaled = scaler.fit_transform(features)
|
92 |
+
|
93 |
+
return features_scaled, scaler
|
94 |
+
|
95 |
+
|
96 |
+
def weighted_kmeans(X, sample_weights, n_clusters, max_iter=300, random_state=42):
|
97 |
+
"""Custom K-means implementation that considers sample weights"""
|
98 |
+
n_samples = X.shape[0]
|
99 |
+
|
100 |
+
# Initialize centroids randomly from the weighted distribution
|
101 |
+
rng = np.random.RandomState(random_state)
|
102 |
+
weighted_indices = rng.choice(n_samples, size=n_clusters, p=sample_weights / sample_weights.sum())
|
103 |
+
centroids = X[weighted_indices]
|
104 |
+
|
105 |
+
for _ in range(max_iter):
|
106 |
+
# Assign points to nearest centroid
|
107 |
+
distances = np.sqrt(((X[:, np.newaxis] - centroids) ** 2).sum(axis=2))
|
108 |
+
labels = np.argmin(distances, axis=1)
|
109 |
+
|
110 |
+
# Update centroids using weighted means
|
111 |
+
new_centroids = np.zeros_like(centroids)
|
112 |
+
for k in range(n_clusters):
|
113 |
+
mask = labels == k
|
114 |
+
if mask.any():
|
115 |
+
weights_k = sample_weights[mask]
|
116 |
+
new_centroids[k] = np.average(X[mask], axis=0, weights=weights_k)
|
117 |
+
|
118 |
+
# Check for convergence
|
119 |
+
if np.allclose(centroids, new_centroids):
|
120 |
+
break
|
121 |
+
|
122 |
+
centroids = new_centroids
|
123 |
+
|
124 |
+
return labels, centroids
|
125 |
+
|
126 |
+
|
127 |
+
def prepare_data(file_path='data/Vision_Survey_Cleaned.csv'):
|
128 |
+
"""Load and prepare the vision health dataset with sample-size-aware clustering."""
|
129 |
+
print("\nLoading and preparing data...")
|
130 |
+
df = pd.read_csv(file_path)
|
131 |
+
|
132 |
+
# Filter data
|
133 |
+
vision_cat = ['Best-corrected visual acuity']
|
134 |
+
df = df[df['Question'].isin(vision_cat)].copy()
|
135 |
+
df = df[df["RiskFactor"] != "All participants"]
|
136 |
+
df = df[df["RiskFactorResponse"] != "Total"]
|
137 |
+
|
138 |
+
# Reset index after filtering
|
139 |
+
df = df.reset_index(drop=True)
|
140 |
+
|
141 |
+
# Create feature vectors for clustering
|
142 |
+
features_scaled, scaler = create_feature_vector(df)
|
143 |
+
|
144 |
+
# Normalize sample sizes for weights
|
145 |
+
sample_weights = df['Sample_Size'].values
|
146 |
+
sample_weights = sample_weights / sample_weights.sum()
|
147 |
+
|
148 |
+
# Apply weighted clustering
|
149 |
+
n_clusters = min(5, len(df))
|
150 |
+
clusters, centroids = weighted_kmeans(
|
151 |
+
features_scaled,
|
152 |
+
sample_weights,
|
153 |
+
n_clusters=n_clusters
|
154 |
+
)
|
155 |
+
|
156 |
+
# Add clusters as a column
|
157 |
+
df['cluster'] = clusters
|
158 |
+
|
159 |
+
# Calculate cluster importance based on total sample size in each cluster
|
160 |
+
cluster_total_samples = df.groupby('cluster')['Sample_Size'].sum()
|
161 |
+
cluster_weights = cluster_total_samples / cluster_total_samples.sum()
|
162 |
+
|
163 |
+
# Enhanced feature engineering with clustering information
|
164 |
+
df['doc'] = df.apply(
|
165 |
+
lambda x: f"""
|
166 |
+
Patient Demographics:
|
167 |
+
- Age Category: {x['Age']}
|
168 |
+
- Gender: {x['Gender']}
|
169 |
+
- Race/Ethnicity: {x['RaceEthnicity']}
|
170 |
+
|
171 |
+
Risk Factors:
|
172 |
+
- {x['RiskFactor']}: {x['RiskFactorResponse']}
|
173 |
+
|
174 |
+
Additional Information:
|
175 |
+
- Sample Size: {x['Sample_Size']}
|
176 |
+
- Cluster Profile: {x['cluster']} (Weight: {cluster_weights.get(x['cluster'], 0):.3f})
|
177 |
+
""".strip(),
|
178 |
+
axis=1
|
179 |
+
)
|
180 |
+
|
181 |
+
# Encode labels
|
182 |
+
le = LabelEncoder()
|
183 |
+
df['labels'] = le.fit_transform(df['Response'].astype(str))
|
184 |
+
|
185 |
+
# Combine sample size weights with cluster importance
|
186 |
+
df['weight'] = df.apply(
|
187 |
+
lambda x: (x['Sample_Size'] / df['Sample_Size'].sum()) *
|
188 |
+
cluster_weights.get(x['cluster'], 0),
|
189 |
+
axis=1
|
190 |
+
)
|
191 |
+
|
192 |
+
# Create train and test splits with stratification
|
193 |
+
train_df, test_df = train_test_split(
|
194 |
+
df,
|
195 |
+
test_size=0.2,
|
196 |
+
stratify=df['labels'],
|
197 |
+
random_state=42
|
198 |
+
)
|
199 |
+
|
200 |
+
# Convert to dict format
|
201 |
+
train_data = {
|
202 |
+
'doc': train_df['doc'].tolist(),
|
203 |
+
'labels': train_df['labels'].tolist(),
|
204 |
+
'weight': train_df['weight'].tolist()
|
205 |
+
}
|
206 |
+
|
207 |
+
test_data = {
|
208 |
+
'doc': test_df['doc'].tolist(),
|
209 |
+
'labels': test_df['labels'].tolist(),
|
210 |
+
'weight': test_df['weight'].tolist()
|
211 |
+
}
|
212 |
+
|
213 |
+
# Convert to datasets
|
214 |
+
train_dataset = Dataset.from_dict(train_data)
|
215 |
+
test_dataset = Dataset.from_dict(test_data)
|
216 |
+
|
217 |
+
dataset_dict = {
|
218 |
+
'train': train_dataset,
|
219 |
+
'test': test_dataset
|
220 |
+
}
|
221 |
+
|
222 |
+
# Print detailed dataset statistics
|
223 |
+
print("\nDataset Summary:")
|
224 |
+
print(f"Training samples: {len(train_dataset)}")
|
225 |
+
print(f"Test samples: {len(test_dataset)}")
|
226 |
+
|
227 |
+
print("\nCluster Distribution:")
|
228 |
+
for i in range(n_clusters):
|
229 |
+
cluster_mask = df['cluster'] == i
|
230 |
+
cluster_samples = df[cluster_mask]['Sample_Size'].sum()
|
231 |
+
print(f"\nCluster {i} (Total samples: {cluster_samples:,}, Weight: {cluster_weights.get(i, 0):.3f}):")
|
232 |
+
print("Most common characteristics:")
|
233 |
+
for col in ['Age', 'Gender', 'RaceEthnicity', 'RiskFactor']:
|
234 |
+
values = df[col][cluster_mask].value_counts().head(3)
|
235 |
+
samples = df[cluster_mask].groupby(col)['Sample_Size'].sum().sort_values(ascending=False).head(3)
|
236 |
+
print(f"{col}:")
|
237 |
+
for val, count in values.items():
|
238 |
+
sample_count = samples.get(val, 0) # Use .get() for safety
|
239 |
+
print(f" - {val}: {count} groups ({sample_count:,} individuals)")
|
240 |
+
|
241 |
+
print("\nLabel Distribution:")
|
242 |
+
for label, idx in zip(le.classes_, range(len(le.classes_))):
|
243 |
+
count = (df['labels'] == idx).sum()
|
244 |
+
total_size = df[df['labels'] == idx]['Sample_Size'].sum()
|
245 |
+
print(f"{label}: {count} groups, {total_size:,} individuals")
|
246 |
+
|
247 |
+
return dataset_dict, le
|
248 |
+
|
249 |
+
|
250 |
+
|
251 |
+
def main():
|
252 |
+
# Setup
|
253 |
+
output_dir = "models/vision-classifier"
|
254 |
+
os.makedirs(output_dir, exist_ok=True)
|
255 |
+
|
256 |
+
# Load the dataset
|
257 |
+
dataset_dict, label_encoder = prepare_data()
|
258 |
+
|
259 |
+
# Initialize the tokenizer
|
260 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
261 |
+
|
262 |
+
# Define tokenization function within main to have access to tokenizer
|
263 |
+
def tokenize_function(examples):
|
264 |
+
"""Tokenize the input texts and maintain the correct column names"""
|
265 |
+
tokenized = tokenizer(
|
266 |
+
examples["doc"],
|
267 |
+
truncation=True,
|
268 |
+
padding='max_length',
|
269 |
+
max_length=128,
|
270 |
+
return_tensors=None
|
271 |
+
)
|
272 |
+
# Keep the additional columns
|
273 |
+
tokenized['labels'] = examples['labels']
|
274 |
+
tokenized['weight'] = examples['weight']
|
275 |
+
return tokenized
|
276 |
+
|
277 |
+
# Tokenize the datasets
|
278 |
+
tokenized_datasets = {}
|
279 |
+
for split, dataset in dataset_dict.items():
|
280 |
+
tokenized_datasets[split] = dataset.map(
|
281 |
+
tokenize_function,
|
282 |
+
batched=True,
|
283 |
+
remove_columns=['doc']
|
284 |
+
)
|
285 |
+
|
286 |
+
# Print sample to verify
|
287 |
+
print("\nSample tokenized data:", tokenized_datasets["train"][0])
|
288 |
+
|
289 |
+
# Initialize the model
|
290 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
291 |
+
"distilbert-base-uncased",
|
292 |
+
num_labels=len(label_encoder.classes_),
|
293 |
+
id2label={i: label for i, label in enumerate(label_encoder.classes_)},
|
294 |
+
label2id={label: i for i, label in enumerate(label_encoder.classes_)},
|
295 |
+
)
|
296 |
+
|
297 |
+
# Data collator
|
298 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
299 |
+
|
300 |
+
# Check device
|
301 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
302 |
+
print(f"\nTraining on device: {device}")
|
303 |
+
|
304 |
+
# Move model to device
|
305 |
+
model.to(device)
|
306 |
+
|
307 |
+
# Set up training arguments
|
308 |
+
training_args = TrainingArguments(
|
309 |
+
output_dir=output_dir,
|
310 |
+
learning_rate=3e-5,
|
311 |
+
per_device_train_batch_size=8,
|
312 |
+
per_device_eval_batch_size=8,
|
313 |
+
num_train_epochs=7,
|
314 |
+
weight_decay=0.01,
|
315 |
+
eval_strategy="epoch",
|
316 |
+
save_strategy="epoch",
|
317 |
+
load_best_model_at_end=True,
|
318 |
+
remove_unused_columns=False,
|
319 |
+
push_to_hub=True,
|
320 |
+
)
|
321 |
+
|
322 |
+
# Create the Trainer
|
323 |
+
trainer = WeightedTrainer(
|
324 |
+
model=model,
|
325 |
+
args=training_args,
|
326 |
+
train_dataset=tokenized_datasets["train"],
|
327 |
+
eval_dataset=tokenized_datasets["test"],
|
328 |
+
data_collator=data_collator,
|
329 |
+
)
|
330 |
+
|
331 |
+
# Train the model
|
332 |
+
print("\nStarting training...")
|
333 |
+
trainer.train()
|
334 |
+
|
335 |
+
# Save the model
|
336 |
+
print("\nSaving model...")
|
337 |
+
trainer.save_model(output_dir=os.path.join(output_dir, "model"))
|
338 |
+
|
339 |
+
# Save the tokenizer
|
340 |
+
tokenizer.save_pretrained(os.path.join(output_dir, "tokenizer"))
|
341 |
+
|
342 |
+
# Save the label encoder
|
343 |
+
label_encoder_path = os.path.join(output_dir, "label_encoder.pkl")
|
344 |
+
with open(label_encoder_path, 'wb') as f:
|
345 |
+
pickle.dump(label_encoder, f)
|
346 |
+
|
347 |
+
return trainer, model, tokenizer, label_encoder
|
348 |
+
|
349 |
+
|
350 |
+
def evaluate_model(model, eval_dataset, tokenizer, label_encoder, device) -> Dict:
|
351 |
+
"""
|
352 |
+
Evaluate model performance using multiple metrics
|
353 |
+
"""
|
354 |
+
model.eval()
|
355 |
+
all_predictions = []
|
356 |
+
all_labels = []
|
357 |
+
|
358 |
+
# Process each example in evaluation dataset
|
359 |
+
for item in eval_dataset:
|
360 |
+
# Tokenize input
|
361 |
+
inputs = tokenizer(
|
362 |
+
item['doc'],
|
363 |
+
truncation=True,
|
364 |
+
padding=True,
|
365 |
+
return_tensors="pt"
|
366 |
+
)
|
367 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
368 |
+
|
369 |
+
# Get predictions
|
370 |
+
with torch.no_grad():
|
371 |
+
outputs = model(**inputs)
|
372 |
+
predictions = torch.argmax(outputs.logits, dim=1)
|
373 |
+
|
374 |
+
all_predictions.extend(predictions.cpu().numpy())
|
375 |
+
all_labels.append(item['labels'])
|
376 |
+
|
377 |
+
# Calculate metrics
|
378 |
+
accuracy = accuracy_score(all_labels, all_predictions)
|
379 |
+
precision, recall, f1, support = precision_recall_fscore_support(
|
380 |
+
all_labels,
|
381 |
+
all_predictions,
|
382 |
+
average='weighted'
|
383 |
+
)
|
384 |
+
|
385 |
+
# Calculate per-class metrics
|
386 |
+
per_class_precision, per_class_recall, per_class_f1, _ = precision_recall_fscore_support(
|
387 |
+
all_labels,
|
388 |
+
all_predictions,
|
389 |
+
average=None
|
390 |
+
)
|
391 |
+
|
392 |
+
# Create confusion matrix
|
393 |
+
conf_matrix = confusion_matrix(all_labels, all_predictions)
|
394 |
+
|
395 |
+
# Combine metrics
|
396 |
+
metrics = {
|
397 |
+
'accuracy': accuracy,
|
398 |
+
'weighted_precision': precision,
|
399 |
+
'weighted_recall': recall,
|
400 |
+
'weighted_f1': f1,
|
401 |
+
'confusion_matrix': conf_matrix,
|
402 |
+
'per_class_metrics': {
|
403 |
+
label: {
|
404 |
+
'precision': p,
|
405 |
+
'recall': r,
|
406 |
+
'f1': f
|
407 |
+
} for label, p, r, f in zip(
|
408 |
+
label_encoder.classes_,
|
409 |
+
per_class_precision,
|
410 |
+
per_class_recall,
|
411 |
+
per_class_f1
|
412 |
+
)
|
413 |
+
}
|
414 |
+
}
|
415 |
+
|
416 |
+
return metrics
|
417 |
+
|
418 |
+
|
419 |
+
def print_evaluation_report(metrics: Dict, label_encoder):
|
420 |
+
"""
|
421 |
+
Print formatted evaluation report
|
422 |
+
"""
|
423 |
+
print("\n" + "=" * 50)
|
424 |
+
print("MODEL EVALUATION REPORT")
|
425 |
+
print("=" * 50)
|
426 |
+
|
427 |
+
print("\nOverall Metrics:")
|
428 |
+
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
429 |
+
print(f"Weighted Precision: {metrics['weighted_precision']:.4f}")
|
430 |
+
print(f"Weighted Recall: {metrics['weighted_recall']:.4f}")
|
431 |
+
print(f"Weighted F1-Score: {metrics['weighted_f1']:.4f}")
|
432 |
+
|
433 |
+
print("\nPer-Class Metrics:")
|
434 |
+
print("-" * 50)
|
435 |
+
print(f"{'Class':<30} {'Precision':>10} {'Recall':>10} {'F1-Score':>10}")
|
436 |
+
print("-" * 50)
|
437 |
+
|
438 |
+
for label, class_metrics in metrics['per_class_metrics'].items():
|
439 |
+
print(
|
440 |
+
f"{label:<30} {class_metrics['precision']:>10.4f} {class_metrics['recall']:>10.4f} {class_metrics['f1']:>10.4f}")
|
441 |
+
|
442 |
+
print("\nConfusion Matrix:")
|
443 |
+
print("-" * 50)
|
444 |
+
conf_matrix = metrics['confusion_matrix']
|
445 |
+
print(conf_matrix)
|
446 |
+
|
447 |
+
|
448 |
+
if __name__ == "__main__":
|
449 |
+
output_dir = "models/vision-classifier"
|
450 |
+
model_path = os.path.join(output_dir, "model")
|
451 |
+
tokenizer_path = os.path.join(output_dir, "tokenizer")
|
452 |
+
|
453 |
+
if os.path.exists(model_path):
|
454 |
+
print("\nLoading pre-trained model...")
|
455 |
+
try:
|
456 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
457 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
458 |
+
label_encoder_path = os.path.join(output_dir, "label_encoder.pkl")
|
459 |
+
if os.path.exists(label_encoder_path):
|
460 |
+
with open(label_encoder_path, 'rb') as f:
|
461 |
+
label_encoder = pickle.load(f)
|
462 |
+
else:
|
463 |
+
print("Warning: Label encoder not found. Running full training...")
|
464 |
+
trainer, model, tokenizer, label_encoder = main()
|
465 |
+
|
466 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
467 |
+
model.to(device)
|
468 |
+
print(f"Model loaded successfully and moved to {device}")
|
469 |
+
|
470 |
+
# Load test dataset for evaluation
|
471 |
+
dataset_dict, _ = prepare_data()
|
472 |
+
|
473 |
+
# Run evaluation
|
474 |
+
print("\nEvaluating model performance...")
|
475 |
+
eval_metrics = evaluate_model(
|
476 |
+
model,
|
477 |
+
dataset_dict['test'],
|
478 |
+
tokenizer,
|
479 |
+
label_encoder,
|
480 |
+
device
|
481 |
+
)
|
482 |
+
|
483 |
+
# Print evaluation report
|
484 |
+
print_evaluation_report(eval_metrics, label_encoder)
|
485 |
+
|
486 |
+
except Exception as e:
|
487 |
+
print(f"Error loading model: {e}")
|
488 |
+
print("Running full training instead...")
|
489 |
+
trainer, model, tokenizer, label_encoder = main()
|
490 |
+
else:
|
491 |
+
print("\nNo pre-trained model found. Running training...")
|
492 |
+
trainer, model, tokenizer, label_encoder = main()
|
493 |
+
|
494 |
+
|
495 |
+
def predict_vision_status(text, model, tokenizer, label_encoder):
|
496 |
+
"""Make prediction using the loaded/trained model"""
|
497 |
+
inputs = tokenizer(
|
498 |
+
text,
|
499 |
+
truncation=True,
|
500 |
+
padding=True,
|
501 |
+
return_tensors="pt"
|
502 |
+
)
|
503 |
+
|
504 |
+
device = next(model.parameters()).device
|
505 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
506 |
+
|
507 |
+
with torch.no_grad():
|
508 |
+
outputs = model(**inputs)
|
509 |
+
# Apply softmax to get probabilities
|
510 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
|
511 |
+
|
512 |
+
# Convert to numpy array
|
513 |
+
probabilities = probabilities.cpu().numpy()[0]
|
514 |
+
|
515 |
+
# Create list of (label, probability) tuples
|
516 |
+
predictions = []
|
517 |
+
for idx, prob in enumerate(probabilities):
|
518 |
+
label = label_encoder.inverse_transform([idx])[0]
|
519 |
+
predictions.append((label, float(prob)))
|
520 |
+
|
521 |
+
# Sort by probability in descending order
|
522 |
+
predictions.sort(key=lambda x: x[1], reverse=True)
|
523 |
+
|
524 |
+
return predictions
|
525 |
+
|
526 |
+
example_text = "Age: 40-64 years, Gender: Female, Race: White, non-Hispanic, Diabetes: No"
|
527 |
+
predictions = predict_vision_status(example_text, model, tokenizer, label_encoder)
|
528 |
+
|
529 |
+
print(f"\nPredictions for: {example_text}")
|
530 |
+
print("\nLabel Confidence Scores:")
|
531 |
+
print("-" * 50)
|
532 |
+
for label, confidence in predictions:
|
533 |
+
print(f"{label:<30} {confidence:.2%}")
|
data/Vision_Survey_Cleaned.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|