Mohammaderfan koupaei commited on
Commit
937a410
·
1 Parent(s): 3ab6d8e
Files changed (4) hide show
  1. app.py +74 -45
  2. requirements.txt +1 -0
  3. scripts/config/config.py +4 -13
  4. scripts/training/trainer.py +128 -104
app.py CHANGED
@@ -1,44 +1,51 @@
1
  import sys
2
  import logging
3
  from pathlib import Path
 
 
4
  from transformers import set_seed
5
 
 
 
 
 
6
  # Import the necessary modules from your project
7
- sys.path.append("./scripts") # Adjust path if needed
8
  from scripts.models.model import NarrativeClassifier
9
  from scripts.models.dataset import NarrativeDataset
10
  from scripts.config.config import TrainingConfig
11
- from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor
12
  from scripts.training.trainer import NarrativeTrainer
13
 
 
 
 
 
 
 
 
 
 
14
  def main():
15
  # Set up logging
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
  logger.info("Initializing training process...")
19
- import os
20
-
21
 
22
- # Set up logging
23
- logging.basicConfig(level=logging.INFO)
24
- logger = logging.getLogger(__name__)
25
- logger.info("Initializing training process...")
26
- import os
27
- import spacy
28
-
29
- # Download and load SpaCy model dynamically
30
- try:
31
- spacy.load("en_core_web_sm")
32
- except OSError:
33
- logger.info("Downloading SpaCy model 'en_core_web_sm'...")
34
- os.system("python -m spacy download en_core_web_sm")
35
-
36
- # Set a random seed for reproducibility
37
  set_seed(42)
 
 
 
 
 
 
 
 
 
38
 
39
  # Load and process the dataset
40
- annotations_file = "./data/subtask-2-annotations.txt" # Adjust path as needed
41
- raw_dir = "./data/raw" # Adjust path as needed
42
  logger.info("Loading and processing dataset...")
43
 
44
  processor = AdvancedNarrativeProcessor(
@@ -47,41 +54,63 @@ def main():
47
  )
48
  processed_data = processor.load_and_process_data()
49
 
50
- # Split processed data into training and validation sets
51
  train_dataset = NarrativeDataset(processed_data['train'])
52
  val_dataset = NarrativeDataset(processed_data['val'])
53
  logger.info(f"Loaded dataset with {len(train_dataset)} training samples and {len(val_dataset)} validation samples.")
54
 
55
- # Initialize the model
56
  logger.info("Initializing the model...")
57
- model = NarrativeClassifier(num_labels=train_dataset.get_num_labels())
 
 
 
58
 
59
- # Define training configuration
60
  config = TrainingConfig(
61
- output_dir=Path("./output"), # Save outputs in this directory
62
  num_epochs=5,
63
- batch_size=16,
64
  learning_rate=2e-5,
65
  warmup_ratio=0.1,
66
  weight_decay=0.01,
67
  max_grad_norm=1.0,
68
- eval_steps=100,
69
- save_steps=100
 
 
 
70
  )
71
- logger.info(f"Training configuration: {config}")
 
 
72
 
73
- # Initialize the trainer
74
- trainer = NarrativeTrainer(
75
- model=model,
76
- train_dataset=train_dataset,
77
- val_dataset=val_dataset,
78
- config=config
79
- )
80
-
81
- # Start the training process
82
- logger.info("Starting the training process...")
83
- trainer.train()
84
- logger.info("Training completed successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
- main()
 
1
  import sys
2
  import logging
3
  from pathlib import Path
4
+ import os
5
+ import torch
6
  from transformers import set_seed
7
 
8
+ # Set environment variables for memory optimization
9
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
10
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
11
+
12
  # Import the necessary modules from your project
13
+ sys.path.append("./scripts")
14
  from scripts.models.model import NarrativeClassifier
15
  from scripts.models.dataset import NarrativeDataset
16
  from scripts.config.config import TrainingConfig
17
+ from scripts.data_processing.advanced_preprocessor import AdvancedNarrativeProcessor
18
  from scripts.training.trainer import NarrativeTrainer
19
 
20
+ def setup_logging():
21
+ """Setup logging configuration"""
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(levelname)s - %(message)s',
25
+ datefmt='%Y-%m-%d %H:%M:%S'
26
+ )
27
+ return logging.getLogger(__name__)
28
+
29
  def main():
