# Fine-tuning for multi-label text classification
**Note: This notebook was run in Google Colab**

This notebook demonstrates how to fine-tune a `bert-base-uncased` model using this Kaggle [dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge)

The Colab link is here: https://colab.research.google.com/drive/1_tOvmArkigdQpxhZhzVIhR58InDHrxPz

## Setup Environment
We first install and import all the necessary libraries and modules.

In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m51.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m84.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1


---------------------------------------------------------

In [None]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Load Dataset
Here we extract training and validation datasets from `train.csv`

In [None]:
# Read dataset and extract all texts and labels
df = pd.read_csv("/content/drive/MyDrive/AI_project/data/train.csv")

train_texts = df["comment_text"].values
labels = df.columns[2:]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
train_labels = df[labels].values

# Randomly select training texts and respective labels
np.random.seed(18)
small_train_texts = np.random.choice(train_texts, size=25000, replace=False)

np.random.seed(18)
small_train_labels_idx = np.random.choice(train_labels.shape[0], size=25000, replace=False)
small_train_labels = train_labels[small_train_labels_idx, :]

# Split data into training data and validation data with a percentage of 80% vs 20%
train_texts, val_texts, train_labels, val_labels = train_test_split(small_train_texts, small_train_labels, test_size=.2)

## Data Preprocessing
As models like BERT don't expect text as direct input, but rather `input_ids`, etc., we tokenize the text using the tokenizer. The `AutoTokenizer` will automatically load the appropriate tokenizer based on the checkpoint on the hub. We can now merge the labels and texts to datasets as a class we defined.

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
# Define a class of dataset for training
class TextDataset(Dataset):
  def __init__(self,texts,labels):
    self.texts = texts
    self.labels = labels

  def __getitem__(self,idx):
    encodings = tokenizer(self.texts[idx], truncation=True, padding="max_length")
    item = {key: torch.tensor(val) for key, val in encodings.items()}
    item['labels'] = torch.tensor(self.labels[idx],dtype=torch.float32)
    del encodings
    return item

  def __len__(self):
    return len(self.labels)

train_dataset = TextDataset(train_texts,train_labels)
val_dataset = TextDataset(val_texts, val_labels)

In [None]:
# Logging into HuggingFace with token so trained model can be pushed
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Train the model using Trainer
We define a model that includes a pre-trained base and also set the problem to `multi_label_classification`. Then we train the model using `Trainer`, which requires `TrainingArguments` beforehand that specify training hyperparameters, where we can set learning rate, batch sizes and `push_to_hub=True`.

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", 
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)
model.to(device)

training_args = TrainingArguments(
    output_dir="finetuned-bert-uncased",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    load_best_model_at_end=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer
)

trainer.train()

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Download file pytorch_model.bin:   0%|          | 8.00k/418M [00:00<?, ?B/s]

Download file runs/Apr27_20-46-09_c5323c8d11c9/1682628372.4779716/events.out.tfevents.1682628372.c5323c8d11c9.…

Download file runs/Apr27_20-19-15_48a9d62a59d8/1682626764.5291471/events.out.tfevents.1682626764.48a9d62a59d8.…

Download file runs/Apr27_20-20-21_48a9d62a59d8/events.out.tfevents.1682626833.48a9d62a59d8.33349.0: 100%|#####…

Download file runs/Apr27_18-19-01_48a9d62a59d8/1682619615.768442/events.out.tfevents.1682619615.48a9d62a59d8.4…

Download file runs/Apr27_20-47-33_c5323c8d11c9/1682628456.232527/events.out.tfevents.1682628456.c5323c8d11c9.4…

Download file runs/Apr27_20-20-21_48a9d62a59d8/1682626833.8729882/events.out.tfevents.1682626833.48a9d62a59d8.…

Download file runs/Apr27_20-49-48_c5323c8d11c9/1682628591.0935924/events.out.tfevents.1682628591.c5323c8d11c9.…

Clean file runs/Apr27_20-46-09_c5323c8d11c9/1682628372.4779716/events.out.tfevents.1682628372.c5323c8d11c9.358…

Clean file runs/Apr27_20-19-15_48a9d62a59d8/1682626764.5291471/events.out.tfevents.1682626764.48a9d62a59d8.322…

Clean file runs/Apr27_20-20-21_48a9d62a59d8/events.out.tfevents.1682626833.48a9d62a59d8.33349.0:  17%|#7      …

Clean file runs/Apr27_18-19-01_48a9d62a59d8/1682619615.768442/events.out.tfevents.1682619615.48a9d62a59d8.473.…

Clean file runs/Apr27_20-49-48_c5323c8d11c9/1682628591.0935924/events.out.tfevents.1682628591.c5323c8d11c9.463…

Clean file runs/Apr27_20-20-21_48a9d62a59d8/1682626833.8729882/events.out.tfevents.1682626833.48a9d62a59d8.333…

Clean file runs/Apr27_20-47-33_c5323c8d11c9/1682628456.232527/events.out.tfevents.1682628456.c5323c8d11c9.4206…

Download file runs/Apr27_19-42-45_48a9d62a59d8/1682624616.513479/events.out.tfevents.1682624616.48a9d62a59d8.2…

