{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "K6KNj8R5pFOi",
    "outputId": "73e388e8-294f-438d-ddc2-06ae7132580a"
   },
   "outputs": [],
   "source": [
    "!kaggle competitions download -c jigsaw-toxic-comment-classification-challenge\n",
    "!unzip jigsaw-toxic-comment-classification-challenge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-a6Sx13TqW2h",
    "outputId": "eb6bb305-7b66-4f59-e1e3-24858c1309c4"
   },
   "outputs": [],
   "source": [
    "!unzip test.csv.zip  \n",
    "!unzip test_labels.csv.zip  \n",
    "!unzip train.csv.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "Jt-aOqhVqavv"
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import pandas as pd\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from transformers import Trainer, TrainingArguments\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mps:0\n"
     ]
    }
   ],
   "source": [
    "# Use GPU\n",
    "#device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "device = \"mps:0\" if torch.backends.mps.is_available() else \"cpu\"\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "zMDF7x0H4VFW"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>comment_text</th>\n",
       "      <th>toxic</th>\n",
       "      <th>severe_toxic</th>\n",
       "      <th>obscene</th>\n",
       "      <th>threat</th>\n",
       "      <th>insult</th>\n",
       "      <th>identity_hate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0000997932d777bf</td>\n",
       "      <td>Explanation\\nWhy the edits made under my usern...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 id                                       comment_text  toxic  \\\n",
       "0  0000997932d777bf  Explanation\\nWhy the edits made under my usern...      0   \n",
       "\n",
       "   severe_toxic  obscene  threat  insult  identity_hate  \n",
       "0             0        0       0       0              0  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load training text and label dataset\n",
    "# Preprocess data\n",
    "\n",
    "#test_texts = pd.read_csv(\"test.csv\").values.tolist()\n",
    "#test_labels = pd.read_csv('test_labels.csv').values.tolist()\n",
    "\n",
    "train = pd.read_csv('train.csv')\n",
    "train.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "159571 (159571, 8)\n",
      "id               False\n",
      "comment_text     False\n",
      "toxic            False\n",
      "severe_toxic     False\n",
      "obscene          False\n",
      "threat           False\n",
      "insult           False\n",
      "identity_hate    False\n",
      "dtype: bool\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "# Any duplicates?\n",
    "print(len(train['comment_text'].unique()), train.shape)\n",
    "\n",
    "# Any missing values?\n",
    "print(train.isnull().any())\n",
    "print(train.isnull().values.any())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>comment_text</th>\n",
       "      <th>toxic</th>\n",
       "      <th>severe_toxic</th>\n",
       "      <th>obscene</th>\n",
       "      <th>threat</th>\n",
       "      <th>insult</th>\n",
       "      <th>identity_hate</th>\n",
       "      <th>grouped_labels</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0000997932d777bf</td>\n",
       "      <td>Explanation\\nWhy the edits made under my usern...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0, 0, 0, 0, 0]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 id                                       comment_text  toxic  \\\n",
       "0  0000997932d777bf  Explanation\\nWhy the edits made under my usern...      0   \n",
       "\n",
       "   severe_toxic  obscene  threat  insult  identity_hate      grouped_labels  \n",
       "0             0        0       0       0              0  [0, 0, 0, 0, 0, 0]  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Group labels to get right format for training\n",
    "labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']\n",
    "train['grouped_labels'] = train[labels].values.tolist()\n",
    "train.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to list from dataframe\n",
    "train_texts = train['comment_text'].values.tolist()\n",
    "train_labels = train['grouped_labels'].values.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "vkxJ6NkFlc46",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Use distilbert, a faster model of BERT which keeps 95% of the performance\n",
    "model_name = \"bert-base-uncased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 0, 1, 1, 0, 0]    11\n",
      "[1, 1, 0, 1, 0, 0]    11\n",
      "[1, 0, 0, 1, 0, 1]     7\n",
      "[1, 1, 0, 0, 1, 1]     7\n",
      "[1, 1, 1, 0, 0, 1]     6\n",
      "[1, 1, 1, 1, 0, 0]     4\n",
      "[0, 0, 0, 1, 1, 0]     3\n",
      "[1, 0, 0, 1, 1, 1]     3\n",
      "[1, 1, 0, 0, 0, 1]     3\n",
      "[0, 0, 1, 0, 0, 1]     3\n",
      "[0, 0, 1, 1, 0, 0]     2\n",
      "[0, 0, 1, 1, 1, 0]     2\n",
      "[1, 1, 0, 1, 1, 0]     1\n",
      "[1, 1, 0, 1, 0, 1]     1\n",
      "Name: grouped_labels, dtype: int64\n",
      "df label indices with only one instance:  [159029, 158498, 157010, 154553, 149180, 144159, 139501, 138026, 134459, 133505, 127410, 120395, 115766, 113304, 110056, 107881, 107096, 101089, 98699, 86746, 76454, 74607, 68264, 66350, 63687, 61934, 57594, 53408, 45101, 41461, 36141, 31191, 30566, 29445, 23374, 17187, 15977, 9487, 8979, 6316, 6063, 2374]\n"
     ]
    }
   ],
   "source": [
    "# Also do preprocessing to see if there are any unique rows\n",
    "# with that specfic combination of labels\n",
    "# If that is the case, we want to include that row in the training data\n",
    "\n",
    "# Find unique label combinations\n",
    "label_counts = train['grouped_labels'].astype(str).value_counts()\n",
    "print(label_counts[-14:])\n",
    "\n",
    "# Take low frequency labels\n",
    "low_freq = label_counts[label_counts<10].keys()\n",
    "low_freq_inds = sorted(list(train[train['grouped_labels'].astype(str).isin(low_freq)].index), reverse=True)\n",
    "print('df label indices with only one instance: ', low_freq_inds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "low_freq_train_texts = [train_texts.pop(i) for i in low_freq_inds]\n",
    "low_freq_train_labels = [train_labels.pop(i) for i in low_freq_inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add low freq values to training data\n",
    "train_texts.extend(low_freq_train_texts)\n",
    "train_labels.extend(low_freq_train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split datasets for training\n",
    "train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shorten token to increase training speed, average is below this\n",
    "max_length = 100\n",
    "train_encodings = tokenizer(train_texts, truncation=True, padding=True, return_tensors=\"pt\", max_length=max_length).to(device)\n",
    "val_encodings = tokenizer(val_texts, truncation=True, padding=True, return_tensors=\"pt\", max_length=max_length).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ToxicDataset(Dataset):\n",
    "  def __init__(self, encodings, labels):\n",
    "    self.encodings = encodings\n",
    "    self.labels = [[float(y) for y in x] for x in labels]\n",
    "\n",
    "  def __getitem__(self, idx):\n",
    "    item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
    "    item['labels'] = torch.tensor(self.labels[idx])\n",
    "    return item\n",
    "\n",
    "  def __len__(self):\n",
    "    return len(self.labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = ToxicDataset(train_encodings, train_labels)\n",
    "val_dataset = ToxicDataset(val_encodings, val_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(model_name,\n",
    "                                                           num_labels=6,\n",
    "                                                          ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "collapsed": true,
    "id": "CI2B0V5D27gA",
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n",
      "PyTorch: setting up devices\n",
      "***** Running training *****\n",
      "  Num examples = 127656\n",
      "  Num Epochs = 1\n",
      "  Instantaneous batch size per device = 16\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 7979\n",
      "  Number of trainable parameters = 109486854\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='33' max='7979' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [  33/7979 00:21 < 1:33:06, 1.42 it/s, Epoch 0.00/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.605800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.590100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.550200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[56], line 28\u001b[0m\n\u001b[1;32m      9\u001b[0m training_args \u001b[38;5;241m=\u001b[39m TrainingArgumentsWithMPSSupport(\n\u001b[1;32m     10\u001b[0m     output_dir \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m./results\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m     11\u001b[0m     num_train_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     18\u001b[0m     logging_steps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m,\n\u001b[1;32m     19\u001b[0m )\n\u001b[1;32m     21\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m     22\u001b[0m     model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m     23\u001b[0m     args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[1;32m     24\u001b[0m     train_dataset\u001b[38;5;241m=\u001b[39mtrain_dataset,\n\u001b[1;32m     25\u001b[0m     eval_dataset\u001b[38;5;241m=\u001b[39mval_dataset,\n\u001b[1;32m     26\u001b[0m )\n\u001b[0;32m---> 28\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/transformers/trainer.py:1501\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\n\u001b[1;32m   1498\u001b[0m inner_training_loop \u001b[38;5;241m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m   1499\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inner_training_loop, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_train_batch_size, args\u001b[38;5;241m.\u001b[39mauto_find_batch_size\n\u001b[1;32m   1500\u001b[0m )\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[43m    \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1503\u001b[0m \u001b[43m    \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1504\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1505\u001b[0m \u001b[43m    \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1506\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/transformers/trainer.py:1749\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   1747\u001b[0m         tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_step(model, inputs)\n\u001b[1;32m   1748\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1749\u001b[0m     tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1751\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m   1752\u001b[0m     args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m   1753\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m   1754\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m   1755\u001b[0m ):\n\u001b[1;32m   1756\u001b[0m     \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m   1757\u001b[0m     tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/transformers/trainer.py:2526\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m   2524\u001b[0m     loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdeepspeed\u001b[38;5;241m.\u001b[39mbackward(loss)\n\u001b[1;32m   2525\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2526\u001b[0m     \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2528\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach()\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/_tensor.py:488\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    479\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    480\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    481\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    486\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    487\u001b[0m     )\n\u001b[0;32m--> 488\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    489\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m    490\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/autograd/__init__.py:204\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    199\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    201\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m    202\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    203\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 204\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    205\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    206\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "class TrainingArgumentsWithMPSSupport(TrainingArguments):\n",
    "    @property\n",
    "    def device(self) -> torch.device:\n",
    "        if torch.backends.mps.is_available():\n",
    "            return torch.device(\"mps\")\n",
    "        else:\n",
    "            return torch.device(\"cpu\")\n",
    "\n",
    "training_args = TrainingArgumentsWithMPSSupport(\n",
    "    output_dir = './results',\n",
    "    num_train_epochs=1,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    warmup_steps=500,\n",
    "    learning_rate=5e-5,\n",
    "    weight_decay=0.01,\n",
    "    logging_dir='./logs',\n",
    "    logging_steps=10,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=val_dataset,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to ./model_checkpoint/done\n",
      "Configuration saved in ./model_checkpoint/done/config.json\n",
      "Model weights saved in ./model_checkpoint/done/pytorch_model.bin\n"
     ]
    }
   ],
   "source": [
    "trainer.save_model('./model_checkpoint/done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from transformers import BertTokenizer, BertForSequenceClassification\n",
    "#saved = DistilBertModel.from_pretrained('./model_checkpoint/trained', num_labels=6, problem_type=\"multi_label_classification\")\n",
    "saved = BertForSequenceClassification.from_pretrained('./model_checkpoint/fine_tuned')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'trainer' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241m.\u001b[39mevaluate()\n",
      "\u001b[0;31mNameError\u001b[0m: name 'trainer' is not defined"
     ]
    }
   ],
   "source": [
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0.4601849317550659,\n",
       "  0.0626736581325531,\n",
       "  0.1962047964334488,\n",
       "  0.0715285912156105,\n",
       "  0.1363525241613388,\n",
       "  0.0730554461479187]]"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = \"fun\"\n",
    "encoded_input = tokenizer(text, return_tensors=\"pt\")\n",
    "outputs = saved(**encoded_input)\n",
    "predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)\n",
    "predictions = predictions.cpu().detach().numpy()\n",
    "predictions.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = [1 if x >= 0.5 else 0 for x in predictions[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 0, 0, 0, 0, 0]"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}