|
|
|
--- |
|
license: mit |
|
language: multilingual |
|
library_name: torch |
|
tags: [] |
|
base_model: BAAI/bge-m3 |
|
datasets: |
|
- philipp-zettl/GGU-xx |
|
- philipp-zettl/sentiment |
|
metrics: |
|
|
|
- accuracy |
|
- precision |
|
- recall |
|
- f1-score |
|
model_name: Multi-Head Sequence Classification Model |
|
pipeline_tag: text-classification |
|
widget: |
|
- text: "Hello, how are you?" |
|
label: "[GGU] Greeting" |
|
- text: "Thank you for your help" |
|
label: "[GGU] Gratitude" |
|
- text: "Hallo, wie geht es dir?" |
|
label: "[GGU] Greeting (de)" |
|
- text: "Danke dir." |
|
label: "[GGU] Gratitude (de)" |
|
- text: "I am not sure what you mean" |
|
label: "[GGU] Other" |
|
- text: "Generate me an image of a dog!" |
|
label: "[GGU] Other" |
|
- text: "What is the weather like today?" |
|
label: "[GGU] Other" |
|
- text: "Wie ist das Wetter heute?" |
|
label: "[GGU] Other (de)" |
|
--- |
|
# Multi-Head Sequence Classification Model |
|
## Model description |
|
The model is a simple sequence classification model based on hidden output layers of a pre-trained transformer model. Multiple heads are added to the output of the backbone to classify the input sequence. |
|
|
|
### Model architecture |
|
The model is a simple sequence classification model based on hidden output layers of a pre-trained transformer model. |
|
|
|
The backbone of the model is BAAI/bge-m3 with 1024 output dimensions. |
|
|
|
An additional layer of (GGU: 3, sentiment: 3) is added to the output of the backbone to classify the input sequence. |
|
|
|
You can find a mapping for the labels here: |
|
|
|
**GGU** |
|
- 0: Greeting |
|
- 1: Gratitude |
|
- 2: Other |
|
|
|
**sentiment** |
|
- 0: Positive |
|
- 1: Negative |
|
- 2: Neutral |
|
|
|
The joint architecture was trained using the provided implementation (in repository) of `MultiHeadClassificationTrainer`. |
|
|
|
### Use cases |
|
Use cases: text classification, sentiment analysis. |
|
|
|
## Model Inference |
|
Inference code: |
|
|
|
```python |
|
|
|
from transformers import AutoModel, AutoTokenizer |
|
from .model import MultiHeadSequenceClassificationModel |
|
import torch |
|
|
|
model = MultiHeadSequenceClassificationModel.from_pretrained('philipp-zettl/multi-head-sequence-classification-model') |
|
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3') |
|
|
|
def predict(text): |
|
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True) |
|
outputs = model(**inputs) |
|
return outputs |
|
|
|
``` |
|
|
|
## Model Training |
|
|
|
#### Confusion Matrix |
|
**GGU** |
|
![Confusion Matrix GGU](assets/confusion_matrix_GGU.png) |
|
|
|
**sentiment** |
|
![Confusion Matrix sentiment](assets/confusion_matrix_sentiment.png) |
|
|
|
#### Training Loss |
|
**GGU** |
|
![Loss GGU](assets/loss_plot_GGU.png) |
|
|
|
**sentiment** |
|
![Loss sentiment](assets/loss_plot_sentiment.png) |
|
|
|
|
|
### Training data |
|
The model has been trained on the following datasets: |
|
- [philipp-zettl/GGU-xx](https://huggingface.co/datasets/philipp-zettl/GGU-xx) |
|
- [philipp-zettl/sentiment](https://huggingface.co/datasets/philipp-zettl/sentiment) |
|
|
|
Using the implementation provided by MultiHeadClassificationTrainer |
|
|
|
### Training procedure |
|
The following code has been executed to train the model: |
|
|
|
```python |
|
def train_classifier(): |
|
backbone = AutoModel.from_pretrained('BAAI/bge-m3').to(torch.float16) |
|
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3') |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
ggu_label_map = { |
|
0: 'Greeting', |
|
1: 'Gratitude', |
|
2: 'Other' |
|
} |
|
sentiment_label_map = { |
|
0: 'Positive', |
|
1: 'Negative', |
|
2: 'Neutral' |
|
} |
|
|
|
num_labels = len(ggu_label_map.keys()) |
|
|
|
# HParams |
|
|
|
dropout = 0.25 |
|
learning_rate = 3e-5 |
|
momentum = 0.9 |
|
l2_reg = 0.25 |
|
|
|
l2_loss_weight = 0.25 |
|
|
|
model_conf = { |
|
'backbone': backbone, |
|
'head_config': { |
|
'GGU': num_labels, |
|
}, |
|
'dropout': dropout, |
|
'l2_reg': l2_reg, |
|
} |
|
|
|
optimizer_conf = { |
|
'lr': learning_rate, |
|
'momentum': momentum |
|
} |
|
|
|
scheduler_conf = { |
|
'factor': 0.2, |
|
'patience': 3, |
|
'min_lr': 1e-8 |
|
} |
|
|
|
train_run = 1000 |
|
trainer = MultiHeadClassificationTrainer( |
|
model_conf=model_conf, |
|
optimizer_conf={**optimizer_conf, 'lr': 1e-4}, |
|
scheduler_conf=scheduler_conf, |
|
num_epochs=35, |
|
l2_loss_weight=l2_loss_weight, |
|
use_lr_scheduler=True, |
|
train_run=train_run, |
|
auto_find_batch_size=False |
|
) |
|
|
|
new_model, history = trainer.train(dataset_name='philipp-zettl/GGU-xx', target_heads=['GGU']) |
|
metrics = history['metrics'] |
|
history['loss_plot'] = trainer._plot_history(**metrics) |
|
res = trainer.eval({'GGU': ggu_label_map}) |
|
history['evaluation'] = res['GGU'] |
|
|
|
total_history = { |
|
'GGU': deepcopy(history), |
|
} |
|
|
|
trainer.classifier.add_head('sentiment', 3) |
|
trainer.auto_find_batch_size = False |
|
new_model, history = trainer.train(dataset_name='philipp-zettl/sentiment', target_heads=['sentiment'], sample_key='text', num_epochs=10, lr=1e-4) |
|
metrics = history['metrics'] |
|
history['loss_plot'] = trainer._plot_history(**metrics) |
|
res = trainer.eval({'sentiment': sentiment_label_map}, sample_key='text') |
|
history['evaluation'] = res['sentiment'] |
|
|
|
total_history['sentiment'] = deepcopy(history) |
|
|
|
label_maps = { |
|
'GGU': ggu_label_map, |
|
'sentiment': sentiment_label_map, |
|
} |
|
|
|
return new_model, total_history, trainer, label_maps |
|
|
|
``` |
|
|
|
### Evaluation |
|
### Evaluation data |
|
For model evaluation, a 20% validation split was used from the training data. |
|
|
|
### Evaluation procedure |
|
The model was evaluated using the `eval` method provided by the `MultiHeadClassificationTrainer` class: |
|
|
|
```python |
|
def _eval_model(self, dataloader, label_map, sample_key, label_key): |
|
self.classifier.train(False) |
|
eval_heads = list(label_map.keys()) |
|
y_pred = {h: [] for h in eval_heads} |
|
y_test = {h: [] for h in eval_heads} |
|
for sample in tqdm(dataloader, total=len(dataloader), desc='Evaluating model...'): |
|
labels = {name: sample[label_key] for name in eval_heads} |
|
embeddings = BatchEncoding({k: torch.stack(v, dim=1).to(self.device) for k, v in sample.items() if k not in [label_key, sample_key]}) |
|
output = self.classifier(embeddings.to('cuda'), head_names=eval_heads) |
|
for head in eval_heads: |
|
y_pred[head].extend(output[head].argmax(dim=1).cpu()) |
|
y_test[head].extend(labels[head]) |
|
torch.cuda.empty_cache() |
|
|
|
accuracies = {h: accuracy_score(y_test[h], y_pred[h]) for h in eval_heads} |
|
f1_scores = {h: f1_score(y_test[h], y_pred[h], average="macro") for h in eval_heads} |
|
recalls = {h: recall_score(y_test[h], y_pred[h], average='macro') for h in eval_heads} |
|
|
|
report = {} |
|
for head in eval_heads: |
|
cm = confusion_matrix(y_test[head], y_pred[head], labels=list(label_map[head].keys())) |
|
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(label_map[head].values())) |
|
clf_report = classification_report( |
|
y_test[head], y_pred[head], output_dict=True, target_names=list(label_map[head].values()) |
|
) |
|
del clf_report["accuracy"] |
|
clf_report = pd.DataFrame(clf_report).T.reset_index() |
|
report[head] = dict( |
|
clf_report=clf_report, confusion_matrix=disp, metrics={'accuracy': accuracies[head], 'f1': f1_scores[head], 'recall': recalls[head]} |
|
) |
|
return report |
|
|
|
``` |
|
|
|
### Metrics |
|
For evaluation, we used the following metrics: accuracy, precision, recall, f1-score. You can find a detailed classification report here: |
|
|
|
**GGU:** |
|
| | index | precision | recall | f1-score | support | |
|
|---:|:-------------|------------:|---------:|-----------:|----------:| |
|
| 0 | Greeting | 0.904762 | 0.974359 | 0.938272 | 39 | |
|
| 1 | Gratitude | 0.958333 | 0.851852 | 0.901961 | 27 | |
|
| 2 | Other | 1 | 1 | 1 | 39 | |
|
| 3 | macro avg | 0.954365 | 0.94207 | 0.946744 | 105 | |
|
| 4 | weighted avg | 0.953912 | 0.952381 | 0.951862 | 105 | |
|
|
|
**sentiment:** |
|
| | index | precision | recall | f1-score | support | |
|
|---:|:-------------|------------:|---------:|-----------:|----------:| |
|
| 0 | Positive | 0.783088 | 0.861878 | 0.820596 | 12851 | |
|
| 1 | Negative | 0.802105 | 0.819524 | 0.810721 | 14229 | |
|
| 2 | Neutral | 0.7874 | 0.6913 | 0.736227 | 13126 | |
|
| 3 | macro avg | 0.790864 | 0.790901 | 0.789181 | 40206 | |
|
| 4 | weighted avg | 0.791226 | 0.7912 | 0.789557 | 40206 | |
|
|
|
|