Clean file runs/Apr27_19-42-45_48a9d62a59d8/1682624616.513479/events.out.tfevents.1682624616.48a9d62a59d8.2294…

Download file runs/Apr27_20-17-23_48a9d62a59d8/1682626666.9451637/events.out.tfevents.1682626666.48a9d62a59d8.…

Clean file runs/Apr27_20-17-23_48a9d62a59d8/1682626666.9451637/events.out.tfevents.1682626666.48a9d62a59d8.322…

Download file runs/Apr27_20-49-48_c5323c8d11c9/events.out.tfevents.1682628591.c5323c8d11c9.4632.0: 100%|######…

Download file runs/Apr27_20-42-08_c5323c8d11c9/1682628222.0662525/events.out.tfevents.1682628222.c5323c8d11c9.…

Download file runs/Apr27_20-47-33_c5323c8d11c9/events.out.tfevents.1682628456.c5323c8d11c9.4206.0: 100%|######…

Download file runs/Apr27_20-17-23_48a9d62a59d8/events.out.tfevents.1682626666.48a9d62a59d8.32253.0: 100%|#####…

Download file runs/Apr27_18-19-01_48a9d62a59d8/events.out.tfevents.1682619615.48a9d62a59d8.473.0: 100%|#######…

Clean file runs/Apr27_20-49-48_c5323c8d11c9/events.out.tfevents.1682628591.c5323c8d11c9.4632.0:  20%|##       …

Clean file runs/Apr27_20-42-08_c5323c8d11c9/1682628222.0662525/events.out.tfevents.1682628222.c5323c8d11c9.852…

Clean file runs/Apr27_20-47-33_c5323c8d11c9/events.out.tfevents.1682628456.c5323c8d11c9.4206.0:  24%|##3      …

Clean file runs/Apr27_20-17-23_48a9d62a59d8/events.out.tfevents.1682626666.48a9d62a59d8.32253.0:  22%|##2     …

Clean file runs/Apr27_18-19-01_48a9d62a59d8/events.out.tfevents.1682619615.48a9d62a59d8.473.0:  22%|##2       …

Download file runs/Apr27_20-19-15_48a9d62a59d8/events.out.tfevents.1682626764.48a9d62a59d8.32253.2: 100%|#####…

Clean file runs/Apr27_20-19-15_48a9d62a59d8/events.out.tfevents.1682626764.48a9d62a59d8.32253.2:  24%|##4     …

Download file runs/Apr27_20-42-08_c5323c8d11c9/events.out.tfevents.1682628222.c5323c8d11c9.852.0: 100%|#######…

Download file runs/Apr27_20-46-09_c5323c8d11c9/events.out.tfevents.1682628372.c5323c8d11c9.3589.0: 100%|######…

Clean file runs/Apr27_20-42-08_c5323c8d11c9/events.out.tfevents.1682628222.c5323c8d11c9.852.0:  24%|##4       …

Download file training_args.bin: 100%|##########| 3.50k/3.50k [00:00<?, ?B/s]

Clean file runs/Apr27_20-46-09_c5323c8d11c9/events.out.tfevents.1682628372.c5323c8d11c9.3589.0:  24%|##4      …

Clean file training_args.bin:  29%|##8       | 1.00k/3.50k [00:00<?, ?B/s]

Download file runs/Apr27_19-42-45_48a9d62a59d8/events.out.tfevents.1682624616.48a9d62a59d8.22948.0: 100%|#####…

Clean file runs/Apr27_19-42-45_48a9d62a59d8/events.out.tfevents.1682624616.48a9d62a59d8.22948.0:  24%|##3     …

Clean file pytorch_model.bin:   0%|          | 1.00k/418M [00:00<?, ?B/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
1,0.0525,0.048165
2,0.037,0.044507
3,0.0275,0.048948
4,0.0188,0.04908
5,0.0146,0.050677


TrainOutput(global_step=6250, training_loss=0.03428032752990723, metrics={'train_runtime': 10293.1001, 'train_samples_per_second': 9.715, 'train_steps_per_second': 0.607, 'total_flos': 2.63120504832e+16, 'train_loss': 0.03428032752990723, 'epoch': 5.0})

## Post-train
Now we can push the trained model to HuggingFa

In [None]:
trainer.push_to_hub()

Upload file runs/Apr30_02-24-35_5690a136cd0d/events.out.tfevents.1682821590.5690a136cd0d.712.0:   0%|         …

To https://huggingface.co/andyqin18/finetuned-bert-uncased
   fe51685..cf13f54  main -> main

   fe51685..cf13f54  main -> main

To https://huggingface.co/andyqin18/finetuned-bert-uncased
   cf13f54..ceb65db  main -> main

   cf13f54..ceb65db  main -> main



'https://huggingface.co/andyqin18/finetuned-bert-uncased/commit/cf13f54cbb5ba93b74ad449c875efa268f647504'

In [None]:
trainer.evaluate()

{'eval_loss': 0.04450742155313492,
 'eval_runtime': 177.3039,
 'eval_samples_per_second': 28.2,
 'eval_steps_per_second': 1.765,
 'epoch': 5.0}