30
  # Set up logging
31
+ logger = setup_logging()
 
32
  logger.info("Initializing training process...")
 
 
33
 
34
+ # Set random seeds for reproducibility
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  set_seed(42)
36
+ torch.manual_seed(42)
37
+ if torch.cuda.is_available():
38
+ torch.cuda.manual_seed_all(42)
39
+
40
+ # Clear GPU cache if available
41
+ if torch.cuda.is_available():
42
+ torch.cuda.empty_cache()
43
+ logger.info(f"CUDA available. Using GPU: {torch.cuda.get_device_name(0)}")
44
+ logger.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
45
 
46
  # Load and process the dataset
47
+ annotations_file = "./data/subtask-2-annotations.txt"
48
+ raw_dir = "./data/raw"
49
  logger.info("Loading and processing dataset...")
50
 
51
  processor = AdvancedNarrativeProcessor(
 
54
  )
55
  processed_data = processor.load_and_process_data()
56
 
57
+ # Create datasets
58
  train_dataset = NarrativeDataset(processed_data['train'])
59
  val_dataset = NarrativeDataset(processed_data['val'])
60
  logger.info(f"Loaded dataset with {len(train_dataset)} training samples and {len(val_dataset)} validation samples.")
61
 
62
+ # Initialize model
63
  logger.info("Initializing the model...")
64
+ model = NarrativeClassifier(
65
+ num_labels=train_dataset.get_num_labels(),
66
+ model_name="microsoft/deberta-v3-large"
67
+ )
68
 
69
+ # Define optimized training configuration
70
  config = TrainingConfig(
71
+ output_dir=Path("./output"),
72
  num_epochs=5,
73
+ batch_size=4, # Reduced batch size for memory
74
  learning_rate=2e-5,
75
  warmup_ratio=0.1,
76
  weight_decay=0.01,
77
  max_grad_norm=1.0,
78
+ eval_steps=50,
79
+ save_steps=50,
80
+ fp16=True, # Enable mixed precision
81
+ gradient_accumulation_steps=4, # Gradient accumulation
82
+ max_length=256 # Reduced sequence length
83
  )
84
+ logger.info("Training configuration:")
85
+ for key, value in vars(config).items():
86
+ logger.info(f" {key}: {value}")
87
 
88
+ try:
89
+ # Initialize trainer
90
+ trainer = NarrativeTrainer(
91
+ model=model,
92
+ train_dataset=train_dataset,
93
+ val_dataset=val_dataset,
94
+ config=config
95
+ )
96
+
97
+ # Start training
98
+ logger.info("Starting the training process...")
99
+ history = trainer.train()
100
+
101
+ # Log final metrics
102
+ logger.info("Training completed successfully!")
103
+ logger.info("Final metrics:")
104
+ logger.info(f" Best validation F1: {trainer.best_val_f1:.4f}")
105
+ logger.info(f" Final training loss: {history['train_loss'][-1]:.4f}")
106
+
107
+ except Exception as e:
108
+ logger.error(f"Training failed with error: {str(e)}")
109
+ raise
110
+ finally:
111
+ # Clean up
112
+ if torch.cuda.is_available():
113
+ torch.cuda.empty_cache()
114
 
115
  if __name__ == "__main__":
116
+ main()
requirements.txt CHANGED
@@ -8,3 +8,4 @@ sentencepiece
8
  pandas
9
  numpy
10
  spacy
 
 
8
  pandas
