Jingyuan-Zhu
commited on
Commit
β’
35fcfe6
1
Parent(s):
9724b8f
Upload 2 files
Browse filesadded evaluation pipeline and model weights
- ensemble_checkpoint.pth +3 -0
- final_ensemble_pipeline.ipynb +675 -0
ensemble_checkpoint.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb5421224d0381aec5065db260c258b22e27594a432b2234d4ebb1e925c53589
|
3 |
+
size 34882689
|
final_ensemble_pipeline.ipynb
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "74bd5ceb-afa1-4bfd-ba39-10af717cf2a5",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"Remember to change the test and model Path!\n",
|
9 |
+
"Since I'm using Embedding to encode headlines to vector, it takes 10+ min. to encode information for test set which I cannot do it on my end since I do not have access to hiddne test set! "
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"id": "a458f2b7-3ab1-479f-9627-ef7ef8ef76b4",
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"import torch\n",
|
20 |
+
"import torch.nn as nn\n",
|
21 |
+
"import torch.optim as optim\n",
|
22 |
+
"from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler\n",
|
23 |
+
"from tqdm import tqdm\n",
|
24 |
+
"import numpy as np\n",
|
25 |
+
"import random\n",
|
26 |
+
"import os\n",
|
27 |
+
"import copy\n",
|
28 |
+
"from torch.utils.data import TensorDataset\n",
|
29 |
+
"import pandas as pd"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 2,
|
35 |
+
"id": "d7943628-3454-4d21-a95d-ca53acd9b6dc",
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"class LabelSmoothingBCELoss(nn.Module):\n",
|
40 |
+
" def __init__(self, smoothing=0.1):\n",
|
41 |
+
" \"\"\"\n",
|
42 |
+
" Label Smoothing Binary Cross Entropy Loss\n",
|
43 |
+
" \n",
|
44 |
+
" Args:\n",
|
45 |
+
" smoothing (float): Amount of label smoothing to apply\n",
|
46 |
+
" \"\"\"\n",
|
47 |
+
" super(LabelSmoothingBCELoss, self).__init__()\n",
|
48 |
+
" self.smoothing = smoothing\n",
|
49 |
+
" \n",
|
50 |
+
" def forward(self, predictions, targets):\n",
|
51 |
+
" \"\"\"\n",
|
52 |
+
" Compute label-smoothed binary cross entropy loss\n",
|
53 |
+
" \n",
|
54 |
+
" Args:\n",
|
55 |
+
" predictions (torch.Tensor): Model predictions\n",
|
56 |
+
" targets (torch.Tensor): Binary labels\n",
|
57 |
+
" \n",
|
58 |
+
" Returns:\n",
|
59 |
+
" torch.Tensor: Smoothed loss\n",
|
60 |
+
" \"\"\"\n",
|
61 |
+
" # Apply label smoothing\n",
|
62 |
+
" smooth_targets = targets * (1 - self.smoothing) + 0.5 * self.smoothing\n",
|
63 |
+
" \n",
|
64 |
+
" # Standard Binary Cross Entropy Loss\n",
|
65 |
+
" loss = nn.functional.binary_cross_entropy(predictions, smooth_targets)\n",
|
66 |
+
" \n",
|
67 |
+
" return loss\n",
|
68 |
+
"\n",
|
69 |
+
"class EarlyStoppingCallback:\n",
|
70 |
+
" def __init__(self, patience=5, min_delta=0.001):\n",
|
71 |
+
" \"\"\"\n",
|
72 |
+
" Early stopping mechanism\n",
|
73 |
+
" \n",
|
74 |
+
" Args:\n",
|
75 |
+
" patience (int): Number of epochs to wait for improvement\n",
|
76 |
+
" min_delta (float): Minimum change to qualify as an improvement\n",
|
77 |
+
" \"\"\"\n",
|
78 |
+
" self.patience = patience\n",
|
79 |
+
" self.min_delta = min_delta\n",
|
80 |
+
" self.counter = 0\n",
|
81 |
+
" self.best_loss = float('inf')\n",
|
82 |
+
" self.early_stop = False\n",
|
83 |
+
" self.best_model_state = None\n",
|
84 |
+
" \n",
|
85 |
+
" def __call__(self, val_loss, model):\n",
|
86 |
+
" \"\"\"\n",
|
87 |
+
" Check if training should stop\n",
|
88 |
+
" \n",
|
89 |
+
" Args:\n",
|
90 |
+
" val_loss (float): Current validation loss\n",
|
91 |
+
" model (nn.Module): Current model state\n",
|
92 |
+
" \n",
|
93 |
+
" Returns:\n",
|
94 |
+
" bool: Whether to stop training\n",
|
95 |
+
" \"\"\"\n",
|
96 |
+
" if val_loss < self.best_loss - self.min_delta:\n",
|
97 |
+
" self.best_loss = val_loss\n",
|
98 |
+
" self.counter = 0\n",
|
99 |
+
" # Save the best model state\n",
|
100 |
+
" self.best_model_state = copy.deepcopy(model.state_dict())\n",
|
101 |
+
" else:\n",
|
102 |
+
" self.counter += 1\n",
|
103 |
+
" if self.counter >= self.patience:\n",
|
104 |
+
" self.early_stop = True\n",
|
105 |
+
" \n",
|
106 |
+
" return self.early_stop\n",
|
107 |
+
"\n",
|
108 |
+
"class EnsembleMLPClassifier(nn.Module):\n",
|
109 |
+
" def __init__(self, \n",
|
110 |
+
" input_dim=1024, # BGE embedding dimension\n",
|
111 |
+
" hidden_layers=None,\n",
|
112 |
+
" dropout_rate=0.2,\n",
|
113 |
+
" activation=nn.ReLU(), # Allow passing activation functions dynamically\n",
|
114 |
+
" device=None):\n",
|
115 |
+
" super(EnsembleMLPClassifier, self).__init__()\n",
|
116 |
+
" \n",
|
117 |
+
" # Default configuration if not provided\n",
|
118 |
+
" if hidden_layers is None:\n",
|
119 |
+
" hidden_layers = [512, 256, 128]\n",
|
120 |
+
" \n",
|
121 |
+
" # Set device (GPU if available, else CPU)\n",
|
122 |
+
" self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
123 |
+
" \n",
|
124 |
+
" # Store initialization parameters\n",
|
125 |
+
" self.input_dim = input_dim\n",
|
126 |
+
" self.hidden_layers = hidden_layers\n",
|
127 |
+
" self.dropout_rate = dropout_rate\n",
|
128 |
+
" self.activation = activation\n",
|
129 |
+
" \n",
|
130 |
+
" # Add linear gate mechanism\n",
|
131 |
+
" self.gate = nn.Linear(input_dim, input_dim, bias=False)\n",
|
132 |
+
" \n",
|
133 |
+
" # Create layers dynamically based on hidden_layers specification\n",
|
134 |
+
" layers = []\n",
|
135 |
+
" prev_dim = input_dim\n",
|
136 |
+
" for hidden_dim in hidden_layers:\n",
|
137 |
+
" # Dense Layer with dynamic activation and BatchNorm\n",
|
138 |
+
" layers.extend([\n",
|
139 |
+
" nn.Linear(prev_dim, hidden_dim),\n",
|
140 |
+
" nn.BatchNorm1d(hidden_dim),\n",
|
141 |
+
" activation,\n",
|
142 |
+
" nn.Dropout(dropout_rate)\n",
|
143 |
+
" ])\n",
|
144 |
+
" prev_dim = hidden_dim\n",
|
145 |
+
" \n",
|
146 |
+
" # Final output layer for binary classification\n",
|
147 |
+
" layers.append(nn.Linear(prev_dim, 1))\n",
|
148 |
+
" layers.append(nn.Sigmoid())\n",
|
149 |
+
" \n",
|
150 |
+
" # Create the model and move to device\n",
|
151 |
+
" self.model = nn.Sequential(*layers)\n",
|
152 |
+
" self.to(self.device)\n",
|
153 |
+
"\n",
|
154 |
+
" def forward(self, x):\n",
|
155 |
+
" \"\"\"Forward pass through the network\"\"\"\n",
|
156 |
+
" # Apply gating mechanism\n",
|
157 |
+
" x = self.gate(x) * x\n",
|
158 |
+
" return self.model(x)\n",
|
159 |
+
"\n",
|
160 |
+
"class EnsembleClassifier:\n",
|
161 |
+
" def __init__(self, num_models=5, label_smoothing=0.1):\n",
|
162 |
+
" self.models = self._create_diverse_models(num_models)\n",
|
163 |
+
" self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
164 |
+
" self.label_smoothing = label_smoothing\n",
|
165 |
+
" self.model_weights = None \n",
|
166 |
+
" \n",
|
167 |
+
" def _create_diverse_models(self, num_models):\n",
|
168 |
+
" models = []\n",
|
169 |
+
" \n",
|
170 |
+
" # Predefined configurations for consistency across runs\n",
|
171 |
+
" architectures = [\n",
|
172 |
+
" {'hidden_layers': [512, 256, 128], 'dropout_rate': 0.2, 'activation': nn.ReLU()},\n",
|
173 |
+
" {'hidden_layers': [1024, 512], 'dropout_rate': 0.3, 'activation': nn.LeakyReLU()},\n",
|
174 |
+
" {'hidden_layers': [256, 128, 64], 'dropout_rate': 0.1, 'activation': nn.GELU()},\n",
|
175 |
+
" {'hidden_layers': [512, 128], 'dropout_rate': 0.25, 'activation': nn.SELU()},\n",
|
176 |
+
" {'hidden_layers': [256, 128], 'dropout_rate': 0.15, 'activation': nn.Tanh()}\n",
|
177 |
+
" ]\n",
|
178 |
+
" \n",
|
179 |
+
" # Optimizer strategies\n",
|
180 |
+
" optimizers = [optim.Adam, optim.AdamW, optim.SGD]\n",
|
181 |
+
" \n",
|
182 |
+
" for i in range(num_models):\n",
|
183 |
+
" # Use predefined architectures in a consistent order\n",
|
184 |
+
" config = architectures[i % len(architectures)]\n",
|
185 |
+
" optimizer_fn = optimizers[i % len(optimizers)]\n",
|
186 |
+
" \n",
|
187 |
+
" model = EnsembleMLPClassifier(\n",
|
188 |
+
" input_dim=1024,\n",
|
189 |
+
" hidden_layers=config['hidden_layers'],\n",
|
190 |
+
" dropout_rate=config['dropout_rate'],\n",
|
191 |
+
" activation=config['activation']\n",
|
192 |
+
" )\n",
|
193 |
+
" \n",
|
194 |
+
" # Custom weight initialization\n",
|
195 |
+
" def init_weights(m):\n",
|
196 |
+
" if isinstance(m, nn.Linear):\n",
|
197 |
+
" init_methods = [\n",
|
198 |
+
" nn.init.xavier_uniform_,\n",
|
199 |
+
" nn.init.kaiming_normal_,\n",
|
200 |
+
" nn.init.orthogonal_\n",
|
201 |
+
" ]\n",
|
202 |
+
" init_method = init_methods[i % len(init_methods)] # Consistent initialization\n",
|
203 |
+
" init_method(m.weight)\n",
|
204 |
+
" if m.bias is not None:\n",
|
205 |
+
" nn.init.zeros_(m.bias)\n",
|
206 |
+
" \n",
|
207 |
+
" model.model.apply(init_weights)\n",
|
208 |
+
" \n",
|
209 |
+
" # Attach optimizer to model instance for flexibility\n",
|
210 |
+
" model.optimizer_fn = optimizer_fn\n",
|
211 |
+
" \n",
|
212 |
+
" # Add L2 regularization to the model (Weight Decay)\n",
|
213 |
+
" model.regularization = {\n",
|
214 |
+
" 'weight_decay': 1e-5 # Example regularization value\n",
|
215 |
+
" }\n",
|
216 |
+
" \n",
|
217 |
+
" models.append(model)\n",
|
218 |
+
" \n",
|
219 |
+
" return models\n",
|
220 |
+
" \n",
|
221 |
+
" def train(self, train_dataset, batch_size=32, num_epochs=20):\n",
|
222 |
+
" for model_idx, model in enumerate(tqdm(self.models, desc=\"Training Models\", position=0)):\n",
|
223 |
+
" print(f\"Starting training for Model {model_idx + 1}/{len(self.models)}\")\n",
|
224 |
+
" \n",
|
225 |
+
" # Randomly split 80% for training and 20% for validation\n",
|
226 |
+
" total_size = len(train_dataset)\n",
|
227 |
+
" train_size = int(0.8 * total_size)\n",
|
228 |
+
" val_size = total_size - train_size\n",
|
229 |
+
" \n",
|
230 |
+
" train_subset, val_subset = random_split(train_dataset, [train_size, val_size])\n",
|
231 |
+
" \n",
|
232 |
+
" # Create data loaders for training and validation\n",
|
233 |
+
" train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)\n",
|
234 |
+
" val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)\n",
|
235 |
+
" \n",
|
236 |
+
" # Optimizer with learning rate scheduler\n",
|
237 |
+
" optimizer = optim.AdamW(model.parameters(), lr=1e-3)\n",
|
238 |
+
" scheduler = optim.lr_scheduler.CosineAnnealingLR(\n",
|
239 |
+
" optimizer, \n",
|
240 |
+
" T_max=num_epochs, \n",
|
241 |
+
" eta_min=1e-5\n",
|
242 |
+
" )\n",
|
243 |
+
" \n",
|
244 |
+
" # Label Smoothing Loss\n",
|
245 |
+
" criterion = LabelSmoothingBCELoss(smoothing=self.label_smoothing)\n",
|
246 |
+
" \n",
|
247 |
+
" # Early stopping\n",
|
248 |
+
" early_stopping = EarlyStoppingCallback(patience=4, min_delta=0.001)\n",
|
249 |
+
" \n",
|
250 |
+
" model.train()\n",
|
251 |
+
" epoch_progress = tqdm(range(num_epochs), desc=f\"Model {model_idx} Training\", position=1, leave=False)\n",
|
252 |
+
" \n",
|
253 |
+
" best_val_loss = float('inf')\n",
|
254 |
+
" for epoch in epoch_progress:\n",
|
255 |
+
" total_loss = 0\n",
|
256 |
+
" \n",
|
257 |
+
" # Training phase\n",
|
258 |
+
" for batch in train_loader:\n",
|
259 |
+
" inputs, labels = batch\n",
|
260 |
+
" inputs, labels = inputs.to(model.device), labels.to(model.device)\n",
|
261 |
+
" \n",
|
262 |
+
" optimizer.zero_grad()\n",
|
263 |
+
" outputs = model(inputs)\n",
|
264 |
+
" loss = criterion(outputs, labels.float().unsqueeze(1))\n",
|
265 |
+
" loss.backward()\n",
|
266 |
+
" optimizer.step()\n",
|
267 |
+
" \n",
|
268 |
+
" total_loss += loss.item()\n",
|
269 |
+
" avg_train_loss = total_loss / len(train_loader)\n",
|
270 |
+
" \n",
|
271 |
+
" # Validation phase\n",
|
272 |
+
" model.eval()\n",
|
273 |
+
" val_loss = 0\n",
|
274 |
+
" with torch.no_grad():\n",
|
275 |
+
" for val_batch in val_loader:\n",
|
276 |
+
" val_inputs, val_labels = val_batch\n",
|
277 |
+
" val_inputs, val_labels = val_inputs.to(model.device), val_labels.to(model.device)\n",
|
278 |
+
" val_outputs = model(val_inputs)\n",
|
279 |
+
" val_loss += criterion(val_outputs, val_labels.float().unsqueeze(1)).item()\n",
|
280 |
+
" \n",
|
281 |
+
" avg_val_loss = val_loss / len(val_loader)\n",
|
282 |
+
" epoch_progress.set_postfix({\n",
|
283 |
+
" 'train_loss': avg_train_loss,\n",
|
284 |
+
" 'val_loss': avg_val_loss\n",
|
285 |
+
" })\n",
|
286 |
+
" \n",
|
287 |
+
" # Early stopping check\n",
|
288 |
+
" if early_stopping(avg_val_loss, model):\n",
|
289 |
+
" if early_stopping.best_model_state:\n",
|
290 |
+
" model.load_state_dict(early_stopping.best_model_state)\n",
|
291 |
+
" print(f\"Early stopping triggered for Model {model_idx}\")\n",
|
292 |
+
" break\n",
|
293 |
+
" \n",
|
294 |
+
" # Learning rate adjustment\n",
|
295 |
+
" scheduler.step()\n",
|
296 |
+
" \n",
|
297 |
+
" # Reset to training mode\n",
|
298 |
+
" model.train()\n",
|
299 |
+
" \n",
|
300 |
+
" # Store model's final state after training\n",
|
301 |
+
" model.eval()\n",
|
302 |
+
" \n",
|
303 |
+
" def compute_test_weights(self, test_loader):\n",
|
304 |
+
" \"\"\"\n",
|
305 |
+
" Compute model weights based on test accuracy while emphasizing distinctions.\n",
|
306 |
+
" \"\"\"\n",
|
307 |
+
" model_accuracies = []\n",
|
308 |
+
" for model_idx, model in enumerate(self.models):\n",
|
309 |
+
" correct = 0\n",
|
310 |
+
" total = 0\n",
|
311 |
+
" model.eval()\n",
|
312 |
+
" with torch.no_grad():\n",
|
313 |
+
" for inputs, labels in test_loader:\n",
|
314 |
+
" inputs, labels = inputs.to(model.device), labels.to(model.device)\n",
|
315 |
+
" outputs = model(inputs)\n",
|
316 |
+
" preds = (outputs > 0.5).float()\n",
|
317 |
+
" correct += (preds == labels).sum().item()\n",
|
318 |
+
" total += labels.size(0)\n",
|
319 |
+
" accuracy = correct / total\n",
|
320 |
+
" model_accuracies.append(accuracy)\n",
|
321 |
+
" \n",
|
322 |
+
" # Apply a power transformation for distinction\n",
|
323 |
+
" accuracies = np.array(model_accuracies)\n",
|
324 |
+
" print(f\"Raw model accuracies: {accuracies}\")\n",
|
325 |
+
" \n",
|
326 |
+
" # Use power scaling to exaggerate differences (e.g., square the accuracies)\n",
|
327 |
+
" power_scaling_factor = 2 # Choose 2 for squaring, can experiment with higher values\n",
|
328 |
+
" scaled_accuracies = accuracies ** power_scaling_factor\n",
|
329 |
+
" \n",
|
330 |
+
" # Smooth the accuracies slightly to avoid over-reliance on any single model\n",
|
331 |
+
" smoothed_accuracies = scaled_accuracies * (1 - 0.1) + 0.1 * np.mean(scaled_accuracies)\n",
|
332 |
+
" \n",
|
333 |
+
" # Normalize weights so they sum to 1\n",
|
334 |
+
" weights = smoothed_accuracies / smoothed_accuracies.sum()\n",
|
335 |
+
" \n",
|
336 |
+
" # Store model weights\n",
|
337 |
+
" self.model_weights = torch.tensor(weights, dtype=torch.float32).to(self.device)\n",
|
338 |
+
" print(f\"Model weights after scaling: {self.model_weights}\")\n",
|
339 |
+
"\n",
|
340 |
+
"\n",
|
341 |
+
" def predict(self, test_loader, confidence_threshold=0.5, return_raw_scores=True):\n",
|
342 |
+
" \"\"\"\n",
|
343 |
+
" Prediction with confidence-weighted voting, optionally returning raw scores.\n",
|
344 |
+
" \"\"\"\n",
|
345 |
+
" if self.model_weights is None:\n",
|
346 |
+
" raise ValueError(\"Model weights not computed. Call compute_test_weights first.\")\n",
|
347 |
+
" \n",
|
348 |
+
" all_predictions = []\n",
|
349 |
+
" for model_idx, model in enumerate(self.models):\n",
|
350 |
+
" model.eval()\n",
|
351 |
+
" model_preds = []\n",
|
352 |
+
" with torch.no_grad():\n",
|
353 |
+
" for batch in test_loader:\n",
|
354 |
+
" inputs, _ = batch\n",
|
355 |
+
" inputs = inputs.to(model.device)\n",
|
356 |
+
" outputs = model(inputs)\n",
|
357 |
+
" model_preds.append(outputs)\n",
|
358 |
+
" \n",
|
359 |
+
" # Concatenate predictions for this model\n",
|
360 |
+
" all_predictions.append(torch.cat(model_preds))\n",
|
361 |
+
" \n",
|
362 |
+
" # Stack predictions and compute weighted average\n",
|
363 |
+
" stacked_preds = torch.stack(all_predictions, dim=1).squeeze(-1)\n",
|
364 |
+
" weighted_preds = (stacked_preds * self.model_weights.view(1, -1)).sum(dim=1)\n",
|
365 |
+
" \n",
|
366 |
+
" # Final prediction with thresholding\n",
|
367 |
+
" final_preds = (weighted_preds > confidence_threshold).float()\n",
|
368 |
+
" \n",
|
369 |
+
" # Optionally return raw probabilities for debugging\n",
|
370 |
+
" if return_raw_scores:\n",
|
371 |
+
" return final_preds, weighted_preds.cpu().numpy()\n",
|
372 |
+
" \n",
|
373 |
+
" return final_preds\n",
|
374 |
+
"\n",
|
375 |
+
"\n",
|
376 |
+
" def save_models(self, save_dir='ensemble_models/model_test_4'):\n",
|
377 |
+
" \"\"\"\n",
|
378 |
+
" Save ensemble model weights and model weights with progress tracking\n",
|
379 |
+
" \"\"\"\n",
|
380 |
+
" os.makedirs(save_dir, exist_ok=True)\n",
|
381 |
+
"\n",
|
382 |
+
" save_data = {\n",
|
383 |
+
" 'models': {},\n",
|
384 |
+
" 'model_weights': self.model_weights.cpu().numpy() if self.model_weights is not None else None\n",
|
385 |
+
" }\n",
|
386 |
+
"\n",
|
387 |
+
" for i, model in tqdm(enumerate(self.models), desc=\"Saving Models\", total=len(self.models)):\n",
|
388 |
+
" save_data['models'][i] = model.state_dict()\n",
|
389 |
+
"\n",
|
390 |
+
" torch.save(save_data, os.path.join(save_dir, 'ensemble_checkpoint.pth'))\n",
|
391 |
+
"\n",
|
392 |
+
" def load_models(self, save_dir='ensemble_models/model_test_4'):\n",
|
393 |
+
" \"\"\"\n",
|
394 |
+
" Load ensemble model weights and model weights with progress tracking\n",
|
395 |
+
" \"\"\"\n",
|
396 |
+
" checkpoint_path = os.path.join(save_dir, 'ensemble_checkpoint.pth')\n",
|
397 |
+
"\n",
|
398 |
+
" save_data = torch.load(checkpoint_path)\n",
|
399 |
+
"\n",
|
400 |
+
" for i, model in tqdm(enumerate(self.models), desc=\"Loading Models\", total=len(self.models)):\n",
|
401 |
+
" model.load_state_dict(save_data['models'][i])\n",
|
402 |
+
" model.eval() # Set to evaluation mode\n",
|
403 |
+
"\n",
|
404 |
+
" if save_data['model_weights'] is not None:\n",
|
405 |
+
" self.model_weights = torch.tensor(save_data['model_weights'], dtype=torch.float32).to(self.device)\n",
|
406 |
+
" \n",
|
407 |
+
" def evaluate(self, test_loader):\n",
|
408 |
+
" \"\"\"\n",
|
409 |
+
" Evaluate ensemble performance with weighted voting, supporting both CPU and GPU.\n",
|
410 |
+
" \"\"\"\n",
|
411 |
+
" # Collect ground truth labels\n",
|
412 |
+
" all_labels = torch.cat([labels for _, labels in test_loader], dim=0).to(self.device)\n",
|
413 |
+
" \n",
|
414 |
+
" # Get predictions for the entire test set\n",
|
415 |
+
" test_preds = self.predict(test_loader, return_raw_scores=True)\n",
|
416 |
+
" \n",
|
417 |
+
" # Ensure predictions and labels are on the same device\n",
|
418 |
+
" all_labels = all_labels.cpu().numpy().ravel() # Flatten to 1D\n",
|
419 |
+
" test_preds, raw_probs = test_preds\n",
|
420 |
+
" test_preds = test_preds.cpu().numpy().ravel() # Flatten to 1D\n",
|
421 |
+
" \n",
|
422 |
+
" # Print debug information\n",
|
423 |
+
" # print(\"Ground truth labels (all_labels):\", all_labels)\n",
|
424 |
+
" # print(\"Predicted classes (test_preds):\", test_preds)\n",
|
425 |
+
" # print(\"Raw probabilities (raw_probs):\", raw_probs) \n",
|
426 |
+
" \n",
|
427 |
+
" # Calculate metrics\n",
|
428 |
+
" accuracy = np.mean(test_preds == all_labels)\n",
|
429 |
+
" precision = precision_score(all_labels, test_preds, zero_division=0)\n",
|
430 |
+
" recall = recall_score(all_labels, test_preds, zero_division=0)\n",
|
431 |
+
" \n",
|
432 |
+
" return {\n",
|
433 |
+
" \"accuracy\": accuracy,\n",
|
434 |
+
" \"precision\": precision,\n",
|
435 |
+
" \"recall\": recall\n",
|
436 |
+
" }"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "code",
|
441 |
+
"execution_count": 3,
|
442 |
+
"id": "a95bb0eb-48ba-4c46-9cc5-4f6a1ee19dee",
|
443 |
+
"metadata": {},
|
444 |
+
"outputs": [
|
445 |
+
{
|
446 |
+
"name": "stdout",
|
447 |
+
"output_type": "stream",
|
448 |
+
"text": [
|
449 |
+
"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
|
450 |
+
"Requirement already satisfied: FlagEmbedding in /opt/conda/lib/python3.11/site-packages (1.3.3)\n",
|
451 |
+
"Requirement already satisfied: torch>=1.6.0 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (2.2.2+cu121)\n",
|
452 |
+
"Requirement already satisfied: transformers==4.44.2 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (4.44.2)\n",
|
453 |
+
"Requirement already satisfied: datasets==2.19.0 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (2.19.0)\n",
|
454 |
+
"Requirement already satisfied: accelerate>=0.20.1 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (1.2.0)\n",
|
455 |
+
"Requirement already satisfied: sentence-transformers in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (3.3.1)\n",
|
456 |
+
"Requirement already satisfied: peft in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.14.0)\n",
|
457 |
+
"Requirement already satisfied: ir-datasets in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.5.9)\n",
|
458 |
+
"Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.2.0)\n",
|
459 |
+
"Requirement already satisfied: protobuf in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (4.25.3)\n",
|
460 |
+
"Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.9.0)\n",
|
461 |
+
"Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (1.26.4)\n",
|
462 |
+
"Requirement already satisfied: pyarrow>=12.0.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (15.0.2)\n",
|
463 |
+
"Requirement already satisfied: pyarrow-hotfix in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.6)\n",
|
464 |
+
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.3.8)\n",
|
465 |
+
"Requirement already satisfied: pandas in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (2.2.2)\n",
|
466 |
+
"Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (2.31.0)\n",
|
467 |
+
"Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (4.66.2)\n",
|
468 |
+
"Requirement already satisfied: xxhash in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.5.0)\n",
|
469 |
+
"Requirement already satisfied: multiprocess in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.70.16)\n",
|
470 |
+
"Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /opt/conda/lib/python3.11/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets==2.19.0->FlagEmbedding) (2024.3.1)\n",
|
471 |
+
"Requirement already satisfied: aiohttp in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.11.10)\n",
|
472 |
+
"Requirement already satisfied: huggingface-hub>=0.21.2 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.26.5)\n",
|
473 |
+
"Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (24.0)\n",
|
474 |
+
"Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (6.0.1)\n",
|
475 |
+
"Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (2024.11.6)\n",
|
476 |
+
"Requirement already satisfied: safetensors>=0.4.1 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (0.4.5)\n",
|
477 |
+
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (0.19.1)\n",
|
478 |
+
"Requirement already satisfied: psutil in /opt/conda/lib/python3.11/site-packages (from accelerate>=0.20.1->FlagEmbedding) (5.9.8)\n",
|
479 |
+
"Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (4.11.0)\n",
|
480 |
+
"Requirement already satisfied: sympy in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (1.12)\n",
|
481 |
+
"Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (3.3)\n",
|
482 |
+
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (3.1.3)\n",
|
483 |
+
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
|
484 |
+
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
|
485 |
+
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
|
486 |
+
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (8.9.2.26)\n",
|
487 |
+
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.3.1)\n",
|
488 |
+
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (11.0.2.54)\n",
|
489 |
+
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (10.3.2.106)\n",
|
490 |
+
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (11.4.5.107)\n",
|
491 |
+
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.0.106)\n",
|
492 |
+
"Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (2.19.3)\n",
|
493 |
+
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
|
494 |
+
"Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (2.2.0)\n",
|
495 |
+
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.6.0->FlagEmbedding) (12.4.127)\n",
|
496 |
+
"Requirement already satisfied: beautifulsoup4>=4.4.1 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (4.12.3)\n",
|
497 |
+
"Requirement already satisfied: inscriptis>=2.2.0 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (2.5.0)\n",
|
498 |
+
"Requirement already satisfied: lxml>=4.5.2 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (5.3.0)\n",
|
499 |
+
"Requirement already satisfied: trec-car-tools>=2.5.4 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (2.6)\n",
|
500 |
+
"Requirement already satisfied: lz4>=3.1.10 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (4.3.3)\n",
|
501 |
+
"Requirement already satisfied: warc3-wet>=0.2.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
|
502 |
+
"Requirement already satisfied: warc3-wet-clueweb09>=0.2.5 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
|
503 |
+
"Requirement already satisfied: zlib-state>=0.1.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.1.9)\n",
|
504 |
+
"Requirement already satisfied: ijson>=3.1.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (3.3.0)\n",
|
505 |
+
"Requirement already satisfied: unlzw3>=0.2.1 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.2)\n",
|
506 |
+
"Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (1.4.2)\n",
|
507 |
+
"Requirement already satisfied: scipy in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (1.13.0)\n",
|
508 |
+
"Requirement already satisfied: Pillow in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (10.3.0)\n",
|
509 |
+
"Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.11/site-packages (from beautifulsoup4>=4.4.1->ir-datasets->FlagEmbedding) (2.5)\n",
|
510 |
+
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (2.4.4)\n",
|
511 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.3.2)\n",
|
512 |
+
"Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (23.2.0)\n",
|
513 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.5.0)\n",
|
514 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (6.1.0)\n",
|
515 |
+
"Requirement already satisfied: propcache>=0.2.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (0.2.1)\n",
|
516 |
+
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.18.3)\n",
|
517 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.3.2)\n",
|
518 |
+
"Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.7)\n",
|
519 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2.2.1)\n",
|
520 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2024.2.2)\n",
|
521 |
+
"Requirement already satisfied: cbor>=1.0.0 in /opt/conda/lib/python3.11/site-packages (from trec-car-tools>=2.5.4->ir-datasets->FlagEmbedding) (1.0.0)\n",
|
522 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch>=1.6.0->FlagEmbedding) (2.1.5)\n",
|
523 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2.9.0)\n",
|
524 |
+
"Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.1)\n",
|
525 |
+
"Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.1)\n",
|
526 |
+
"Requirement already satisfied: joblib>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (1.4.0)\n",
|
527 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (3.4.0)\n",
|
528 |
+
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.11/site-packages (from sympy->torch>=1.6.0->FlagEmbedding) (1.3.0)\n",
|
529 |
+
"Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets==2.19.0->FlagEmbedding) (1.16.0)\n",
|
530 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
531 |
+
"\u001b[0m"
|
532 |
+
]
|
533 |
+
},
|
534 |
+
{
|
535 |
+
"data": {
|
536 |
+
"application/vnd.jupyter.widget-view+json": {
|
537 |
+
"model_id": "a24dee20be054f138b75c100ab2e6a36",
|
538 |
+
"version_major": 2,
|
539 |
+
"version_minor": 0
|
540 |
+
},
|
541 |
+
"text/plain": [
|
542 |
+
"Fetching 30 files: 0%| | 0/30 [00:00<?, ?it/s]"
|
543 |
+
]
|
544 |
+
},
|
545 |
+
"metadata": {},
|
546 |
+
"output_type": "display_data"
|
547 |
+
},
|
548 |
+
{
|
549 |
+
"name": "stdout",
|
550 |
+
"output_type": "stream",
|
551 |
+
"text": [
|
552 |
+
"Encoding titles...\n"
|
553 |
+
]
|
554 |
+
},
|
555 |
+
{
|
556 |
+
"name": "stderr",
|
557 |
+
"output_type": "stream",
|
558 |
+
"text": [
|
559 |
+
"You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
560 |
+
]
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"name": "stdout",
|
564 |
+
"output_type": "stream",
|
565 |
+
"text": [
|
566 |
+
"Processed 20/20 titles\n"
|
567 |
+
]
|
568 |
+
}
|
569 |
+
],
|
570 |
+
"source": [
|
571 |
+
"!pip install FlagEmbedding\n",
|
572 |
+
"from FlagEmbedding import BGEM3FlagModel\n",
|
573 |
+
"model = BGEM3FlagModel('BAAI/bge-m3')\n",
|
574 |
+
"\n",
|
575 |
+
"# Remember to change the test path\n",
|
576 |
+
"test_data_path = \"/home/jovyan/work/test_data_random_subset.csv\"\n",
|
577 |
+
"\n",
|
578 |
+
"data = data = pd.read_csv(test_data_path)\n",
|
579 |
+
"titles = data['title'].tolist()\n",
|
580 |
+
"labels = data['labels'].tolist()\n",
|
581 |
+
"\n",
|
582 |
+
"batch_size = 32\n",
|
583 |
+
"embeddings = []\n",
|
584 |
+
"\n",
|
585 |
+
"print('Encoding titles...')\n",
|
586 |
+
"for i in range(0, len(titles), batch_size):\n",
|
587 |
+
" batch = titles[i:i + batch_size]\n",
|
588 |
+
" batch_embeddings = model.encode(batch, batch_size=batch_size, max_length=512)['dense_vecs']\n",
|
589 |
+
" embeddings.extend(batch_embeddings)\n",
|
590 |
+
" print(f\"Processed {i + len(batch)}/{len(titles)} titles\")\n",
|
591 |
+
"\n",
|
592 |
+
"embeddings_df = pd.DataFrame(embeddings)\n",
|
593 |
+
"embeddings_df['label'] = labels\n",
|
594 |
+
"\n",
|
595 |
+
"# Convert embeddings and labels to PyTorch tensors\n",
|
596 |
+
"X_test = torch.FloatTensor(embeddings_df.iloc[:, :-1].values) # Features\n",
|
597 |
+
"y_test = torch.FloatTensor(embeddings_df['label'].values).view(-1, 1) # Labels\n",
|
598 |
+
"\n",
|
599 |
+
"# Create DataLoader for the test dataset\n",
|
600 |
+
"test_dataset = TensorDataset(X_test, y_test)\n",
|
601 |
+
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
|
602 |
+
]
|
603 |
+
},
|
604 |
+
{
|
605 |
+
"cell_type": "code",
|
606 |
+
"execution_count": 5,
|
607 |
+
"id": "c6bcf956-4e26-4278-a6fe-9955322cf06a",
|
608 |
+
"metadata": {},
|
609 |
+
"outputs": [
|
610 |
+
{
|
611 |
+
"name": "stderr",
|
612 |
+
"output_type": "stream",
|
613 |
+
"text": [
|
614 |
+
"Loading Models: 100%|ββββββββββ| 5/5 [00:00<00:00, 1799.05it/s]"
|
615 |
+
]
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"name": "stdout",
|
619 |
+
"output_type": "stream",
|
620 |
+
"text": [
|
621 |
+
"{'accuracy': 0.9, 'precision': 0.9, 'recall': 0.9}\n"
|
622 |
+
]
|
623 |
+
},
|
624 |
+
{
|
625 |
+
"name": "stderr",
|
626 |
+
"output_type": "stream",
|
627 |
+
"text": [
|
628 |
+
"\n"
|
629 |
+
]
|
630 |
+
}
|
631 |
+
],
|
632 |
+
"source": [
|
633 |
+
"from sklearn.metrics import precision_score, recall_score\n",
|
634 |
+
"ensemble = EnsembleClassifier(5) \n",
|
635 |
+
"\n",
|
636 |
+
"# Load saved model weights\n",
|
637 |
+
"# Be sure to change to the actual path\n",
|
638 |
+
"ensemble.load_models(save_dir='/home/jovyan/work/ensemble_models/model_test_4')\n",
|
639 |
+
"\n",
|
640 |
+
"# Evaluate the ensemble\n",
|
641 |
+
"results = ensemble.evaluate(test_loader)\n",
|
642 |
+
"print(results)"
|
643 |
+
]
|
644 |
+
},
|
645 |
+
{
|
646 |
+
"cell_type": "code",
|
647 |
+
"execution_count": null,
|
648 |
+
"id": "77da0a63-cf76-4cbb-8b48-da115e124946",
|
649 |
+
"metadata": {},
|
650 |
+
"outputs": [],
|
651 |
+
"source": []
|
652 |
+
}
|
653 |
+
],
|
654 |
+
"metadata": {
|
655 |
+
"kernelspec": {
|
656 |
+
"display_name": "Python 3 (ipykernel)",
|
657 |
+
"language": "python",
|
658 |
+
"name": "python3"
|
659 |
+
},
|
660 |
+
"language_info": {
|
661 |
+
"codemirror_mode": {
|
662 |
+
"name": "ipython",
|
663 |
+
"version": 3
|
664 |
+
},
|
665 |
+
"file_extension": ".py",
|
666 |
+
"mimetype": "text/x-python",
|
667 |
+
"name": "python",
|
668 |
+
"nbconvert_exporter": "python",
|
669 |
+
"pygments_lexer": "ipython3",
|
670 |
+
"version": "3.11.8"
|
671 |
+
}
|
672 |
+
},
|
673 |
+
"nbformat": 4,
|
674 |
+
"nbformat_minor": 5
|
675 |
+
}
|