|
--- |
|
license: apache-2.0 |
|
tags: |
|
- transformers |
|
- pytorch |
|
datasets: |
|
- conv_ai_2 |
|
model-index: |
|
- name: distillbert_conv_quality_score |
|
results: [] |
|
language: |
|
- en |
|
--- |
|
|
|
|
|
# distillbert_conv_quality_score |
|
|
|
This model is a fine-tuned version of [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) on the conv_ai_2 dataset. |
|
It was trained to Generate a score from a conversation. The score is a float between 0 and 1. |
|
|
|
|
|
It achieves the following results on the evaluation set: |
|
- training/loss: 0.0165 |
|
- validation/loss: 0.0149 |
|
|
|
## Model description |
|
|
|
More information needed |
|
|
|
## Usage |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
model_name = "alespalla/distillbert_conv_quality_score" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
conversation = ''' |
|
Q: Begin |
|
A: lol ! do you think it is strange to feel like you have been through life before ? |
|
Q: Hellow |
|
A: I don't understand you ๐. Also, try to guess: i like to ... |
|
Q: How are you? |
|
A: make time stop, funny you :) |
|
Q: What is your name? |
|
A: jessie. hows your day going ? ๐ |
|
''' |
|
|
|
score = model(**tokenizer(conversation, return_tensors='pt')).logits.item() |
|
print(f"Score: {score}") |
|
``` |
|
|
|
## Training and evaluation data |
|
|
|
The training data was generated from `conv_ai_2` using the following function |
|
|
|
```python |
|
|
|
from datasets import load_dataset |
|
|
|
def get_dataset(regression=False): |
|
|
|
db = load_dataset("conv_ai_2") |
|
|
|
def generate_converation(elem): |
|
text = "" |
|
for idx, txt in enumerate(elem["dialog"]): |
|
if idx % 2: |
|
text += f"A: {txt['text']}\n" |
|
else: |
|
text += f"Q: {txt['text']}\n" |
|
if regression: |
|
return {'text': text, "labels": (elem['eval_score'] - 1)/4} |
|
return {'text': text, "labels": elem['eval_score'] - 1} |
|
|
|
db = db.filter(lambda example: example["eval_score"] > 0) |
|
db = db.map(generate_converation, remove_columns=db['train'].column_names) |
|
db = db['train'].train_test_split(test_size=0.2).shuffle(42) |
|
|
|
return db |
|
|
|
``` |
|
|
|
## Training procedure |
|
|
|
### Training hyperparameters |
|
|
|
The following hyperparameters were used during training: |
|
- epochs: 40 |
|
- batch_size: 16 |
|
- learning_rate: 0.0002 |
|
- eval_steps: 82 |
|
- log_steps: 82 |
|
- save_steps: 41 |
|
- gradient_accumulation_steps: 1 |
|
- warmup_steps: 0 |
|
|
|
### Training results |
|
|
|
| step | training/loss | validation/loss | |
|
|:----:|:-------------:|:---------------:| |
|
| 81 | 0.1020 | 0.0794 | |
|
| 163 | 0.0800 | 0.0713 | |
|
| 245 | 0.0553 | 0.0491 | |
|
| 327 | 0.0362 | 0.0440 | |
|
| 409 | 0.0282 | 0.0352 | |
|
| 491 | 0.0282 | 0.0412 | |
|
| 573 | 0.0256 | 0.0293 | |
|
| 655 | 0.0238 | 0.0252 | |
|
| 737 | 0.0175 | 0.0226 | |
|
| 819 | 0.0154 | 0.0228 | |
|
| 901 | 0.0116 | 0.0205 | |
|
| 983 | 0.0160 | 0.0202 | |
|
| 1065 | 0.0146 | 0.0240 | |
|
| 1147 | 0.0182 | 0.0180 | |
|
| 1229 | 0.0171 | 0.0192 | |
|
| 1311 | 0.0091 | 0.0174 | |
|
| 1393 | 0.0171 | 0.0158 | |
|
| 1475 | 0.0137 | 0.0158 | |
|
| 1557 | 0.0158 | 0.0148 | |
|
| 1639 | 0.0165 | 0.0149 | |
|
|
|
|
|
### Framework versions |
|
|
|
- Transformers 4.26.1 |
|
- Datasets 2.10.1 |
|
- Tokenizers 0.13.2 |