9
  numpy
10
  spacy
11
+ accelerate
scripts/config/config.py CHANGED
@@ -12,11 +12,14 @@ class TrainingConfig:
12
 
13
  # Training parameters
14
  num_epochs: int = 5
15
- batch_size: int = 8
16
  learning_rate: float = 2e-5
17
  warmup_ratio: float = 0.1
18
  weight_decay: float = 0.01
19
  max_grad_norm: float = 1.0
 
 
 
 
20
 
21
  # Data parameters
22
  max_length: int = 512
@@ -45,15 +48,3 @@ if __name__ == "__main__":
45
  print(f"Learning rate: {default_config.learning_rate}")
46
  print(f"Device: {default_config.device}")
47
 
48
- # Create custom config
49
- custom_config = TrainingConfig(
50
- batch_size=16,
51
- num_epochs=10,
52
- learning_rate=1e-5
53
- )
54
-
55
- print("\n=== Custom Configuration ===")
56
- print(f"Model name: {custom_config.model_name}") # Uses default
57
- print(f"Batch size: {custom_config.batch_size}") # Customized
58
- print(f"Learning rate: {custom_config.learning_rate}") # Customized
59
- print(f"Number of epochs: {custom_config.num_epochs}") # Customized
 
12
 
13
  # Training parameters
14
  num_epochs: int = 5
 
15
  learning_rate: float = 2e-5
16
  warmup_ratio: float = 0.1
17
  weight_decay: float = 0.01
18
  max_grad_norm: float = 1.0
19
+ gradient_accumulation_steps: int = 4
20
+ fp16: bool = True # Enable mixed precision training
21
+ max_length: int = 256 # Reduce from 512
22
+ batch_size: int = 4 # Reduce from 8
23
 
24
  # Data parameters
25
  max_length: int = 512
 
48
  print(f"Learning rate: {default_config.learning_rate}")
49
  print(f"Device: {default_config.device}")
50
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/training/trainer.py CHANGED
@@ -8,12 +8,10 @@ import numpy as np
8
  from sklearn.metrics import f1_score, precision_score, recall_score
9
  import json
10
  from datetime import datetime
11
-
12
 
13
  class NarrativeTrainer:
14
- """
15
- Comprehensive trainer for narrative classification with GPU support.
16
- """
17
  def __init__(
18
  self,
19
  model,
@@ -21,29 +19,43 @@ class NarrativeTrainer:
21
  val_dataset,
22
  config,
23
  ):
 
24
  self.setup_logging()
25
  self.logger = logging.getLogger(__name__)
26
 
27
- # Set device
 
 
 
28
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
  self.logger.info(f"Using device: {self.device}")
30
 
 
 
 
 
31
  # Initialize model and components
32
  self.model = model.to(self.device)
33
  self.train_dataset = train_dataset
34
  self.val_dataset = val_dataset
35
- self.config = config
36
 
 
37
  self.current_epoch = 0
38
  self.global_step = 0
39
  self.best_val_f1 = 0.0
40
 
 
 
 
 
41
  self.setup_training()
42
 
 
43
  self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
44
  self.output_dir = Path(config.output_dir) / self.timestamp
45
  self.output_dir.mkdir(parents=True, exist_ok=True)
46
 
 
47
  self.save_config()
48
  self.history = {
49
  'train_loss': [],
@@ -54,6 +66,7 @@ class NarrativeTrainer:
54
  }
55
 
56
  def setup_logging(self):
 
57
  logging.basicConfig(
58
  level=logging.INFO,
59
  format='%(asctime)s - %(levelname)s - %(message)s',
@@ -61,27 +74,33 @@ class NarrativeTrainer:
61
  )
62
 
63
  def setup_training(self):
64
- """Initialize dataloaders, optimizer, and scheduler."""
 
65
  self.train_loader = DataLoader(
66
  self.train_dataset,
67
  batch_size=self.config.batch_size,
68
  shuffle=True,
69
- num_workers=4
 
70
  )
