Spaces:
Runtime error
Runtime error
Mohammaderfan koupaei
commited on
Commit
·
61d58d1
1
Parent(s):
06ca50d
second
Browse files
scripts/training/trainer.py
CHANGED
@@ -45,7 +45,7 @@ class NarrativeTrainer:
|
|
45 |
self.best_val_f1 = 0.0
|
46 |
|
47 |
# Initialize mixed precision training
|
48 |
-
self.scaler = GradScaler(enabled=self.config.fp16)
|
49 |
|
50 |
# Setup training components
|
51 |
self.setup_training()
|
@@ -99,10 +99,9 @@ class NarrativeTrainer:
|
|
99 |
# Calculate class weights
|
100 |
pos_weight = self.calculate_class_weights()
|
101 |
|
102 |
-
# Setup loss function with class weights
|
103 |
self.criterion = torch.nn.BCEWithLogitsLoss(
|
104 |
-
pos_weight=pos_weight
|
105 |
-
label_smoothing=self.config.label_smoothing
|
106 |
)
|
107 |
|
108 |
# Setup optimizer
|
|
|
45 |
self.best_val_f1 = 0.0
|
46 |
|
47 |
# Initialize mixed precision training
|
48 |
+
self.scaler = GradScaler('cuda', enabled=self.config.fp16)
|
49 |
|
50 |
# Setup training components
|
51 |
self.setup_training()
|
|
|
99 |
# Calculate class weights
|
100 |
pos_weight = self.calculate_class_weights()
|
101 |
|
102 |
+
# Setup loss function with class weights only
|
103 |
self.criterion = torch.nn.BCEWithLogitsLoss(
|
104 |
+
pos_weight=pos_weight
|
|
|
105 |
)
|
106 |
|
107 |
# Setup optimizer
|