Mohammaderfan koupaei commited on
Commit
06ca50d
·
1 Parent(s): b5e09fa
scripts/config/config.py CHANGED
@@ -12,22 +12,25 @@ class TrainingConfig:
12
 
13
  # Training parameters
14
  num_epochs: int = 5
15
- learning_rate: float = 2e-5
16
- warmup_ratio: float = 0.1
17
  weight_decay: float = 0.01
18
  max_grad_norm: float = 1.0
19
  gradient_accumulation_steps: int = 4
20
- fp16: bool = True # Enable mixed precision training
21
- max_length: int = 256 # Reduce from 512
22
- batch_size: int = 4 # Reduce from 8
23
 
24
  # Data parameters
25
- max_length: int = 512
 
26
  train_ratio: float = 0.8
27
 
 
 
 
 
28
  # Output parameters
29
  output_dir: Path = Path("outputs")
30
- save_steps: int = 100
31
  eval_steps: int = 50
32
 
33
  # Device
 
12
 
13
  # Training parameters
14
  num_epochs: int = 5
15
+ learning_rate: float = 1e-5 # Reduced from 2e-5
16
+ warmup_ratio: float = 0.2 # Increased from 0.1
17
  weight_decay: float = 0.01
18
  max_grad_norm: float = 1.0
19
  gradient_accumulation_steps: int = 4
20
+ fp16: bool = True
 
 
21
 
22
  # Data parameters
23
+ max_length: int = 256
24
+ batch_size: int = 4
25
  train_ratio: float = 0.8
26
 
27
+ # Loss parameters
28
+ pos_weight_multiplier: float = 5.0 # Weight multiplier for positive classes
29
+ label_smoothing: float = 0.1 # Label smoothing factor
30
+
31
  # Output parameters
32
  output_dir: Path = Path("outputs")
33
+ save_steps: int = 50
34
  eval_steps: int = 50
35
 
36
  # Device
scripts/training/trainer.py CHANGED
@@ -11,7 +11,7 @@ from datetime import datetime
11
  from torch.cuda.amp import autocast, GradScaler
12
 
13
  class NarrativeTrainer:
14
- """Comprehensive trainer for narrative classification with GPU memory optimizations"""
15
  def __init__(
16
  self,
17
  model,
@@ -60,28 +60,33 @@ class NarrativeTrainer:
60
  self.history = {
61
  'train_loss': [],
62
  'val_loss': [],
63
- 'val_f1': [],
64
- 'val_precision': [],
65
- 'val_recall': []
66
  }
67
-
68
  def setup_logging(self):
69
- """Initialize logging configuration"""
70
  logging.basicConfig(
71
  level=logging.INFO,
72
  format='%(asctime)s - %(levelname)s - %(message)s',
73
  datefmt='%Y-%m-%d %H:%M:%S'
74
  )
75
-
 
 
 
 
 
 
 
76
  def setup_training(self):
77
- """Initialize training components with memory optimizations"""
78
  # Create dataloaders
79
  self.train_loader = DataLoader(
80
  self.train_dataset,
81
  batch_size=self.config.batch_size,
82
  shuffle=True,
83
  num_workers=4,
84
- pin_memory=True # Optimize data transfer to GPU
85
  )
86
 
87
  self.val_loader = DataLoader(
@@ -91,6 +96,15 @@ class NarrativeTrainer:
91
  pin_memory=True
92
  )
93
 
 
 
 
 
 
 
 
 
 
94
  # Setup optimizer
95
  self.optimizer = torch.optim.AdamW(
96
  self.model.parameters(),
@@ -98,7 +112,7 @@ class NarrativeTrainer:
98
  weight_decay=self.config.weight_decay
99
  )
100
 
101
- # Setup scheduler with gradient accumulation steps
102
  num_update_steps_per_epoch = len(self.train_loader) // self.config.gradient_accumulation_steps
103
  num_training_steps = num_update_steps_per_epoch * self.config.num_epochs
104
  num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
@@ -109,17 +123,77 @@ class NarrativeTrainer:
109
  num_training_steps=num_training_steps
110
  )