71
 
72
  self.val_loader = DataLoader(
73
  self.val_dataset,
74
  batch_size=self.config.batch_size,
75
- num_workers=4
 
76
  )
77
 
 
78
  self.optimizer = torch.optim.AdamW(
79
  self.model.parameters(),
80
  lr=self.config.learning_rate,
81
  weight_decay=self.config.weight_decay
82
  )
83
 
84
- num_training_steps = len(self.train_loader) * self.config.num_epochs
 
 
85
  num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
86
 
87
  self.scheduler = get_linear_schedule_with_warmup(
@@ -93,66 +112,106 @@ class NarrativeTrainer:
93
  self.criterion = torch.nn.BCEWithLogitsLoss()
94
 
95
  def save_config(self):
96
- """Save training configuration."""
97
  config_dict = {k: str(v) for k, v in vars(self.config).items()}
98
  config_path = self.output_dir / 'config.json'
99
  with open(config_path, 'w') as f:
100
  json.dump(config_dict, f, indent=4)
101
 
102
  def train_epoch(self):
103
- """Train model for one epoch."""
104
  self.model.train()
105
  total_loss = 0
106
- pbar = tqdm(self.train_loader, desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}')
107
 
108
- for batch in pbar:
109
- batch = {k: v.to(self.device) for k, v in batch.items()}
 
 
 
 
 
110
 
111
- self.optimizer.zero_grad()
112
- outputs = self.model(
113
- input_ids=batch['input_ids'],
114
- attention_mask=batch['attention_mask'],
115
- features=batch['features']
116
- )
 
 
 
117
 
118
- loss = self.criterion(outputs, batch['labels'])
119
- loss.backward()
120
 
121
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
122
- self.optimizer.step()
123
- self.scheduler.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- total_loss += loss.item()
126
- pbar.set_postfix({'loss': total_loss / (pbar.n + 1)})
127
  self.global_step += 1
128
 
 
129
  if self.global_step % self.config.eval_steps == 0:
130
  self.evaluate()
 
 
 
 
 
 
 
 
131
 
132
  return total_loss / len(self.train_loader)
133
 
134
  @torch.no_grad()
135
  def evaluate(self):
136
- """Evaluate model performance."""
137
  self.model.eval()
138
  total_loss = 0
139
  all_preds, all_labels = [], []
140
 
141
  for batch in tqdm(self.val_loader, desc="Evaluating"):
142
- batch = {k: v.to(self.device) for k, v in batch.items()}
143
- outputs = self.model(
144
- input_ids=batch['input_ids'],
145
- attention_mask=batch['attention_mask'],
146
- features=batch['features']
147
- )
 
 
 
148
 
149
- loss = self.criterion(outputs, batch['labels'])
150
  total_loss += loss.item()
151
 
152
- preds = torch.sigmoid(outputs) > 0.5
153
- all_preds.append(preds.cpu().numpy())
154
- all_labels.append(batch['labels'].cpu().numpy())
 
 
 
 
 
 
 
 
155
 
 
156
  all_preds = np.concatenate(all_preds, axis=0)
157
  all_labels = np.concatenate(all_labels, axis=0)
158
 
@@ -172,11 +231,13 @@ class NarrativeTrainer:
172
  return metrics
173
 
174
  def save_model(self, filename: str, metrics: dict = None):
 
175
  save_path = self.output_dir / filename
176
  torch.save({
177
  'model_state_dict': self.model.state_dict(),
178
  'optimizer_state_dict': self.optimizer.state_dict(),
179
  'scheduler_state_dict': self.scheduler.state_dict(),
 
180
  'epoch': self.current_epoch,
181
  'global_step': self.global_step,
182
  'best_val_f1': self.best_val_f1,
@@ -185,71 +246,34 @@ class NarrativeTrainer:
185
  self.logger.info(f"Model saved to {save_path}")
186
 
187
  def train(self):
188
- """Run training for all epochs."""
189
  self.logger.info("Starting training...")
190
- for epoch in range(self.config.num_epochs):
191
- self.current_epoch = epoch
192
- train_loss = self.train_epoch()
193
- self.history['train_loss'].append(train_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- val_metrics = self.evaluate()
196
- self.history['val_loss'].append(val_metrics['loss'])
197
- self.history['val_f1'].append(val_metrics['f1'])
198
- self.history['val_precision'].append(val_metrics['precision'])
199
- self.history['val_recall'].append(val_metrics['recall'])
200
 
201
- self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics)
202
- history_path = self.output_dir / 'history.json'
203
- with open(history_path, 'w') as f:
204
- json.dump(self.history, f, indent=4)
205
-
206
- self.logger.info("Training completed!")
207
- return self.history
208
-
209
-
210
- if __name__ == "__main__":
211
- import sys
212
- sys.path.append("../../")
213
- from scripts.models.model import NarrativeClassifier
214
- from scripts.models.dataset import NarrativeDataset
215
- from scripts.config.config import TrainingConfig
216
- from scripts.data_processing.data_preparation import AdvancedNarrativeProcessor
217
-
218
- # Initialize training configuration
219
- config = TrainingConfig(
220
- output_dir=Path("./output"),
221
- num_epochs=5,
222
- batch_size=32,
223
- learning_rate=5e-5,
224
- weight_decay=0.01,
225
- warmup_ratio=0.1,
226
- max_grad_norm=1.0,
227
- eval_steps=100
228
- )
229
-
230
- # Load and process data
231
- processor = AdvancedNarrativeProcessor(
232
- annotations_file="../../data/subtask-2-annotations.txt",
233
- raw_dir="../../data/raw"
234
- )
235
- processed_data = processor.load_and_process_data()
236
-
237
- # Create datasets
238
- train_dataset = NarrativeDataset(processed_data['train'])
239
- val_dataset = NarrativeDataset(processed_data['val'])
240
-
241
- # Initialize model
242
- model = NarrativeClassifier(num_labels=train_dataset.get_num_labels())
243
-
244
- # Initialize trainer
245
- trainer = NarrativeTrainer(
246
- model=model,
247
- train_dataset=train_dataset,
248
- val_dataset=val_dataset,
249
- config=config
250
- )
251
-
252
- # Start full training
253
- print("\n=== Starting Training ===")
254
- trainer.train()
255
- print("\nTraining completed successfully!")
 
8
  from sklearn.metrics import f1_score, precision_score, recall_score
9
  import json
10
  from datetime import datetime
11
+ from torch.cuda.amp import autocast, GradScaler
12
 
13
  class NarrativeTrainer:
14
+ """Comprehensive trainer for narrative classification with GPU memory optimizations"""
 
 
15
  def __init__(
16
  self,
17
  model,
 
19
  val_dataset,
20
  config,
21
  ):
22
+ # Setup basics
23
  self.setup_logging()
24
  self.logger = logging.getLogger(__name__)
25
 
26
+ # Store config first
27
+ self.config = config
28
+
29
+ # Setup device
30
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
  self.logger.info(f"Using device: {self.device}")
32
 
33
+ # Clear GPU cache if using CUDA
34
+ if torch.cuda.is_available():
35
+ torch.cuda.empty_cache()
36
+
37
  # Initialize model and components
38
  self.model = model.to(self.device)
39
  self.train_dataset = train_dataset
40
  self.val_dataset = val_dataset
 
41
 
42
+ # Initialize training state
43
  self.current_epoch = 0
44
  self.global_step = 0
45
  self.best_val_f1 = 0.0
46
 
47
+ # Initialize mixed precision training
48
+ self.scaler = GradScaler(enabled=self.config.fp16)
49
+
50
+ # Setup training components
51
  self.setup_training()
52
 
53
+ # Setup output directory
54
  self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
55
  self.output_dir = Path(config.output_dir) / self.timestamp
56
  self.output_dir.mkdir(parents=True, exist_ok=True)
57
 
58
+ # Save config and initialize history
59
  self.save_config()
60
  self.history = {
61
  'train_loss': [],
 
66
  }
67
 
68
  def setup_logging(self):
69
+ """Initialize logging configuration"""
70
  logging.basicConfig(
71
  level=logging.INFO,
72
  format='%(asctime)s - %(levelname)s - %(message)s',
 
74
  )
75
 
76
  def setup_training(self):
77
+ """Initialize training components with memory optimizations"""
78
+ # Create dataloaders
79
  self.train_loader = DataLoader(
80
  self.train_dataset,
81
  batch_size=self.config.batch_size,
82
  shuffle=True,
83
+ num_workers=4,
84
+ pin_memory=True # Optimize data transfer to GPU
85
  )
86
 
87
  self.val_loader = DataLoader(
88
  self.val_dataset,
89
  batch_size=self.config.batch_size,
90
+ num_workers=4,
91
+ pin_memory=True
92
  )
93
 
94
+ # Setup optimizer
95
  self.optimizer = torch.optim.AdamW(
96
  self.model.parameters(),
97
  lr=self.config.learning_rate,
98
  weight_decay=self.config.weight_decay
99
  )
100
 
101
+ # Setup scheduler with gradient accumulation steps
102
+ num_update_steps_per_epoch = len(self.train_loader) // self.config.gradient_accumulation_steps
103
+ num_training_steps = num_update_steps_per_epoch * self.config.num_epochs
104
  num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
105
 
106
  self.scheduler = get_linear_schedule_with_warmup(
 
112
  self.criterion = torch.nn.BCEWithLogitsLoss()
113
 
114
  def save_config(self):
115
+ """Save training configuration"""
116
  config_dict = {k: str(v) for k, v in vars(self.config).items()}
117
  config_path = self.output_dir / 'config.json'
118
  with open(config_path, 'w') as f:
119
  json.dump(config_dict, f, indent=4)
120
 
121
  def train_epoch(self):
122
+ """Train for one epoch with memory optimizations"""
123
  self.model.train()
124
  total_loss = 0
125
+ self.optimizer.zero_grad()
126
 
127
+ pbar = tqdm(enumerate(self.train_loader),
128
+ total=len(self.train_loader),
129
+ desc=f'Epoch {self.current_epoch + 1}/{self.config.num_epochs}')
130
+
131
+ for step, batch in pbar:
132
+ # Move batch to device
133
+ batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}
134
 
135
+ # Mixed precision forward pass
136
+ with autocast(enabled=self.config.fp16):
137
+ outputs = self.model(
138
+ input_ids=batch['input_ids'],
139
+ attention_mask=batch['attention_mask'],
140
+ features=batch['features']
141
+ )
142
+ loss = self.criterion(outputs, batch['labels'])
143
+ loss = loss / self.config.gradient_accumulation_steps
144
 
145
+ # Scaled backward pass
146
+ self.scaler.scale(loss).backward()
147
 
148
+ # Update weights if we've accumulated enough gradients
149
+ if (step + 1) % self.config.gradient_accumulation_steps == 0:
150
+ self.scaler.unscale_(self.optimizer)
151
+ torch.nn.utils.clip_grad_norm_(
152
+ self.model.parameters(),
153
+ self.config.max_grad_norm
154
+ )
155
+
156
+ self.scaler.step(self.optimizer)
157
+ self.scaler.update()
158
+ self.scheduler.step()
159
+ self.optimizer.zero_grad()
160
+
161
+ # Update metrics
162
+ total_loss += loss.item() * self.config.gradient_accumulation_steps
163
+ avg_loss = total_loss / (step + 1)
164
+ pbar.set_postfix({'loss': f'{avg_loss:.4f}'})
165
 
 
 
166
  self.global_step += 1
167
 
168
+ # Evaluate if needed
169
  if self.global_step % self.config.eval_steps == 0:
170
  self.evaluate()
171
+
172
+ # Clear memory periodically
173
+ if step % 10 == 0:
174
+ torch.cuda.empty_cache()
175
+
176
+ # Clear unnecessary tensors
177
+ del outputs
178
+ del loss
179
 
180
  return total_loss / len(self.train_loader)
181
 
182
  @torch.no_grad()
183
  def evaluate(self):
184
+ """Evaluate model with memory optimizations"""
185
  self.model.eval()
186
  total_loss = 0
187
  all_preds, all_labels = [], []
188
 
189
  for batch in tqdm(self.val_loader, desc="Evaluating"):
190
+ batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}
191
+
192
+ with autocast(enabled=self.config.fp16):
193
+ outputs = self.model(
194
+ input_ids=batch['input_ids'],
195
+ attention_mask=batch['attention_mask'],
196
+ features=batch['features']
197
+ )
198
+ loss = self.criterion(outputs, batch['labels'])
199
 
 
200
  total_loss += loss.item()
201
 
202
+ # CPU computations for predictions
203
+ preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy()
204
+ labels = batch['labels'].cpu().numpy()
205
+
206
+ all_preds.append(preds)
207
+ all_labels.append(labels)
208
+
209
+ # Clear memory
210
+ del outputs
211
+ del loss
212
+ torch.cuda.empty_cache()
213
 
214
+ # Compute metrics
215
  all_preds = np.concatenate(all_preds, axis=0)
216
  all_labels = np.concatenate(all_labels, axis=0)
217
 
 
231
  return metrics
232
 
233
  def save_model(self, filename: str, metrics: dict = None):
234
+ """Save model checkpoint"""
235
  save_path = self.output_dir / filename
236
  torch.save({
237
  'model_state_dict': self.model.state_dict(),
238
  'optimizer_state_dict': self.optimizer.state_dict(),
239
  'scheduler_state_dict': self.scheduler.state_dict(),
240
+ 'scaler_state_dict': self.scaler.state_dict(),
241
  'epoch': self.current_epoch,
242
  'global_step': self.global_step,
243
  'best_val_f1': self.best_val_f1,
 
246
  self.logger.info(f"Model saved to {save_path}")
247
 
248
  def train(self):
249
+ """Run complete training loop"""
250
  self.logger.info("Starting training...")
251
+ try:
252
+ for epoch in range(self.config.num_epochs):
253
+ self.current_epoch = epoch
254
+ self.logger.info(f"Starting epoch {epoch + 1}/{self.config.num_epochs}")
255
+
256
+ train_loss = self.train_epoch()
257
+ self.history['train_loss'].append(train_loss)
258
+
259
+ val_metrics = self.evaluate()
260
+ self.history['val_loss'].append(val_metrics['loss'])
261
+ self.history['val_f1'].append(val_metrics['f1'])
262
+ self.history['val_precision'].append(val_metrics['precision'])
263
+ self.history['val_recall'].append(val_metrics['recall'])
264
+
265
+ self.save_model(f'checkpoint_epoch_{epoch+1}.pt', val_metrics)
266
+
267
+ # Save training history
268
+ history_path = self.output_dir / 'history.json'
269
+ with open(history_path, 'w') as f:
270
+ json.dump(self.history, f, indent=4)
271
+
272
+ self.logger.info(f"Epoch {epoch + 1} completed. Train loss: {train_loss:.4f}")
273
 
274
+ self.logger.info("Training completed successfully!")
275
+ return self.history
 
 
 
276
 
277
+ except Exception as e:
278
+ self.logger.error(f"Training failed with error: {str(e)}")
279
+ raise