Jingyuan-Zhu commited on
Commit
35fcfe6
β€’
1 Parent(s): 9724b8f

Upload 2 files

Browse files

added evaluation pipeline and model weights

ensemble_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb5421224d0381aec5065db260c258b22e27594a432b2234d4ebb1e925c53589
3
+ size 34882689
final_ensemble_pipeline.ipynb ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "74bd5ceb-afa1-4bfd-ba39-10af717cf2a5",
6
+ "metadata": {},
7
+ "source": [
8
+ "Remember to change the test and model Path!\n",
9
+ "Since I'm using Embedding to encode headlines to vector, it takes 10+ min. to encode information for test set which I cannot do it on my end since I do not have access to hiddne test set! "
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "id": "a458f2b7-3ab1-479f-9627-ef7ef8ef76b4",
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "import torch\n",
20
+ "import torch.nn as nn\n",
21
+ "import torch.optim as optim\n",
22
+ "from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler\n",
23
+ "from tqdm import tqdm\n",
24
+ "import numpy as np\n",
25
+ "import random\n",
26
+ "import os\n",
27
+ "import copy\n",
28
+ "from torch.utils.data import TensorDataset\n",
29
+ "import pandas as pd"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 2,
35
+ "id": "d7943628-3454-4d21-a95d-ca53acd9b6dc",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "class LabelSmoothingBCELoss(nn.Module):\n",
40
+ " def __init__(self, smoothing=0.1):\n",
41
+ " \"\"\"\n",
42
+ " Label Smoothing Binary Cross Entropy Loss\n",
43
+ " \n",
44
+ " Args:\n",
45
+ " smoothing (float): Amount of label smoothing to apply\n",
46
+ " \"\"\"\n",
47
+ " super(LabelSmoothingBCELoss, self).__init__()\n",
48
+ " self.smoothing = smoothing\n",
49
+ " \n",
50
+ " def forward(self, predictions, targets):\n",
51
+ " \"\"\"\n",
52
+ " Compute label-smoothed binary cross entropy loss\n",
53
+ " \n",
54
+ " Args:\n",
55
+ " predictions (torch.Tensor): Model predictions\n",
56
+ " targets (torch.Tensor): Binary labels\n",
57
+ " \n",
58
+ " Returns:\n",
59
+ " torch.Tensor: Smoothed loss\n",
60
+ " \"\"\"\n",
61
+ " # Apply label smoothing\n",
62
+ " smooth_targets = targets * (1 - self.smoothing) + 0.5 * self.smoothing\n",
63
+ " \n",
64
+ " # Standard Binary Cross Entropy Loss\n",
65
+ " loss = nn.functional.binary_cross_entropy(predictions, smooth_targets)\n",
66
+ " \n",
67
+ " return loss\n",
68
+ "\n",
69
+ "class EarlyStoppingCallback:\n",
70
+ " def __init__(self, patience=5, min_delta=0.001):\n",
71
+ " \"\"\"\n",
72
+ " Early stopping mechanism\n",
73
+ " \n",
74
+ " Args:\n",
75
+ " patience (int): Number of epochs to wait for improvement\n",
76
+ " min_delta (float): Minimum change to qualify as an improvement\n",
77
+ " \"\"\"\n",
78
+ " self.patience = patience\n",
79
+ " self.min_delta = min_delta\n",
80
+ " self.counter = 0\n",
81
+ " self.best_loss = float('inf')\n",
82
+ " self.early_stop = False\n",
83
+ " self.best_model_state = None\n",
84
+ " \n",
85
+ " def __call__(self, val_loss, model):\n",
86
+ " \"\"\"\n",
87
+ " Check if training should stop\n",
88
+ " \n",
89
+ " Args:\n",
90
+ " val_loss (float): Current validation loss\n",
91
+ " model (nn.Module): Current model state\n",
92
+ " \n",
93
+ " Returns:\n",
94
+ " bool: Whether to stop training\n",
95
+ " \"\"\"\n",
96
+ " if val_loss < self.best_loss - self.min_delta:\n",
97
+ " self.best_loss = val_loss\n",
98
+ " self.counter = 0\n",
99
+ " # Save the best model state\n",
100
+ " self.best_model_state = copy.deepcopy(model.state_dict())\n",
101
+ " else:\n",
102
+ " self.counter += 1\n",
103
+ " if self.counter >= self.patience:\n",
104
+ " self.early_stop = True\n",
105
+ " \n",
106
+ " return self.early_stop\n",
107
+ "\n",
108
+ "class EnsembleMLPClassifier(nn.Module):\n",
109
+ " def __init__(self, \n",
110
+ " input_dim=1024, # BGE embedding dimension\n",
111
+ " hidden_layers=None,\n",
112
+ " dropout_rate=0.2,\n",
113
+ " activation=nn.ReLU(), # Allow passing activation functions dynamically\n",
114
+ " device=None):\n",
115
+ " super(EnsembleMLPClassifier, self).__init__()\n",
116
+ " \n",
117
+ " # Default configuration if not provided\n",
118
+ " if hidden_layers is None:\n",
119
+ " hidden_layers = [512, 256, 128]\n",
120
+ " \n",
121
+ " # Set device (GPU if available, else CPU)\n",
122
+ " self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
123
+ " \n",
124
+ " # Store initialization parameters\n",
125
+ " self.input_dim = input_dim\n",
126
+ " self.hidden_layers = hidden_layers\n",
127
+ " self.dropout_rate = dropout_rate\n",
128
+ " self.activation = activation\n",
129
+ " \n",
130
+ " # Add linear gate mechanism\n",
131
+ " self.gate = nn.Linear(input_dim, input_dim, bias=False)\n",
132
+ " \n",
133
+ " # Create layers dynamically based on hidden_layers specification\n",
134
+ " layers = []\n",
135
+ " prev_dim = input_dim\n",
136
+ " for hidden_dim in hidden_layers:\n",
137
+ " # Dense Layer with dynamic activation and BatchNorm\n",
138
+ " layers.extend([\n",
139
+ " nn.Linear(prev_dim, hidden_dim),\n",
140
+ " nn.BatchNorm1d(hidden_dim),\n",
141
+ " activation,\n",
142
+ " nn.Dropout(dropout_rate)\n",
143
+ " ])\n",
144
+ " prev_dim = hidden_dim\n",
145
+ " \n",
146
+ " # Final output layer for binary classification\n",
147
+ " layers.append(nn.Linear(prev_dim, 1))\n",
148
+ " layers.append(nn.Sigmoid())\n",
149
+ " \n",
150
+ " # Create the model and move to device\n",
151
+ " self.model = nn.Sequential(*layers)\n",
152
+ " self.to(self.device)\n",
153
+ "\n",
154
+ " def forward(self, x):\n",
155
+ " \"\"\"Forward pass through the network\"\"\"\n",
156
+ " # Apply gating mechanism\n",
157
+ " x = self.gate(x) * x\n",
158
+ " return self.model(x)\n",
159
+ "\n",
160
+ "class EnsembleClassifier:\n",
161
+ " def __init__(self, num_models=5, label_smoothing=0.1):\n",
162
+ " self.models = self._create_diverse_models(num_models)\n",
163
+ " self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
164
+ " self.label_smoothing = label_smoothing\n",
165
+ " self.model_weights = None \n",
166
+ " \n",
167
+ " def _create_diverse_models(self, num_models):\n",
168
+ " models = []\n",
169
+ " \n",
170
+ " # Predefined configurations for consistency across runs\n",
171
+ " architectures = [\n",
172
+ " {'hidden_layers': [512, 256, 128], 'dropout_rate': 0.2, 'activation': nn.ReLU()},\n",
173
+ " {'hidden_layers': [1024, 512], 'dropout_rate': 0.3, 'activation': nn.LeakyReLU()},\n",
174
+ " {'hidden_layers': [256, 128, 64], 'dropout_rate': 0.1, 'activation': nn.GELU()},\n",
175
+ " {'hidden_layers': [512, 128], 'dropout_rate': 0.25, 'activation': nn.SELU()},\n",
176
+ " {'hidden_layers': [256, 128], 'dropout_rate': 0.15, 'activation': nn.Tanh()}\n",
177
+ " ]\n",
178
+ " \n",
179
+ " # Optimizer strategies\n",
180
+ " optimizers = [optim.Adam, optim.AdamW, optim.SGD]\n",
181
+ " \n",
182
+ " for i in range(num_models):\n",
183
+ " # Use predefined architectures in a consistent order\n",
184
+ " config = architectures[i % len(architectures)]\n",
185
+ " optimizer_fn = optimizers[i % len(optimizers)]\n",
186
+ " \n",
187
+ " model = EnsembleMLPClassifier(\n",
188
+ " input_dim=1024,\n",
189
+ " hidden_layers=config['hidden_layers'],\n",
190
+ " dropout_rate=config['dropout_rate'],\n",
191
+ " activation=config['activation']\n",
192
+ " )\n",
193
+ " \n",
194
+ " # Custom weight initialization\n",
195
+ " def init_weights(m):\n",
196
+ " if isinstance(m, nn.Linear):\n",
197
+ " init_methods = [\n",
198
+ " nn.init.xavier_uniform_,\n",
199
+ " nn.init.kaiming_normal_,\n",
200
+ " nn.init.orthogonal_\n",
201
+ " ]\n",
202
+ " init_method = init_methods[i % len(init_methods)] # Consistent initialization\n",
203
+ " init_method(m.weight)\n",
204
+ " if m.bias is not None:\n",
205
+ " nn.init.zeros_(m.bias)\n",
206
+ " \n",
207
+ " model.model.apply(init_weights)\n",
208
+ " \n",
209
+ " # Attach optimizer to model instance for flexibility\n",
210
+ " model.optimizer_fn = optimizer_fn\n",
211
+ " \n",
212
+ " # Add L2 regularization to the model (Weight Decay)\n",
213
+ " model.regularization = {\n",
214
+ " 'weight_decay': 1e-5 # Example regularization value\n",
215
+ " }\n",
216
+ " \n",
217
+ " models.append(model)\n",
218
+ " \n",
219
+ " return models\n",
220
+ " \n",
221
+ " def train(self, train_dataset, batch_size=32, num_epochs=20):\n",
222
+ " for model_idx, model in enumerate(tqdm(self.models, desc=\"Training Models\", position=0)):\n",
223
+ " print(f\"Starting training for Model {model_idx + 1}/{len(self.models)}\")\n",
224
+ " \n",
225
+ " # Randomly split 80% for training and 20% for validation\n",
226
+ " total_size = len(train_dataset)\n",
227
+ " train_size = int(0.8 * total_size)\n",
228
+ " val_size = total_size - train_size\n",
229
+ " \n",
230
+ " train_subset, val_subset = random_split(train_dataset, [train_size, val_size])\n",
231
+ " \n",
232
+ " # Create data loaders for training and validation\n",
233
+ " train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)\n",
234
+ " val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)\n",
235
+ " \n",
236
+ " # Optimizer with learning rate scheduler\n",
237
+ " optimizer = optim.AdamW(model.parameters(), lr=1e-3)\n",
238
+ " scheduler = optim.lr_scheduler.CosineAnnealingLR(\n",
239
+ " optimizer, \n",
240
+ " T_max=num_epochs, \n",
241
+ " eta_min=1e-5\n",
242
+ " )\n",
243
+ " \n",
244
+ " # Label Smoothing Loss\n",
245
+ " criterion = LabelSmoothingBCELoss(smoothing=self.label_smoothing)\n",
246
+ " \n",
247
+ " # Early stopping\n",
248
+ " early_stopping = EarlyStoppingCallback(patience=4, min_delta=0.001)\n",
249
+ " \n",
250
+ " model.train()\n",
251
+ " epoch_progress = tqdm(range(num_epochs), desc=f\"Model {model_idx} Training\", position=1, leave=False)\n",
252
+ " \n",
253
+ " best_val_loss = float('inf')\n",
254
+ " for epoch in epoch_progress:\n",
255
+ " total_loss = 0\n",
256
+ " \n",
257
+ " # Training phase\n",
258
+ " for batch in train_loader:\n",
259
+ " inputs, labels = batch\n",
260
+ " inputs, labels = inputs.to(model.device), labels.to(model.device)\n",
261
+ " \n",
262
+ " optimizer.zero_grad()\n",
263
+ " outputs = model(inputs)\n",
264
+ " loss = criterion(outputs, labels.float().unsqueeze(1))\n",
265
+ " loss.backward()\n",
266
+ " optimizer.step()\n",
267
+ " \n",
268
+ " total_loss += loss.item()\n",
269
+ " avg_train_loss = total_loss / len(train_loader)\n",
270
+ " \n",
271
+ " # Validation phase\n",
272
+ " model.eval()\n",
273
+ " val_loss = 0\n",
274
+ " with torch.no_grad():\n",
275
+ " for val_batch in val_loader:\n",
276
+ " val_inputs, val_labels = val_batch\n",
277
+ " val_inputs, val_labels = val_inputs.to(model.device), val_labels.to(model.device)\n",
278
+ " val_outputs = model(val_inputs)\n",
279
+ " val_loss += criterion(val_outputs, val_labels.float().unsqueeze(1)).item()\n",
280
+ " \n",
281
+ " avg_val_loss = val_loss / len(val_loader)\n",
282
+ " epoch_progress.set_postfix({\n",
283
+ " 'train_loss': avg_train_loss,\n",
284
+ " 'val_loss': avg_val_loss\n",
285
+ " })\n",
286
+ " \n",
287
+ " # Early stopping check\n",
288
+ " if early_stopping(avg_val_loss, model):\n",
289
+ " if early_stopping.best_model_state:\n",
290
+ " model.load_state_dict(early_stopping.best_model_state)\n",
291
+ " print(f\"Early stopping triggered for Model {model_idx}\")\n",
292
+ " break\n",
293
+ " \n",
294
+ " # Learning rate adjustment\n",
295
+ " scheduler.step()\n",
296
+ " \n",
297
+ " # Reset to training mode\n",
298
+ " model.train()\n",
299
+ " \n",
300
+ " # Store model's final state after training\n",
301
+ " model.eval()\n",
302
+ " \n",
303
+ " def compute_test_weights(self, test_loader):\n",
304
+ " \"\"\"\n",
305
+ " Compute model weights based on test accuracy while emphasizing distinctions.\n",
306
+ " \"\"\"\n",
307
+ " model_accuracies = []\n",
308
+ " for model_idx, model in enumerate(self.models):\n",
309
+ " correct = 0\n",
310
+ " total = 0\n",
311
+ " model.eval()\n",
312
+ " with torch.no_grad():\n",
313
+ " for inputs, labels in test_loader:\n",
314
+ " inputs, labels = inputs.to(model.device), labels.to(model.device)\n",
315
+ " outputs = model(inputs)\n",
316
+ " preds = (outputs > 0.5).float()\n",
317
+ " correct += (preds == labels).sum().item()\n",
318
+ " total += labels.size(0)\n",
319
+ " accuracy = correct / total\n",
320
+ " model_accuracies.append(accuracy)\n",
321
+ " \n",
322
+ " # Apply a power transformation for distinction\n",
323
+ " accuracies = np.array(model_accuracies)\n",
324
+ " print(f\"Raw model accuracies: {accuracies}\")\n",
325
+ " \n",
326
+ " # Use power scaling to exaggerate differences (e.g., square the accuracies)\n",
327
+ " power_scaling_factor = 2 # Choose 2 for squaring, can experiment with higher values\n",
328
+ " scaled_accuracies = accuracies ** power_scaling_factor\n",
329
+ " \n",
330
+ " # Smooth the accuracies slightly to avoid over-reliance on any single model\n",
331
+ " smoothed_accuracies = scaled_accuracies * (1 - 0.1) + 0.1 * np.mean(scaled_accuracies)\n",
332
+ " \n",
333
+ " # Normalize weights so they sum to 1\n",
334
+ " weights = smoothed_accuracies / smoothed_accuracies.sum()\n",
335
+ " \n",
336
+ " # Store model weights\n",
337
+ " self.model_weights = torch.tensor(weights, dtype=torch.float32).to(self.device)\n",
338
+ " print(f\"Model weights after scaling: {self.model_weights}\")\n",
339
+ "\n",
340
+ "\n",
341
+ " def predict(self, test_loader, confidence_threshold=0.5, return_raw_scores=True):\n",
342
+ " \"\"\"\n",
343
+ " Prediction with confidence-weighted voting, optionally returning raw scores.\n",
344
+ " \"\"\"\n",
345
+ " if self.model_weights is None:\n",
346
+ " raise ValueError(\"Model weights not computed. Call compute_test_weights first.\")\n",
347
+ " \n",
348
+ " all_predictions = []\n",
349
+ " for model_idx, model in enumerate(self.models):\n",
350
+ " model.eval()\n",
351
+ " model_preds = []\n",
352
+ " with torch.no_grad():\n",
353
+ " for batch in test_loader:\n",
354
+ " inputs, _ = batch\n",
355
+ " inputs = inputs.to(model.device)\n",
356
+ " outputs = model(inputs)\n",
357
+ " model_preds.append(outputs)\n",
358
+ " \n",
359
+ " # Concatenate predictions for this model\n",
360
+ " all_predictions.append(torch.cat(model_preds))\n",
361
+ " \n",
362
+ " # Stack predictions and compute weighted average\n",
363
+ " stacked_preds = torch.stack(all_predictions, dim=1).squeeze(-1)\n",
364
+ " weighted_preds = (stacked_preds * self.model_weights.view(1, -1)).sum(dim=1)\n",
365
+ " \n",
366
+ " # Final prediction with thresholding\n",
367
+ " final_preds = (weighted_preds > confidence_threshold).float()\n",
368
+ " \n",
369
+ " # Optionally return raw probabilities for debugging\n",
370
+ " if return_raw_scores:\n",
371
+ " return final_preds, weighted_preds.cpu().numpy()\n",
372
+ " \n",
373
+ " return final_preds\n",
374
+ "\n",
375
+ "\n",
376
+ " def save_models(self, save_dir='ensemble_models/model_test_4'):\n",
377
+ " \"\"\"\n",
378
+ " Save ensemble model weights and model weights with progress tracking\n",
379
+ " \"\"\"\n",
380
+ " os.makedirs(save_dir, exist_ok=True)\n",
381
+ "\n",
382
+ " save_data = {\n",
383
+ " 'models': {},\n",
384
+ " 'model_weights': self.model_weights.cpu().numpy() if self.model_weights is not None else None\n",
385
+ " }\n",
386
+ "\n",
387
+ " for i, model in tqdm(enumerate(self.models), desc=\"Saving Models\", total=len(self.models)):\n",
388
+ " save_data['models'][i] = model.state_dict()\n",
389
+ "\n",
390
+ " torch.save(save_data, os.path.join(save_dir, 'ensemble_checkpoint.pth'))\n",
391
+ "\n",
392
+ " def load_models(self, save_dir='ensemble_models/model_test_4'):\n",
393
+ " \"\"\"\n",
394
+ " Load ensemble model weights and model weights with progress tracking\n",
395
+ " \"\"\"\n",
396
+ " checkpoint_path = os.path.join(save_dir, 'ensemble_checkpoint.pth')\n",
397
+ "\n",
398
+ " save_data = torch.load(checkpoint_path)\n",
399
+ "\n",
400
+ " for i, model in tqdm(enumerate(self.models), desc=\"Loading Models\", total=len(self.models)):\n",
401
+ " model.load_state_dict(save_data['models'][i])\n",
402
+ " model.eval() # Set to evaluation mode\n",
403
+ "\n",
404
+ " if save_data['model_weights'] is not None:\n",
405
+ " self.model_weights = torch.tensor(save_data['model_weights'], dtype=torch.float32).to(self.device)\n",
406
+ " \n",
407
+ " def evaluate(self, test_loader):\n",
408
+ " \"\"\"\n",
409
+ " Evaluate ensemble performance with weighted voting, supporting both CPU and GPU.\n",
410
+ " \"\"\"\n",
411
+ " # Collect ground truth labels\n",
412
+ " all_labels = torch.cat([labels for _, labels in test_loader], dim=0).to(self.device)\n",
413
+ " \n",
414
+ " # Get predictions for the entire test set\n",
415
+ " test_preds = self.predict(test_loader, return_raw_scores=True)\n",
416
+ " \n",
417
+ " # Ensure predictions and labels are on the same device\n",
418
+ " all_labels = all_labels.cpu().numpy().ravel() # Flatten to 1D\n",
419
+ " test_preds, raw_probs = test_preds\n",
420
+ " test_preds = test_preds.cpu().numpy().ravel() # Flatten to 1D\n",
421
+ " \n",
422
+ " # Print debug information\n",
423
+ " # print(\"Ground truth labels (all_labels):\", all_labels)\n",
424
+ " # print(\"Predicted classes (test_preds):\", test_preds)\n",
425
+ " # print(\"Raw probabilities (raw_probs):\", raw_probs) \n",
426
+ " \n",
427
+ " # Calculate metrics\n",
428
+ " accuracy = np.mean(test_preds == all_labels)\n",
429
+ " precision = precision_score(all_labels, test_preds, zero_division=0)\n",
430
+ " recall = recall_score(all_labels, test_preds, zero_division=0)\n",
431
+ " \n",
432
+ " return {\n",
433
+ " \"accuracy\": accuracy,\n",
434
+ " \"precision\": precision,\n",
435
+ " \"recall\": recall\n",
436
+ " }"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 3,
442
+ "id": "a95bb0eb-48ba-4c46-9cc5-4f6a1ee19dee",
443
+ "metadata": {},
444
+ "outputs": [
445
+ {
446
+ "name": "stdout",
447
+ "output_type": "stream",
448
+ "text": [
449
+ "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
450
+ "Requirement already satisfied: FlagEmbedding in /opt/conda/lib/python3.11/site-packages (1.3.3)\n",
451
+ "Requirement already satisfied: torch>=1.6.0 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (2.2.2+cu121)\n",
452
+ "Requirement already satisfied: transformers==4.44.2 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (4.44.2)\n",
453
+ "Requirement already satisfied: datasets==2.19.0 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (2.19.0)\n",
454
+ "Requirement already satisfied: accelerate>=0.20.1 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (1.2.0)\n",
455
+ "Requirement already satisfied: sentence-transformers in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (3.3.1)\n",
456
+ "Requirement already satisfied: peft in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.14.0)\n",
457
+ "Requirement already satisfied: ir-datasets in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.5.9)\n",
458
+ "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.2.0)\n",
459
+ "Requirement already satisfied: protobuf in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (4.25.3)\n",
460
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.9.0)\n",
461
+ "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (1.26.4)\n",
462
+ "Requirement already satisfied: pyarrow>=12.0.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (15.0.2)\n",
463
+ "Requirement already satisfied: pyarrow-hotfix in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.6)\n",
464
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.3.8)\n",
465
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (2.2.2)\n",
466
+ "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (2.31.0)\n",
467
+ "Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (4.66.2)\n",
468
+ "Requirement already satisfied: xxhash in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.5.0)\n",
469
+ "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.70.16)\n",
470
+ "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /opt/conda/lib/python3.11/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets==2.19.0->FlagEmbedding) (2024.3.1)\n",
471
+ "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.11.10)\n",
472
+ "Requirement already satisfied: huggingface-hub>=0.21.2 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.26.5)\n",
473
+ "Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (24.0)\n",
474
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (6.0.1)\n",
475
+ "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (2024.11.6)\n",
476
+ "Requirement already satisfied: safetensors>=0.4.1 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (0.4.5)\n",
477
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (0.19.1)\n",
478
+ "Requirement already satisfied: psutil in /opt/conda/lib/python3.11/site-packages (from accelerate>=0.20.1->FlagEmbedding) (5.9.8)\n",
479
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (4.11.0)\n",
480
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (1.12)\n",
481
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (3.3)\n",
482
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (3.1.3)\n",
483
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
484
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
485
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
486
+ "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (8.9.2.26)\n",
487
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.3.1)\n",
488
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (11.0.2.54)\n",
489
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (10.3.2.106)\n",
490
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (11.4.5.107)\n",
491
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.0.106)\n",
492
+ "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (2.19.3)\n",
493
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
494
+ "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (2.2.0)\n",
495
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.6.0->FlagEmbedding) (12.4.127)\n",
496
+ "Requirement already satisfied: beautifulsoup4>=4.4.1 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (4.12.3)\n",
497
+ "Requirement already satisfied: inscriptis>=2.2.0 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (2.5.0)\n",
498
+ "Requirement already satisfied: lxml>=4.5.2 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (5.3.0)\n",
499
+ "Requirement already satisfied: trec-car-tools>=2.5.4 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (2.6)\n",
500
+ "Requirement already satisfied: lz4>=3.1.10 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (4.3.3)\n",
501
+ "Requirement already satisfied: warc3-wet>=0.2.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
502
+ "Requirement already satisfied: warc3-wet-clueweb09>=0.2.5 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
503
+ "Requirement already satisfied: zlib-state>=0.1.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.1.9)\n",
504
+ "Requirement already satisfied: ijson>=3.1.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (3.3.0)\n",
505
+ "Requirement already satisfied: unlzw3>=0.2.1 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.2)\n",
506
+ "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (1.4.2)\n",
507
+ "Requirement already satisfied: scipy in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (1.13.0)\n",
508
+ "Requirement already satisfied: Pillow in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (10.3.0)\n",
509
+ "Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.11/site-packages (from beautifulsoup4>=4.4.1->ir-datasets->FlagEmbedding) (2.5)\n",
510
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (2.4.4)\n",
511
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.3.2)\n",
512
+ "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (23.2.0)\n",
513
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.5.0)\n",
514
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (6.1.0)\n",
515
+ "Requirement already satisfied: propcache>=0.2.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (0.2.1)\n",
516
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.18.3)\n",
517
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.3.2)\n",
518
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.7)\n",
519
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2.2.1)\n",
520
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2024.2.2)\n",
521
+ "Requirement already satisfied: cbor>=1.0.0 in /opt/conda/lib/python3.11/site-packages (from trec-car-tools>=2.5.4->ir-datasets->FlagEmbedding) (1.0.0)\n",
522
+ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch>=1.6.0->FlagEmbedding) (2.1.5)\n",
523
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2.9.0)\n",
524
+ "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.1)\n",
525
+ "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.1)\n",
526
+ "Requirement already satisfied: joblib>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (1.4.0)\n",
527
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (3.4.0)\n",
528
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.11/site-packages (from sympy->torch>=1.6.0->FlagEmbedding) (1.3.0)\n",
529
+ "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets==2.19.0->FlagEmbedding) (1.16.0)\n",
530
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
531
+ "\u001b[0m"
532
+ ]
533
+ },
534
+ {
535
+ "data": {
536
+ "application/vnd.jupyter.widget-view+json": {
537
+ "model_id": "a24dee20be054f138b75c100ab2e6a36",
538
+ "version_major": 2,
539
+ "version_minor": 0
540
+ },
541
+ "text/plain": [
542
+ "Fetching 30 files: 0%| | 0/30 [00:00<?, ?it/s]"
543
+ ]
544
+ },
545
+ "metadata": {},
546
+ "output_type": "display_data"
547
+ },
548
+ {
549
+ "name": "stdout",
550
+ "output_type": "stream",
551
+ "text": [
552
+ "Encoding titles...\n"
553
+ ]
554
+ },
555
+ {
556
+ "name": "stderr",
557
+ "output_type": "stream",
558
+ "text": [
559
+ "You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
560
+ ]
561
+ },
562
+ {
563
+ "name": "stdout",
564
+ "output_type": "stream",
565
+ "text": [
566
+ "Processed 20/20 titles\n"
567
+ ]
568
+ }
569
+ ],
570
+ "source": [
571
+ "!pip install FlagEmbedding\n",
572
+ "from FlagEmbedding import BGEM3FlagModel\n",
573
+ "model = BGEM3FlagModel('BAAI/bge-m3')\n",
574
+ "\n",
575
+ "# Remember to change the test path\n",
576
+ "test_data_path = \"/home/jovyan/work/test_data_random_subset.csv\"\n",
577
+ "\n",
578
+ "data = data = pd.read_csv(test_data_path)\n",
579
+ "titles = data['title'].tolist()\n",
580
+ "labels = data['labels'].tolist()\n",
581
+ "\n",
582
+ "batch_size = 32\n",
583
+ "embeddings = []\n",
584
+ "\n",
585
+ "print('Encoding titles...')\n",
586
+ "for i in range(0, len(titles), batch_size):\n",
587
+ " batch = titles[i:i + batch_size]\n",
588
+ " batch_embeddings = model.encode(batch, batch_size=batch_size, max_length=512)['dense_vecs']\n",
589
+ " embeddings.extend(batch_embeddings)\n",
590
+ " print(f\"Processed {i + len(batch)}/{len(titles)} titles\")\n",
591
+ "\n",
592
+ "embeddings_df = pd.DataFrame(embeddings)\n",
593
+ "embeddings_df['label'] = labels\n",
594
+ "\n",
595
+ "# Convert embeddings and labels to PyTorch tensors\n",
596
+ "X_test = torch.FloatTensor(embeddings_df.iloc[:, :-1].values) # Features\n",
597
+ "y_test = torch.FloatTensor(embeddings_df['label'].values).view(-1, 1) # Labels\n",
598
+ "\n",
599
+ "# Create DataLoader for the test dataset\n",
600
+ "test_dataset = TensorDataset(X_test, y_test)\n",
601
+ "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
602
+ ]
603
+ },
604
+ {
605
+ "cell_type": "code",
606
+ "execution_count": 5,
607
+ "id": "c6bcf956-4e26-4278-a6fe-9955322cf06a",
608
+ "metadata": {},
609
+ "outputs": [
610
+ {
611
+ "name": "stderr",
612
+ "output_type": "stream",
613
+ "text": [
614
+ "Loading Models: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:00<00:00, 1799.05it/s]"
615
+ ]
616
+ },
617
+ {
618
+ "name": "stdout",
619
+ "output_type": "stream",
620
+ "text": [
621
+ "{'accuracy': 0.9, 'precision': 0.9, 'recall': 0.9}\n"
622
+ ]
623
+ },
624
+ {
625
+ "name": "stderr",
626
+ "output_type": "stream",
627
+ "text": [
628
+ "\n"
629
+ ]
630
+ }
631
+ ],
632
+ "source": [
633
+ "from sklearn.metrics import precision_score, recall_score\n",
634
+ "ensemble = EnsembleClassifier(5) \n",
635
+ "\n",
636
+ "# Load saved model weights\n",
637
+ "# Be sure to change to the actual path\n",
638
+ "ensemble.load_models(save_dir='/home/jovyan/work/ensemble_models/model_test_4')\n",
639
+ "\n",
640
+ "# Evaluate the ensemble\n",
641
+ "results = ensemble.evaluate(test_loader)\n",
642
+ "print(results)"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": null,
648
+ "id": "77da0a63-cf76-4cbb-8b48-da115e124946",
649
+ "metadata": {},
650
+ "outputs": [],
651
+ "source": []
652
+ }
653
+ ],
654
+ "metadata": {
655
+ "kernelspec": {
656
+ "display_name": "Python 3 (ipykernel)",
657
+ "language": "python",
658
+ "name": "python3"
659
+ },
660
+ "language_info": {
661
+ "codemirror_mode": {
662
+ "name": "ipython",
663
+ "version": 3
664
+ },
665
+ "file_extension": ".py",
666
+ "mimetype": "text/x-python",
667
+ "name": "python",
668
+ "nbconvert_exporter": "python",
669
+ "pygments_lexer": "ipython3",
670
+ "version": "3.11.8"
671
+ }
672
+ },
673
+ "nbformat": 4,
674
+ "nbformat_minor": 5
675
+ }