111
 
112
- self.criterion = torch.nn.BCEWithLogitsLoss()
113
-
 
114
  def save_config(self):
115
- """Save training configuration"""
116
  config_dict = {k: str(v) for k, v in vars(self.config).items()}
117
  config_path = self.output_dir / 'config.json'
118
  with open(config_path, 'w') as f:
119
  json.dump(config_dict, f, indent=4)
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def train_epoch(self):
122
- """Train for one epoch with memory optimizations"""
123
  self.model.train()
124
  total_loss = 0
125
  self.optimizer.zero_grad()
@@ -129,10 +203,8 @@ class NarrativeTrainer:
129
  desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}')
130
 
131
  for step, batch in pbar:
132
- # Move batch to device
133
  batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}
134
 
135
- # Mixed precision forward pass
136
  with autocast(enabled=self.config.fp16):
137
  outputs = self.model(
138
  input_ids=batch['input_ids'],
@@ -142,10 +214,8 @@ class NarrativeTrainer:
142
  loss = self.criterion(outputs, batch['labels'])
143
  loss = loss / self.config.gradient_accumulation_steps
144
 
145
- # Scaled backward pass
146
  self.scaler.scale(loss).backward()
147
 
148
- # Update weights if we've accumulated enough gradients
149
  if (step + 1) % self.config.gradient_accumulation_steps == 0:
150
  self.scaler.unscale_(self.optimizer)
151
  torch.nn.utils.clip_grad_norm_(
@@ -158,33 +228,29 @@ class NarrativeTrainer:
158
  self.scheduler.step()
159
  self.optimizer.zero_grad()
160
 
161
- # Update metrics
162
  total_loss += loss.item() * self.config.gradient_accumulation_steps
163
  avg_loss = total_loss / (step + 1)
164
  pbar.set_postfix({'loss': f'{avg_loss:.4f}'})
165
 
166
  self.global_step += 1
167
 
168
- # Evaluate if needed
169
  if self.global_step % self.config.eval_steps == 0:
170
  self.evaluate()
171
 
172
- # Clear memory periodically
173
  if step % 10 == 0:
174
  torch.cuda.empty_cache()
175
 
176
- # Clear unnecessary tensors
177
  del outputs
178
  del loss
179
 
180
  return total_loss / len(self.train_loader)
181
-
182
  @torch.no_grad()
183
  def evaluate(self):
184
- """Evaluate model with memory optimizations"""
185
  self.model.eval()
186
  total_loss = 0
187
- all_preds, all_labels = [], []
188
 
189
  for batch in tqdm(self.val_loader, desc="Evaluating"):
190
  batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}
@@ -198,40 +264,38 @@ class NarrativeTrainer:
198
  loss = self.criterion(outputs, batch['labels'])
199
 
200
  total_loss += loss.item()
 
 
201
 
202
- # CPU computations for predictions
203
- preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy()
204
- labels = batch['labels'].cpu().numpy()
205
-
206
- all_preds.append(preds)
207
- all_labels.append(labels)
208
-
209
- # Clear memory
210
  del outputs
211
  del loss
212
  torch.cuda.empty_cache()
213
 
214
- # Compute metrics
215
- all_preds = np.concatenate(all_preds, axis=0)
216
- all_labels = np.concatenate(all_labels, axis=0)
217
 
218
- metrics = {
219
- 'loss': total_loss / len(self.val_loader),
220
- 'f1': f1_score(all_labels, all_preds, average='micro'),
221
- 'precision': precision_score(all_labels, all_preds, average='micro'),
222
- 'recall': recall_score(all_labels, all_preds, average='micro')
223
- }
224
 
225
- self.logger.info(f"Step {self.global_step} - Validation metrics: {metrics}")
 
 
226
 
227
- if metrics['f1'] > self.best_val_f1:
228
- self.best_val_f1 = metrics['f1']
 
 
 
 
 
 
 
 
229
  self.save_model('best_model.pt', metrics)
230
 
231
  return metrics
232
-
233
  def save_model(self, filename: str, metrics: dict = None):
234
- """Save model checkpoint"""
235
  save_path = self.output_dir / filename
236
  torch.save({
237
  'model_state_dict': self.model.state_dict(),
@@ -241,10 +305,11 @@ class NarrativeTrainer:
241
  'epoch': self.current_epoch,
242
  'global_step': self.global_step,
243
  'best_val_f1': self.best_val_f1,
244
- 'metrics': metrics
 
245
  }, save_path)
246
  self.logger.info(f"Model saved to {save_path}")
247
-
248
  def train(self):
249
  """Run complete training loop"""
250
  self.logger.info("Starting training...")
@@ -257,14 +322,11 @@ class NarrativeTrainer:
257
  self.history['train_loss'].append(train_loss)
258
 
259
  val_metrics = self.evaluate()
260
- self.history['val_loss'].append(val_metrics['loss'])
261
- self.history['val_f1'].append(val_metrics['f1'])
262
- self.history['val_precision'].append(val_metrics['precision'])
263
- self.history['val_recall'].append(val_metrics['recall'])
264
 
265
  self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics)
