{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Best Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import skorch\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import gradio as gr\n",
    "\n",
    "import librosa\n",
    "\n",
    "from joblib import dump, load\n",
    "\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "from resnet import ResNet\n",
    "from gradio_utils import load_as_librosa, predict_gradio\n",
    "from dataloading import uniformize, to_numpy\n",
    "from preprocessing import MfccTransformer, TorchTransform\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Notebook params\n",
    "SEED : int = 42\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "\n",
    "# Dataloading params\n",
    "PATHS: list = [\n",
    "    \"../Projet-ML/data/\",\n",
    "    \"../Projet-ML/new_data/JulienNestor\",\n",
    "    \"../Projet-ML/new_data/classroom_data\",\n",
    "    \"../Projet-ML/new_data/class\",\n",
    "    \"../Projet-ML/new_data/JulienRaph\",\n",
    "]\n",
    "REMOVE_LABEL: list = [\n",
    "         \"penduleinverse\", \"pendule\", \n",
    "         \"decollage\", \"atterrissage\",\n",
    "         \"plushaut\", \"plusbas\",\n",
    "         \"etatdurgence\",\n",
    "         \"faisunflip\", \n",
    "         \"faisUnFlip\", \"arreteToi\", \"etatDurgence\",\n",
    "        #  \"tournedroite\", \"arretetoi\", \"tournegauche\"\n",
    "]\n",
    "SAMPLE_RATE: int = 16_000\n",
    "METHOD: str = \"time_stretch\"\n",
    "MAX_TIME: float = 3.0\n",
    "\n",
    "# Features Extraction params\n",
    "N_MFCC: int = 64\n",
    "HOP_LENGHT = 2_048"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1 - Dataloading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1-Dataloading\n",
    "from dataloading import load_dataset, to_numpy\n",
    "dataset, uniform_lambda = load_dataset(PATHS,\n",
    "      remove_label=REMOVE_LABEL,\n",
    "      sr=SAMPLE_RATE,\n",
    "      method=METHOD,\n",
    "      max_time=MAX_TIME\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['recule',\n",
       " 'tournedroite',\n",
       " 'arretetoi',\n",
       " 'tournegauche',\n",
       " 'gauche',\n",
       " 'avance',\n",
       " 'droite']"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(dataset[\"ground_truth\"].unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2-Train and split\n",
    "from sklearn.model_selection import train_test_split\n",
    "dataset_train, dataset_test = train_test_split(dataset, random_state=0)\n",
    "\n",
    "X_train = to_numpy(dataset_train[\"y_uniform\"])\n",
    "y_train = to_numpy(dataset_train[\"ground_truth\"])\n",
    "X_test = to_numpy(dataset_test[\"y_uniform\"])\n",
    "y_test = to_numpy(dataset_test[\"ground_truth\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 - Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "only_mffc_transform = Pipeline(\n",
    "    steps=[\n",
    "        (\"mfcc\", MfccTransformer(N_MFCC=N_MFCC, reshape_output=False, hop_length=HOP_LENGHT)),\n",
    "        (\"torch\", TorchTransform())\n",
    "    ]\n",
    ")\n",
    "\n",
    "only_mffc_transform.fit(X_train)\n",
    "\n",
    "X_train_mfcc_torch = only_mffc_transform.transform(X_train)\n",
    "X_test_mfcc_torch = only_mffc_transform.transform(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train a LabelEncoder (if needed)\n",
    "label_encoder = LabelEncoder()\n",
    "label_encoder.fit(y_train)\n",
    "y_train_enc = label_encoder.transform(y_train)\n",
    "y_test_enc = label_encoder.transform(y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 - ResNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "if hasattr(torch, \"has_mps\") and torch.has_mps:\n",
    "    device = torch.device(\"mps\")\n",
    "elif hasattr(torch, \"has_cuda\") and torch.has_cuda:\n",
    "    device = torch.device(\"cuda\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.1 - nn.Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from resnet import ResNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2 - Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss     dur\n",
      "-------  ------------  ------\n",
      "      1        \u001b[36m2.8636\u001b[0m  1.9894\n",
      "      2        \u001b[36m1.9484\u001b[0m  0.4326\n",
      "      3        \u001b[36m1.8183\u001b[0m  0.4312\n",
      "      4        \u001b[36m1.6839\u001b[0m  0.4318\n",
      "      5        \u001b[36m1.5514\u001b[0m  0.4326\n",
      "      6        \u001b[36m1.4672\u001b[0m  0.4309\n",
      "      7        \u001b[36m1.2708\u001b[0m  0.4323\n",
      "      8        1.2842  0.4308\n",
      "      9        \u001b[36m1.0673\u001b[0m  0.4316\n",
      "     10        \u001b[36m0.9857\u001b[0m  0.4307\n",
      "     11        \u001b[36m0.9400\u001b[0m  0.4322\n",
      "     12        \u001b[36m0.9096\u001b[0m  0.4310\n",
      "     13        \u001b[36m0.7838\u001b[0m  0.4313\n",
      "     14        \u001b[36m0.7031\u001b[0m  0.4330\n",
      "     15        \u001b[36m0.6361\u001b[0m  0.4313\n",
      "     16        \u001b[36m0.5983\u001b[0m  0.4325\n",
      "     17        \u001b[36m0.5712\u001b[0m  0.4318\n",
      "     18        \u001b[36m0.4825\u001b[0m  0.4315\n",
      "     19        0.4951  0.4323\n",
      "     20        \u001b[36m0.4653\u001b[0m  0.4320\n",
      "     21        \u001b[36m0.4050\u001b[0m  0.4333\n",
      "     22        0.4351  0.4317\n",
      "     23        0.4365  0.4314\n",
      "     24        \u001b[36m0.4000\u001b[0m  0.4304\n",
      "     25        \u001b[36m0.3876\u001b[0m  0.4319\n",
      "     26        \u001b[36m0.3740\u001b[0m  0.4327\n",
      "     27        \u001b[36m0.3589\u001b[0m  0.4323\n",
      "     28        \u001b[36m0.3173\u001b[0m  0.4330\n",
      "     29        0.3412  0.4322\n",
      "     30        0.3263  0.4335\n",
      "     31        0.3313  0.4322\n",
      "     32        \u001b[36m0.3033\u001b[0m  0.4327\n",
      "     33        0.3333  0.4325\n",
      "     34        \u001b[36m0.2912\u001b[0m  0.4328\n",
      "     35        \u001b[36m0.2834\u001b[0m  0.4330\n",
      "     36        0.3150  0.4326\n",
      "     37        0.2842  0.4339\n",
      "     38        0.2854  0.4335\n",
      "     39        \u001b[36m0.2588\u001b[0m  0.4341\n",
      "     40        0.2775  0.4340\n",
      "     41        0.2823  0.4336\n",
      "     42        0.2826  0.4344\n",
      "     43        0.2723  0.4328\n",
      "     44        0.2638  0.4354\n",
      "     45        \u001b[36m0.2350\u001b[0m  0.4348\n",
      "     46        0.2463  0.4334\n",
      "     47        0.2688  0.4333\n",
      "     48        0.2652  0.4343\n",
      "     49        0.2869  0.4348\n",
      "     50        0.2833  0.4338\n",
      "     51        0.2541  0.4335\n",
      "     52        0.2796  0.4318\n",
      "     53        \u001b[36m0.2273\u001b[0m  0.4350\n",
      "     54        0.2516  0.4341\n",
      "     55        0.2392  0.4332\n",
      "     56        0.2480  0.4332\n",
      "     57        0.2341  0.4331\n",
      "     58        \u001b[36m0.2240\u001b[0m  0.4332\n",
      "     59        0.2441  0.4333\n",
      "     60        0.2313  0.4329\n",
      "     61        0.2590  0.4348\n",
      "     62        0.2412  0.4344\n",
      "     63        0.2391  0.4323\n",
      "     64        0.2591  0.4331\n",
      "     65        0.2595  0.4336\n",
      "     66        0.2356  0.4328\n",
      "     67        0.2529  0.4351\n",
      "     68        0.2262  0.4330\n",
      "     69        0.2438  0.4322\n",
      "     70        \u001b[36m0.2189\u001b[0m  0.4323\n",
      "     71        0.2283  0.4318\n",
      "     72        0.2333  0.4325\n",
      "     73        0.2327  0.4333\n",
      "     74        \u001b[36m0.2062\u001b[0m  0.4350\n",
      "     75        0.2566  0.4323\n",
      "     76        0.2373  0.4333\n",
      "     77        0.2253  0.4332\n",
      "     78        0.2446  0.4328\n",
      "     79        0.2459  0.4328\n",
      "     80        \u001b[36m0.2006\u001b[0m  0.4322\n",
      "     81        0.2170  0.4337\n",
      "     82        0.2270  0.4324\n",
      "     83        0.2177  0.4324\n",
      "     84        0.2235  0.4318\n",
      "     85        0.2326  0.4341\n",
      "     86        0.2260  0.4330\n",
      "     87        0.2479  0.4318\n",
      "     88        0.2267  0.4335\n",
      "     89        0.2544  0.4324\n",
      "     90        0.2167  0.4347\n",
      "     91        0.2280  0.4328\n",
      "     92        0.2093  0.4334\n",
      "     93        0.2035  0.4337\n",
      "     94        0.2077  0.4327\n",
      "     95        0.2437  0.4341\n",
      "     96        0.2278  0.4330\n",
      "     97        0.2265  0.4359\n",
      "     98        0.2145  0.4328\n",
      "     99        0.2239  0.4336\n",
      "    100        0.2034  0.4333\n",
      "    101        0.2286  0.4332\n",
      "    102        0.2231  0.4325\n",
      "    103        0.2169  0.4327\n",
      "    104        0.2415  0.4337\n",
      "Stopping since train_loss has not improved in the last 25 epochs.\n",
      "0.946058091286307\n"
     ]
    }
   ],
   "source": [
    "# Define net\n",
    "n_labels = np.unique(dataset.ground_truth).size\n",
    "net = ResNet(in_channels=1, num_classes=n_labels)\n",
    "\n",
    "# Define model\n",
    "model = skorch.NeuralNetClassifier(\n",
    "    module=net,\n",
    "    criterion=nn.CrossEntropyLoss(),\n",
    "    callbacks=[skorch.callbacks.EarlyStopping(monitor=\"train_loss\", patience=25)],\n",
    "    max_epochs=200,\n",
    "    lr=0.01,\n",
    "    batch_size=128,\n",
    "    train_split=None,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "model.check_data(X_train_mfcc_torch, y_train_enc)\n",
    "model.fit(X_train_mfcc_torch, y_train_enc)\n",
    "\n",
    "print(model.score(X_test_mfcc_torch, y_test_enc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet(\n",
       "  (conv1): ConvBlock(\n",
       "    (pool_block): Sequential(\n",
       "      (0): ReLU(inplace=True)\n",
       "    )\n",
       "    (block): Sequential(\n",
       "      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): Sequential(\n",
       "        (0): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (conv2): ConvBlock(\n",
       "    (pool_block): Sequential(\n",
       "      (0): ReLU(inplace=True)\n",
       "      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (block): Sequential(\n",
       "      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): Sequential(\n",
       "        (0): ReLU(inplace=True)\n",
       "        (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (res1): Sequential(\n",
       "    (0): ConvBlock(\n",
       "      (pool_block): Sequential(\n",
       "        (0): ReLU(inplace=True)\n",
       "      )\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): Sequential(\n",
       "          (0): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): ConvBlock(\n",
       "      (pool_block): Sequential(\n",
       "        (0): ReLU(inplace=True)\n",
       "      )\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): Sequential(\n",
       "          (0): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (conv3): ConvBlock(\n",
       "    (pool_block): Sequential(\n",
       "      (0): ReLU(inplace=True)\n",
       "    )\n",
       "    (block): Sequential(\n",
       "      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): Sequential(\n",
       "        (0): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (conv4): ConvBlock(\n",
       "    (pool_block): Sequential(\n",
       "      (0): ReLU(inplace=True)\n",
       "      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (block): Sequential(\n",
       "      (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): Sequential(\n",
       "        (0): ReLU(inplace=True)\n",
       "        (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (res2): Sequential(\n",
       "    (0): ConvBlock(\n",
       "      (pool_block): Sequential(\n",
       "        (0): ReLU(inplace=True)\n",
       "      )\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): Sequential(\n",
       "          (0): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): ConvBlock(\n",
       "      (pool_block): Sequential(\n",
       "        (0): ReLU(inplace=True)\n",
       "      )\n",
       "      (block): Sequential(\n",
       "        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): Sequential(\n",
       "          (0): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (classifier): Sequential(\n",
       "    (0): MaxPool2d(kernel_size=(4, 4), stride=(4, 4), padding=0, dilation=1, ceil_mode=False)\n",
       "    (1): AdaptiveAvgPool2d(output_size=1)\n",
       "    (2): Flatten(start_dim=1, end_dim=-1)\n",
       "    (3): Linear(in_features=512, out_features=128, bias=True)\n",
       "    (4): Dropout(p=0.25, inplace=False)\n",
       "    (5): Linear(in_features=128, out_features=7, bias=True)\n",
       "    (6): Dropout(p=0.25, inplace=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.device = torch.device(\"cpu\")\n",
    "model.module.to(torch.device(\"cpu\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['./model/HOP_LENGHT.joblib']"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from joblib import dump, load\n",
    "\n",
    "dump(model, './model/model.joblib') \n",
    "dump(only_mffc_transform, './model/only_mffc_transform.joblib') \n",
    "dump(label_encoder, './model/label_encoder.joblib')\n",
    "dump(SAMPLE_RATE, \"./model/SAMPLE_RATE.joblib\")\n",
    "dump(METHOD, \"./model/METHOD.joblib\")\n",
    "dump(MAX_TIME, \"./model/MAX_TIME.joblib\")\n",
    "dump(N_MFCC, \"./model/N_MFCC.joblib\")\n",
    "dump(HOP_LENGHT, \"./model/HOP_LENGHT.joblib\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load('./model/model.joblib') \n",
    "only_mffc_transform = load('./model/only_mffc_transform.joblib') \n",
    "label_encoder = load('./model/label_encoder.joblib') \n",
    "SAMPLE_RATE = load(\"./model/SAMPLE_RATE.joblib\")\n",
    "METHOD = load(\"./model/METHOD.joblib\")\n",
    "MAX_TIME = load(\"./model/MAX_TIME.joblib\")\n",
    "N_MFCC = load(\"./model/N_MFCC.joblib\")\n",
    "HOP_LENGHT = load(\"./model/HOP_LENGHT.joblib\")\n",
    "\n",
    "sklearn_model = Pipeline(\n",
    "            steps=[\n",
    "                (\"mfcc\", only_mffc_transform),\n",
    "                (\"model\", model)\n",
    "            ]\n",
    "        )\n",
    "\n",
    "uniform_lambda = lambda y, sr: uniformize(y, sr, METHOD, MAX_TIME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "title = r\"ResNet 9\"\n",
    "\n",
    "description = r\"\"\"\n",
    "<center>\n",
    "The resnet9 model was trained to classify drone speech command.\n",
    "<img src=\"http://zeus.blanchon.cc/dropshare/modia.png\" width=200px>\n",
    "</center>\n",
    "\"\"\"\n",
    "article = r\"\"\"\n",
    "- [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385)\n",
    "\"\"\"\n",
    "\n",
    "demo_men = gr.Interface(\n",
    "    title = title,\n",
    "    description = description,\n",
    "    article = article, \n",
    "    fn=lambda data: predict_gradio(\n",
    "        data=data, \n",
    "        uniform_lambda=uniform_lambda, \n",
    "        sklearn_model=sklearn_model,\n",
    "        label_transform=label_encoder,\n",
    "        target_sr=SAMPLE_RATE),\n",
    "    inputs = gr.Audio(source=\"microphone\", type=\"numpy\"),\n",
    "    outputs = gr.Label(),\n",
    "    # allow_flagging = \"manual\",\n",
    "    # flagging_options = ['recule', 'tournedroite', 'arretetoi', 'tournegauche', 'gauche', 'avance', 'droite'],\n",
    "    # flagging_dir = \"./flag/men\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 ('ml')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "f1f34988cae7bd54e626a86efbacac2b339eeffffea662e9af12f610fca26db7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}