jaynopponep commited on
Commit
5dbb891
·
1 Parent(s): 6c014d0

Trying class weights

Browse files
Files changed (1) hide show
  1. train.py +11 -0
train.py CHANGED
@@ -4,7 +4,18 @@ from transformers import BertTokenizer, Trainer, TrainingArguments
4
  from datasets import load_dataset
5
  import numpy as np
6
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support
 
 
7
 
 
 
 
 
 
 
 
 
 
8
  # Load dataset dynamically or from a config
9
  dataset_name = "NicolaiSivesind/human-vs-machine"
10
  dataset = load_dataset(dataset_name)
 
4
  from datasets import load_dataset
5
  import numpy as np
6
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support
7
+ from torch.utils.data import DataLoader
8
+ from sklearn.utils.class_weight import compute_class_weight
9
 
10
+ # Other imports and code remain the same...
11
+
12
+ # Compute class weights
13
+ class_weights = compute_class_weight(
14
+ 'balanced', classes=np.unique(train_dataset['labels']), y=train_dataset['labels'])
15
+ class_weights = torch.tensor(class_weights, dtype=torch.float)
16
+
17
+ # Update the model's classifier with class weights
18
+ model.classifier.weight.data = class_weights
19
  # Load dataset dynamically or from a config
20
  dataset_name = "NicolaiSivesind/human-vs-machine"
21
  dataset = load_dataset(dataset_name)