266
 
267
- # Save training history
268
  history_path = self.output_dir / 'history.json'
269
  with open(history_path, 'w') as f:
270
  json.dump(self.history, f, indent=4)
 
11
  from torch.cuda.amp import autocast, GradScaler
12
 
13
  class NarrativeTrainer:
14
+ """Enhanced trainer with detailed metrics and optimizations"""
15
  def __init__(
16
  self,
17
  model,
 
60
  self.history = {
61
  'train_loss': [],
62
  'val_loss': [],
63
+ 'metrics': [],
64
+ 'thresholds': []
 
65
  }
66
+
67
  def setup_logging(self):
 
68
  logging.basicConfig(
69
  level=logging.INFO,
70
  format='%(asctime)s - %(levelname)s - %(message)s',
71
  datefmt='%Y-%m-%d %H:%M:%S'
72
  )
73
+
74
+ def calculate_class_weights(self):
75
+ """Calculate weights for imbalanced classes"""
76
+ pos_counts = self.train_dataset.labels.sum(dim=0)
77
+ neg_counts = len(self.train_dataset) - pos_counts
78
+ pos_weight = (neg_counts / pos_counts) * self.config.pos_weight_multiplier
79
+ return torch.clamp(pos_weight, min=1.0, max=50.0).to(self.device)
80
+
81
  def setup_training(self):
82
+ """Initialize training components with optimizations"""
83
  # Create dataloaders
84
  self.train_loader = DataLoader(
85
  self.train_dataset,
86
  batch_size=self.config.batch_size,
87
  shuffle=True,
88
  num_workers=4,
89
+ pin_memory=True
90
  )
91
 
92
  self.val_loader = DataLoader(
 
96
  pin_memory=True
97
  )
98
 
99
+ # Calculate class weights
100
+ pos_weight = self.calculate_class_weights()
101
+
102
+ # Setup loss function with class weights
103
+ self.criterion = torch.nn.BCEWithLogitsLoss(
104
+ pos_weight=pos_weight,
105
+ label_smoothing=self.config.label_smoothing
106
+ )
107
+
108
  # Setup optimizer
109
  self.optimizer = torch.optim.AdamW(
110
  self.model.parameters(),
 
112
  weight_decay=self.config.weight_decay
113
  )
114
 
115
+ # Setup scheduler
116
  num_update_steps_per_epoch = len(self.train_loader) // self.config.gradient_accumulation_steps
117
  num_training_steps = num_update_steps_per_epoch * self.config.num_epochs
118
  num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
 
123
  num_training_steps=num_training_steps
124
  )
