Commit
·
5dbb891
1
Parent(s):
6c014d0
Trying class weights
Browse files
train.py
CHANGED
@@ -4,7 +4,18 @@ from transformers import BertTokenizer, Trainer, TrainingArguments
|
|
4 |
from datasets import load_dataset
|
5 |
import numpy as np
|
6 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
|
|
|
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
# Load dataset dynamically or from a config
|
9 |
dataset_name = "NicolaiSivesind/human-vs-machine"
|
10 |
dataset = load_dataset(dataset_name)
|
|
|
4 |
from datasets import load_dataset
|
5 |
import numpy as np
|
6 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from sklearn.utils.class_weight import compute_class_weight
|
9 |
|
10 |
+
# Other imports and code remain the same...
|
11 |
+
|
12 |
+
# Compute class weights
|
13 |
+
class_weights = compute_class_weight(
|
14 |
+
'balanced', classes=np.unique(train_dataset['labels']), y=train_dataset['labels'])
|
15 |
+
class_weights = torch.tensor(class_weights, dtype=torch.float)
|
16 |
+
|
17 |
+
# Update the model's classifier with class weights
|
18 |
+
model.classifier.weight.data = class_weights
|
19 |
# Load dataset dynamically or from a config
|
20 |
dataset_name = "NicolaiSivesind/human-vs-machine"
|
21 |
dataset = load_dataset(dataset_name)
|