Fine-tuned multilingual BERT for multi-label emotion classification task.
Model was trained on lv_go_emotions dataset. This dataset is Latvian translation of GoEmotions dataset. Google Translate was used to generate the machine translation.
Original 26 emotions were mapped to 6 base emotions as per Dr. Ekman theory.
Labels predicted by classifier:
0: anger
1: disgust
2: fear
3: joy
4: sadness
5: surprise
6: neutral
Label mapping from 27 emotions from GoEmotion to 6 base emotions as per Dr. Ekman theory:
GoEmotion | Ekman |
---|---|
admiration | joy |
amusement | joy |
anger | anger |
annoyance | anger |
approval | joy |
caring | joy |
confusion | surprise |
curiosity | surprise |
desire | joy |
disappointment | sadness |
disapproval | anger |
disgust | disgust |
embarrassment | sadness |
excitement | joy |
fear | fear |
gratitude | joy |
grief | sadness |
joy | joy |
love | joy |
nervousness | fear |
optimism | joy |
pride | joy |
realization | surprise |
relief | joy |
remorse | sadness |
sadness | sadness |
surprise | surprise |
neutral | neutral |
Seed used for random number generator is 42:
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
Training parameters:
max_length: null
batch_size: 64
shuffle: True
num_workers: 8
pin_memory: False
drop_last: False
optimizer: adam
lr: 0.00001
weight_decay: 0
problem_type: multi_label_classification
num_epochs: 4
Evaluation results on test split of lv_go_emotions
Precision | Recall | F1-Score | AUC-ROC | Support | |
---|---|---|---|---|---|
anger | 0.58 | 0.36 | 0.45 | 0.83 | 726 |
disgust | 0.88 | 0.12 | 0.21 | 0.90 | 123 |
fear | 0.75 | 0.48 | 0.58 | 0.93 | 98 |
joy | 0.82 | 0.76 | 0.79 | 0.90 | 2104 |
sadness | 0.69 | 0.46 | 0.55 | 0.88 | 379 |
surprise | 0.61 | 0.51 | 0.55 | 0.87 | 677 |
neutral | 0.65 | 0.62 | 0.64 | 0.83 | 1787 |
micro avg | 0.71 | 0.60 | 0.65 | 0.92 | 5894 |
macro avg | 0.71 | 0.47 | 0.54 | 0.88 | 5894 |
weighted avg | 0.71 | 0.60 | 0.64 | 0.87 | 5894 |
samples avg | 0.63 | 0.62 | 0.62 | nan | 5894 |
- Downloads last month
- 4