125
 
126
+ # Initialize thresholds
127
+ self.label_thresholds = torch.ones(self.train_dataset.get_num_labels()).to(self.device) * 0.5
128
+
129
  def save_config(self):
 
130
  config_dict = {k: str(v) for k, v in vars(self.config).items()}
131
  config_path = self.output_dir / 'config.json'
132
  with open(config_path, 'w') as f:
133
  json.dump(config_dict, f, indent=4)
134
+
135
+ def find_optimal_thresholds(self, val_outputs, val_labels):
136
+ """Find optimal threshold for each label"""
137
+ outputs = torch.sigmoid(val_outputs).cpu().numpy()
138
+ labels = val_labels.cpu().numpy()
139
+
140
+ thresholds = []
141
+ for i in range(labels.shape[1]):
142
+ best_f1 = 0
143
+ best_threshold = 0.5
144
+ if labels[:, i].sum() > 0: # Only if we have positive samples
145
+ for threshold in np.arange(0.1, 0.9, 0.05):
146
+ preds = (outputs[:, i] > threshold).astype(int)
147
+ f1 = f1_score(labels[:, i], preds)
148
+ if f1 > best_f1:
149
+ best_f1 = f1
150
+ best_threshold = threshold
151
+ thresholds.append(best_threshold)
152
+ return torch.tensor(thresholds).to(self.device)
153
+
154
+ def calculate_detailed_metrics(self, all_labels, all_preds, all_probs=None):
155
+ """Calculate detailed metrics for model evaluation"""
156
+ metrics = {}
157
+
158
+ # Basic metrics
159
+ metrics['micro'] = {
160
+ 'precision': precision_score(all_labels, all_preds, average='micro'),
161
+ 'recall': recall_score(all_labels, all_preds, average='micro'),
162
+ 'f1': f1_score(all_labels, all_preds, average='micro')
163
+ }
164
+
165
+ metrics['macro'] = {
166
+ 'precision': precision_score(all_labels, all_preds, average='macro'),
167
+ 'recall': recall_score(all_labels, all_preds, average='macro'),
168
+ 'f1': f1_score(all_labels, all_preds, average='macro')
169
+ }
170
+
171
+ metrics['weighted'] = {
172
+ 'precision': precision_score(all_labels, all_preds, average='weighted'),
173
+ 'recall': recall_score(all_labels, all_preds, average='weighted'),
174
+ 'f1': f1_score(all_labels, all_preds, average='weighted')
175
+ }
176
+
177
+ # Per-class metrics
178
+ per_class_metrics = {}
179
+ precisions = precision_score(all_labels, all_preds, average=None)
180
+ recalls = recall_score(all_labels, all_preds, average=None)
181
+ f1s = f1_score(all_labels, all_preds, average=None)
182
+
183
+ for i in range(len(f1s)):
184
+ per_class_metrics[f'class_{i}'] = {
185
+ 'precision': float(precisions[i]),
186
+ 'recall': float(recalls[i]),
187
+ 'f1': float(f1s[i]),
188
+ 'support': int(all_labels[:, i].sum())
189
+ }
190
+
191
+ metrics['per_class'] = per_class_metrics
192
+
193
+ return metrics
194
+
195
  def train_epoch(self):
196
+ """Train for one epoch with optimizations"""
197
  self.model.train()
198
  total_loss = 0
199
  self.optimizer.zero_grad()
 
203
  desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}')
204
 
205
  for step, batch in pbar:
 
206
  batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}
207
 
 
208
  with autocast(enabled=self.config.fp16):
209
  outputs = self.model(
210
  input_ids=batch['input_ids'],
 
214
  loss = self.criterion(outputs, batch['labels'])
215
  loss = loss / self.config.gradient_accumulation_steps
216
 
 
217
  self.scaler.scale(loss).backward()
218
 
 
219
  if (step + 1) % self.config.gradient_accumulation_steps == 0:
220
  self.scaler.unscale_(self.optimizer)
221
  torch.nn.utils.clip_grad_norm_(
 
228
  self.scheduler.step()
229
  self.optimizer.zero_grad()
230
 
 
231
  total_loss += loss.item() * self.config.gradient_accumulation_steps
232
  avg_loss = total_loss / (step + 1)
233
  pbar.set_postfix({'loss': f'{avg_loss:.4f}'})
234
 
235
  self.global_step += 1
236
 
 
237
  if self.global_step % self.config.eval_steps == 0:
238
  self.evaluate()
239
 
 
240
  if step % 10 == 0:
241
  torch.cuda.empty_cache()
242
 
 
243
  del outputs
244
  del loss
245
 
246
  return total_loss / len(self.train_loader)
247
+
248
  @torch.no_grad()
249
  def evaluate(self):
250
+ """Evaluate model with detailed metrics"""
251
  self.model.eval()
252
  total_loss = 0
253
+ all_outputs, all_labels = [], []
254
 
255
  for batch in tqdm(self.val_loader, desc="Evaluating"):
256
  batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}
 
264
  loss = self.criterion(outputs, batch['labels'])
265
 
266
  total_loss += loss.item()
267
+ all_outputs.append(outputs.cpu())
268
+ all_labels.append(batch['labels'].cpu())
269
 
 
 
 
 
 
 
 
 
270
  del outputs
271
  del loss
272
  torch.cuda.empty_cache()
273
 
274
+ all_outputs = torch.cat(all_outputs, dim=0)
275
+ all_labels = torch.cat(all_labels, dim=0)
 
276
 
277
+ if self.global_step % (self.config.eval_steps * 2) == 0:
278
+ self.label_thresholds = self.find_optimal_thresholds(all_outputs, all_labels)
 
 
 
 
279
 
280
+ all_probs = torch.sigmoid(all_outputs).numpy()
281
+ all_preds = (all_probs > self.label_thresholds.cpu().unsqueeze(0).numpy())
282
+ all_labels = all_labels.numpy()
283
 
284
+ metrics = self.calculate_detailed_metrics(all_labels, all_preds, all_probs)
285
+ metrics['loss'] = total_loss / len(self.val_loader)
286
+
287
+ self.logger.info(f"Step {self.global_step} - Validation metrics:")
288
+ self.logger.info(f"Loss: {metrics['loss']:.4f}")
289
+ self.logger.info(f"Micro F1: {metrics['micro']['f1']:.4f}")
290
+ self.logger.info(f"Macro F1: {metrics['macro']['f1']:.4f}")
291
+
292
+ if metrics['micro']['f1'] > self.best_val_f1:
293
+ self.best_val_f1 = metrics['micro']['f1']
294
  self.save_model('best_model.pt', metrics)
295
 
296
  return metrics
297
+
298
  def save_model(self, filename: str, metrics: dict = None):
 
299
  save_path = self.output_dir / filename
300
  torch.save({
301
  'model_state_dict': self.model.state_dict(),
 
305
  'epoch': self.current_epoch,
306
  'global_step': self.global_step,
307
  'best_val_f1': self.best_val_f1,
308
+ 'metrics': metrics,
309
+ 'thresholds': self.label_thresholds
310
  }, save_path)
311
  self.logger.info(f"Model saved to {save_path}")
312
+
313
  def train(self):
314
  """Run complete training loop"""
315
  self.logger.info("Starting training...")
 
322
  self.history['train_loss'].append(train_loss)
323
 
324
  val_metrics = self.evaluate()
325
+ self.history['metrics'].append(val_metrics)
326
+ self.history['thresholds'].append(self.label_thresholds.cpu().tolist())
 
 
327
 
328
  self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics)
329
 
 
330
  history_path = self.output_dir / 'history.json'
331
  with open(history_path, 'w') as f:
332
  json.dump(self.history, f, indent=4)