{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "VqZ9BtJnwear" }, "source": [ "# Задание 1. Bootstrap" ] }, { "cell_type": "markdown", "metadata": { "id": "2jZ6d3Owweau" }, "source": [ "В этом задании используйте датасет [Breast Cancer 🛠️[doc]](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_breast_cancer.html) — классический датасет для задачи бинарной классификации. Обучите модели:\n", "\n", " - `DecisionTreeClassifier`\n", " - `RandomForestClassifier`\n", " - `LGBMClassifier`\n", " - `SVC`\n", " - `BaggingClassifier` с базовым класификатором `SVC`.\n", "\n", "Параметры моделей можете оставить по умолчанию или задать сами.\n", "\n", "Для каждой модели посчитайте [корреляцию Мэтьюса 📚[wiki]](https://en.wikipedia.org/wiki/Phi_coefficient) — метрику для оценки качества бинарной классификации, в частности, устойчивую к дисбалансу классов, (`sklearn.metrics.matthews_corrcoef` [🛠️[doc]](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html) для предсказанного ею класса и реального. Подробнее почитать про его пользу можно в статье:\n", "\n", "[[article] 🎓 The advantages of the Matthews correlation coefficient (MCC) over F1 score and accuracy in binary classification evaluation](https://bmcgenomics.biomedcentral.com/articles/10.1186/s12864-019-6413-7)\n", "\n", "С помощью bootstrap-подхода постройте 90% доверительные интервалы для качества полученных моделей. Используйте функцию `bootstrap_metric()` из лекции.\n", "\n", "Постройте [боксплоты 🛠️[doc]](https://seaborn.pydata.org/generated/seaborn.boxplot.html) для качества полученных моделей." ] }, { "cell_type": "markdown", "metadata": { "id": "Nan68aZIweaw" }, "source": [ "Установка и импорт необходимых библиотек:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EobqrBMqweax", "outputId": "44e71e9b-ca42-4128-f9bb-d8c982586514" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m242.6/242.6 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install -q dask[dataframe]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rNK2PHixweaz" }, "outputs": [], "source": [ "import lightgbm\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import sklearn.datasets\n", "import matplotlib.pyplot as plt\n", "\n", "from sklearn.svm import SVC\n", "from sklearn.metrics import matthews_corrcoef\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestClassifier, BaggingClassifier" ] }, { "cell_type": "markdown", "metadata": { "id": "c914Tk78wea0" }, "source": [ "Загрузка датасета:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "collapsed": true, "id": "nA-lRHUTwea0", "jupyter": { "outputs_hidden": true }, "outputId": "6735dae2-ad75-4dc4-b9c3-d3cc248e6a44" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ".. _breast_cancer_dataset:\n", "\n", "Breast cancer wisconsin (diagnostic) dataset\n", "--------------------------------------------\n", "\n", "**Data Set Characteristics:**\n", "\n", ":Number of Instances: 569\n", "\n", ":Number of Attributes: 30 numeric, predictive attributes and the class\n", "\n", ":Attribute Information:\n", " - radius (mean of distances from center to points on the perimeter)\n", " - texture (standard deviation of gray-scale values)\n", " - perimeter\n", " - area\n", " - smoothness (local variation in radius lengths)\n", " - compactness (perimeter^2 / area - 1.0)\n", " - concavity (severity of concave portions of the contour)\n", " - concave points (number of concave portions of the contour)\n", " - symmetry\n", " - fractal dimension (\"coastline approximation\" - 1)\n", "\n", " The mean, standard error, and \"worst\" or largest (mean of the three\n", " worst/largest values) of these features were computed for each image,\n", " resulting in 30 features. For instance, field 0 is Mean Radius, field\n", " 10 is Radius SE, field 20 is Worst Radius.\n", "\n", " - class:\n", " - WDBC-Malignant\n", " - WDBC-Benign\n", "\n", ":Summary Statistics:\n", "\n", "===================================== ====== ======\n", " Min Max\n", "===================================== ====== ======\n", "radius (mean): 6.981 28.11\n", "texture (mean): 9.71 39.28\n", "perimeter (mean): 43.79 188.5\n", "area (mean): 143.5 2501.0\n", "smoothness (mean): 0.053 0.163\n", "compactness (mean): 0.019 0.345\n", "concavity (mean): 0.0 0.427\n", "concave points (mean): 0.0 0.201\n", "symmetry (mean): 0.106 0.304\n", "fractal dimension (mean): 0.05 0.097\n", "radius (standard error): 0.112 2.873\n", "texture (standard error): 0.36 4.885\n", "perimeter (standard error): 0.757 21.98\n", "area (standard error): 6.802 542.2\n", "smoothness (standard error): 0.002 0.031\n", "compactness (standard error): 0.002 0.135\n", "concavity (standard error): 0.0 0.396\n", "concave points (standard error): 0.0 0.053\n", "symmetry (standard error): 0.008 0.079\n", "fractal dimension (standard error): 0.001 0.03\n", "radius (worst): 7.93 36.04\n", "texture (worst): 12.02 49.54\n", "perimeter (worst): 50.41 251.2\n", "area (worst): 185.2 4254.0\n", "smoothness (worst): 0.071 0.223\n", "compactness (worst): 0.027 1.058\n", "concavity (worst): 0.0 1.252\n", "concave points (worst): 0.0 0.291\n", "symmetry (worst): 0.156 0.664\n", "fractal dimension (worst): 0.055 0.208\n", "===================================== ====== ======\n", "\n", ":Missing Attribute Values: None\n", "\n", ":Class Distribution: 212 - Malignant, 357 - Benign\n", "\n", ":Creator: Dr. William H. Wolberg, W. Nick Street, Olvi L. Mangasarian\n", "\n", ":Donor: Nick Street\n", "\n", ":Date: November, 1995\n", "\n", "This is a copy of UCI ML Breast Cancer Wisconsin (Diagnostic) datasets.\n", "https://goo.gl/U2Uwz2\n", "\n", "Features are computed from a digitized image of a fine needle\n", "aspirate (FNA) of a breast mass. They describe\n", "characteristics of the cell nuclei present in the image.\n", "\n", "Separating plane described above was obtained using\n", "Multisurface Method-Tree (MSM-T) [K. P. Bennett, \"Decision Tree\n", "Construction Via Linear Programming.\" Proceedings of the 4th\n", "Midwest Artificial Intelligence and Cognitive Science Society,\n", "pp. 97-101, 1992], a classification method which uses linear\n", "programming to construct a decision tree. Relevant features\n", "were selected using an exhaustive search in the space of 1-4\n", "features and 1-3 separating planes.\n", "\n", "The actual linear program used to obtain the separating plane\n", "in the 3-dimensional space is that described in:\n", "[K. P. Bennett and O. L. Mangasarian: \"Robust Linear\n", "Programming Discrimination of Two Linearly Inseparable Sets\",\n", "Optimization Methods and Software 1, 1992, 23-34].\n", "\n", "This database is also available through the UW CS ftp server:\n", "\n", "ftp ftp.cs.wisc.edu\n", "cd math-prog/cpo-dataset/machine-learn/WDBC/\n", "\n", ".. dropdown:: References\n", "\n", " - W.N. Street, W.H. Wolberg and O.L. Mangasarian. Nuclear feature extraction\n", " for breast tumor diagnosis. IS&T/SPIE 1993 International Symposium on\n", " Electronic Imaging: Science and Technology, volume 1905, pages 861-870,\n", " San Jose, CA, 1993.\n", " - O.L. Mangasarian, W.N. Street and W.H. Wolberg. Breast cancer diagnosis and\n", " prognosis via linear programming. Operations Research, 43(4), pages 570-577,\n", " July-August 1995.\n", " - W.H. Wolberg, W.N. Street, and O.L. Mangasarian. Machine learning techniques\n", " to diagnose breast cancer from fine-needle aspirates. Cancer Letters 77 (1994)\n", " 163-171.\n", "\n" ] } ], "source": [ "breast_cancer = sklearn.datasets.load_breast_cancer()\n", "print(breast_cancer.DESCR)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PMzugGlGwea0" }, "outputs": [], "source": [ "x = breast_cancer.data\n", "y = breast_cancer.target\n", "x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TxFinYhZ_PIS", "outputId": "930146d8-42b0-48a3-e400-5ec67724ad28" }, "outputs": [ { "data": { "text/plain": [ "array([[1.799e+01, 1.038e+01, 1.228e+02, 1.001e+03, 1.184e-01, 2.776e-01,\n", " 3.001e-01, 1.471e-01, 2.419e-01, 7.871e-02, 1.095e+00, 9.053e-01,\n", " 8.589e+00, 1.534e+02, 6.399e-03, 4.904e-02, 5.373e-02, 1.587e-02,\n", " 3.003e-02, 6.193e-03, 2.538e+01, 1.733e+01, 1.846e+02, 2.019e+03,\n", " 1.622e-01, 6.656e-01, 7.119e-01, 2.654e-01, 4.601e-01, 1.189e-01],\n", " [2.057e+01, 1.777e+01, 1.329e+02, 1.326e+03, 8.474e-02, 7.864e-02,\n", " 8.690e-02, 7.017e-02, 1.812e-01, 5.667e-02, 5.435e-01, 7.339e-01,\n", " 3.398e+00, 7.408e+01, 5.225e-03, 1.308e-02, 1.860e-02, 1.340e-02,\n", " 1.389e-02, 3.532e-03, 2.499e+01, 2.341e+01, 1.588e+02, 1.956e+03,\n", " 1.238e-01, 1.866e-01, 2.416e-01, 1.860e-01, 2.750e-01, 8.902e-02],\n", " [1.969e+01, 2.125e+01, 1.300e+02, 1.203e+03, 1.096e-01, 1.599e-01,\n", " 1.974e-01, 1.279e-01, 2.069e-01, 5.999e-02, 7.456e-01, 7.869e-01,\n", " 4.585e+00, 9.403e+01, 6.150e-03, 4.006e-02, 3.832e-02, 2.058e-02,\n", " 2.250e-02, 4.571e-03, 2.357e+01, 2.553e+01, 1.525e+02, 1.709e+03,\n", " 1.444e-01, 4.245e-01, 4.504e-01, 2.430e-01, 3.613e-01, 8.758e-02]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[:3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GJ6iCg4Awea1" }, "outputs": [], "source": [ "# Your code here\n", "DTC = DecisionTreeClassifier(random_state = 42)\n", "DTC.fit(x_train, y_train)\n", "dtc_pred = DTC.predict(x_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "E_tUhJkWA74e", "outputId": "bcad98fe-9205-4206-9d73-6bed18cc0316" }, "outputs": [ { "data": { "text/plain": [ "0.8963356530877563" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "matthews_corrcoef(y_pred=dtc_pred, y_true=y_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rN_GrSMbA7DU" }, "outputs": [], "source": [ "RFC = RandomForestClassifier(max_depth=None,)\n", "RFC.fit(x_train, y_train)\n", "rfc_pred = RFC.predict(x_test)\n", "\n", "LGMBC = lightgbm.LGBMClassifier(n_estimators=2000,\n", " learning_rate=0.1,\n", " max_depth=-1,\n", " num_leaves=2**5,\n", " random_state=42,\n", " min_child_weight=13,\n", " n_jobs=-1,\n", " force_col_wise=True,\n", " verbose=-1,)\n", "LGMBC.fit(X=x_train, y=y_train)\n", "lgbmc_pred = LGMBC.predict(x_test)\n", "\n", "svc = SVC()\n", "svc.fit(x_train, y_train)\n", "svc_pred = svc.predict(X=x_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C0itI8YNBdZD" }, "outputs": [], "source": [ "BagC_SVC = BaggingClassifier(estimator=SVC())\n", "BagC_SVC.fit(x_train, y_train)\n", "bagc_pred = BagC_SVC.predict(x_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hEvTzHOdCh3A" }, "outputs": [], "source": [ "def bootstrap_metric(y_true, y_pred, metric_fn, samples_cnt=1000, random_state=42):\n", " np.random.seed(random_state)\n", " b_metric = np.zeros(samples_cnt)\n", " for i in range(samples_cnt):\n", " poses = np.random.choice(y_true.shape[0], size=y_true.shape[0], replace=True)\n", "\n", " y_true_boot = y_true[poses]\n", " y_pred_boot = y_pred[poses]\n", " m_val = metric_fn(y_true_boot, y_pred_boot)\n", " b_metric[i] = m_val\n", "\n", " return b_metric" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o5Rps5hoB4Jg" }, "outputs": [], "source": [ "boot_dtc = bootstrap_metric(y_test, dtc_pred, metric_fn=matthews_corrcoef)\n", "boot_rfc = bootstrap_metric(y_test, rfc_pred, metric_fn=matthews_corrcoef)\n", "boot_lgbmc = bootstrap_metric(y_test, lgbmc_pred, metric_fn=matthews_corrcoef)\n", "boot_bagc = bootstrap_metric(y_test, bagc_pred, metric_fn=matthews_corrcoef)\n", "boot_svc = bootstrap_metric(y_test, svc_pred, metric_fn=matthews_corrcoef)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 536 }, "id": "yXCp7ZejDAQ2", "outputId": "b25060af-c33e-4210-aba5-9e2244193b91" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import pandas as pd\n", "\n", "\n", "plt.figure(figsize=(16, 6))\n", "sns.boxplot(\n", " data=pd.DataFrame(\n", " {\n", " \"DecTree\": boot_dtc,\n", " \"RandomForest\": boot_rfc,\n", " \"LGMBC\": boot_lgbmc,\n", " \"SVC\" : boot_svc,\n", " \"BAG_SVC\" : boot_bagc\n", " }\n", " )\n", ")\n", "plt.ylabel(\"Matthew Corr \", size=20)\n", "plt.tick_params(axis=\"both\", which=\"major\", labelsize=20)\n", "plt.yticks(fontsize=14)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SlUCysufECHN", "outputId": "8fd0ccec-03aa-4e3f-e2d1-a310a7330e6e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "random_forest 0.9249968928906641 0.03256402729794103\n", "lgmbc 0.9254232010527652 0.03254820677355373\n" ] } ], "source": [ "print('random_forest', boot_rfc.mean(), boot_rfc.std())\n", "print('lgmbc', boot_lgbmc.mean(), boot_lgbmc.std())" ] }, { "cell_type": "markdown", "metadata": { "id": "NRsWZSbYwea1" }, "source": [ "Сделайте вывод о том, какие модели работают лучше.\n", "\n", "**Напишите вывод**\n", "\n", "Лучшими моделями являются RandomForest и LGMBC. Из них чуть (на самую малость) лучше LGMBC, т.к. число выбросов у неё поменьше, если посмотреть на график (но совсем немного)\n", "\n", "?? Как и обсуждалось на лекции, SVC & BAG_SVC показывают схожие результаты." ] }, { "cell_type": "markdown", "metadata": { "id": "h9v_5ruEwea1" }, "source": [ "## Формат результата" ] }, { "cell_type": "markdown", "metadata": { "id": "5w-oVK8Pwea2" }, "source": [ "График с демонстрацией корреляции Мэтьюса для следующих моделей:\n", "\n", " - `DecisionTreeClassifier`\n", " - `RandomForestClassifier`\n", " - `LGBMClassifier`\n", " - `SVC`\n", " - `BaggingClassifier` с базовым класификатором `SVC`\n", "\n", "Пример графика:" ] }, { "cell_type": "markdown", "metadata": { "id": "Po9ACyoRwea2" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "5AK3fbfQwea2" }, "source": [ "# Задание 2. Дисбаланс классов" ] }, { "cell_type": "markdown", "metadata": { "id": "VB722f9vwea2" }, "source": [ "В этом задании мы рассмотрим особенности обучения и контроля качества моделей на данных, содержащих значительный дисбаланс." ] }, { "cell_type": "markdown", "metadata": { "id": "91wJV2DVwea3" }, "source": [ "Установка и импорт необходимых библиотек:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kzWORrblwea3" }, "outputs": [], "source": [ "!pip install -qU imbalanced-learn" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9kc1qeB7wea3" }, "outputs": [], "source": [ "import imblearn\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import accuracy_score, balanced_accuracy_score\n", "from sklearn.model_selection import (\n", " train_test_split,\n", " KFold,\n", " StratifiedKFold,\n", " cross_validate,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "6baaMVgZwea3" }, "source": [ "Важно обращать внимание на сбалансированность классов в наборе.\n", "Предположим, у нас есть некоторый набор данных со следующими метками классов:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rTmE5J1Awea4" }, "outputs": [], "source": [ "real_labels = [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]" ] }, { "cell_type": "markdown", "metadata": { "id": "1YTxXiBPwea4" }, "source": [ "В наборе 16 объектов относятся к классу 0, а 5 — к классу 1.\n", "\n", "Мы обучили две модели. Первая всегда выдает 0:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xpLQ6y9Fwea4" }, "outputs": [], "source": [ "model1_res = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]" ] }, { "cell_type": "markdown", "metadata": { "id": "hDoYYLcvwea5" }, "source": [ "Вторая сумела обнаружить некоторую закономерность в признаках:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "70o5b1bawea5" }, "outputs": [], "source": [ "model2_res = [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1]" ] }, { "cell_type": "markdown", "metadata": { "id": "fL7LdlICwea5" }, "source": [ "Рассчитаем точность Accuracy (см. лекцию 1) для этих моделей:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HtCdfZiAwea5", "outputId": "f746ccf9-77cd-459c-ba50-5e325dea558e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy for model1: 0.7619\n", "Accuracy for model2: 0.7619\n" ] } ], "source": [ "print(f\"Accuracy for model1: {accuracy_score(real_labels, model1_res):.4f}\")\n", "print(f\"Accuracy for model2: {accuracy_score(real_labels, model2_res):.4f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qPtK-Kj2wea6" }, "source": [ "Accuracy нельзя использовать, если данные не сбалансированы. Для несбалансированных данных необходимо использовать свои метрики и модели. Одной из таких метрик является Balanced accuracy. При вычислении данной метрики считается полнота (recall) отдельно для каждого класса и вычисляется среднее значение:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bbc2nS4Mwea6", "outputId": "08c6f75f-8311-48ca-e991-bab7765ebb0a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Balanced accuracy for model1: 0.500\n", "Balanced accuracy for model2: 0.775\n" ] } ], "source": [ "# Balanced accuracy for model1 = (16/16+0/5)/2 = 0.5\n", "print(\n", " f\"Balanced accuracy for model1: {balanced_accuracy_score(real_labels, model1_res):.3f}\"\n", ")\n", "# Balanced accuracy for model2 = (12/16+4/5)/2 = 0.775\n", "print(\n", " f\"Balanced accuracy for model2: {balanced_accuracy_score(real_labels, model2_res):.3f}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "IwCToaBTwea6" }, "source": [ "**Всегда проверяйте**, являются ли ваши данные сбалансированными и могут ли выбранные для оценки модели метрики работать с несбалансированными классами." ] }, { "cell_type": "markdown", "metadata": { "id": "QJBLu7X9wea6" }, "source": [ "Загрузим датасет с различными биомаркерами пациентов с меланомой (обезличенный, информации о пациентах нет) и переменной, содержащей 1, если пациент ответил на иммунотерапию (терапия помогла пациенту и произошло уменьшение размеров опухоли), и 0, если не ответил. Количество пациентов, отвечающих на терапию, сильно меньше пациентов, которым терапия не помогает, поэтому предсказание ответа пациента на терапию на основании биомаркеров — актуальная задача в онкологии. В данном задании вам предстоит попробовать её решить." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 433 }, "id": "biJaDT0Bwea7", "outputId": "0400e0e1-6d28-4b28-b8e6-94ad805e414e" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "summary": "{\n \"name\": \"display(y\",\n \"rows\": 5,\n \"fields\": [\n {\n \"column\": \"sample_id\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 5,\n \"samples\": [\n \"SAMd215b503f99a\",\n \"SAMc0da5d48686d\",\n \"SAM7fb6987514a4\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"IgG1/IgA\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 4.336516131030748,\n \"min\": 2.1390162680442666,\n \"max\": 12.614971518429773,\n \"num_unique_values\": 5,\n \"samples\": [\n 2.1390162680442666,\n 2.764088601964564,\n 12.614971518429773\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"IL21\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.00380889617604994,\n \"min\": -0.0001388498153087,\n \"max\": 0.0081030209571419,\n \"num_unique_values\": 5,\n \"samples\": [\n -8.893814644625583e-05,\n 0.0061065875078736,\n 0.0081030209571419\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"CXCL9\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.21738547329753688,\n \"min\": -0.0029863401619796,\n \"max\": 0.5020432006016782,\n \"num_unique_values\": 5,\n \"samples\": [\n 0.0304946704036539,\n 0.0155327165946915,\n 0.5020432006016782\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"CXCL10\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.211283092537428,\n \"min\": -0.0363664642116282,\n \"max\": 0.5307833290514908,\n \"num_unique_values\": 5,\n \"samples\": [\n 0.2439575071974413,\n 0.1354700642781154,\n 0.5307833290514908\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"CD8A\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.13269648200533504,\n \"min\": 0.0676862641834773,\n \"max\": 0.3884550403564303,\n \"num_unique_values\": 5,\n \"samples\": [\n 0.161127818839325,\n 0.0676862641834773,\n 0.3884550403564303\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"GZMB\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.248904631338728,\n \"min\": 0.0534987504763072,\n \"max\": 0.5657980098838621,\n \"num_unique_values\": 5,\n \"samples\": [\n 0.5657980098838621,\n 0.0534987504763072,\n 0.5281416626901745\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"KLRC2\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.3373884704161336,\n \"min\": -0.3873725336312784,\n \"max\": 0.5020582002204187,\n \"num_unique_values\": 5,\n \"samples\": [\n -0.2034945826631211,\n -0.1160401526617025,\n -0.1562088546043653\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"KLRC3\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.057030752132511374,\n \"min\": -0.0838623550923506,\n \"max\": 0.0637136519258656,\n \"num_unique_values\": 5,\n \"samples\": [\n -0.0269020883054512,\n 0.0637136519258656,\n 0.0011469040291496\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"KLRC4\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.05427963946515479,\n \"min\": -0.035405019005983,\n \"max\": 0.0882013027712291,\n \"num_unique_values\": 5,\n \"samples\": [\n -0.035405019005983,\n 0.0882013027712291,\n -0.0286895144329407\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"GNLY\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.08871351180542986,\n \"min\": 0.0301251118618957,\n \"max\": 0.2607031139685462,\n \"num_unique_values\": 5,\n \"samples\": [\n 0.0301251118618957,\n 0.0829399689765805,\n 0.2607031139685462\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"TGFB1\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 26.045607610224625,\n \"min\": 53.55281726755536,\n \"max\": 114.4229256006832,\n \"num_unique_values\": 5,\n \"samples\": [\n 103.26583725635012,\n 114.4229256006832,\n 53.55281726755536\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Response\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 0,\n \"num_unique_values\": 1,\n \"samples\": [\n 0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", "type": "dataframe" }, "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IgG1/IgAIL21CXCL9CXCL10CD8AGZMBKLRC2KLRC3KLRC4GNLYTGFB1Response
sample_id
SAM4b0175e8db6e3.2427460.001280-0.002986-0.0363660.0966580.0634670.502058-0.0838620.0536590.09193061.9341190
SAMd215b503f99a2.139016-0.0000890.0304950.2439580.1611280.565798-0.203495-0.026902-0.0354050.030125103.2658370
SAM7fb6987514a412.6149720.0081030.5020430.5307830.3884550.528142-0.1562090.001147-0.0286900.26070353.5528170
SAMd636e34619556.365973-0.0001390.0240350.1151270.0844550.200038-0.387373-0.0578370.0459380.07319280.8373180
SAMc0da5d48686d2.7640890.0061070.0155330.1354700.0676860.053499-0.1160400.0637140.0882010.082940114.4229260
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "
\n", "
\n" ], "text/plain": [ " IgG1/IgA IL21 CXCL9 CXCL10 CD8A GZMB \\\n", "sample_id \n", "SAM4b0175e8db6e 3.242746 0.001280 -0.002986 -0.036366 0.096658 0.063467 \n", "SAMd215b503f99a 2.139016 -0.000089 0.030495 0.243958 0.161128 0.565798 \n", "SAM7fb6987514a4 12.614972 0.008103 0.502043 0.530783 0.388455 0.528142 \n", "SAMd636e3461955 6.365973 -0.000139 0.024035 0.115127 0.084455 0.200038 \n", "SAMc0da5d48686d 2.764089 0.006107 0.015533 0.135470 0.067686 0.053499 \n", "\n", " KLRC2 KLRC3 KLRC4 GNLY TGFB1 Response \n", "sample_id \n", "SAM4b0175e8db6e 0.502058 -0.083862 0.053659 0.091930 61.934119 0 \n", "SAMd215b503f99a -0.203495 -0.026902 -0.035405 0.030125 103.265837 0 \n", "SAM7fb6987514a4 -0.156209 0.001147 -0.028690 0.260703 53.552817 0 \n", "SAMd636e3461955 -0.387373 -0.057837 0.045938 0.073192 80.837318 0 \n", "SAMc0da5d48686d -0.116040 0.063714 0.088201 0.082940 114.422926 0 " ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Number of patients responded to immunotherapy:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
count
Response
0228
137
\n", "

" ], "text/plain": [ "Response\n", "0 228\n", "1 37\n", "Name: count, dtype: int64" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "cancer = pd.read_table(\n", " \"https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/Cancer_dataset_2.tsv\",\n", " index_col=\"sample_id\",\n", ")\n", "display(cancer.head())\n", "\n", "# split the data on features (x) and dependant variable (y)\n", "y = cancer[\"Response\"]\n", "x = cancer.drop(\"Response\", axis=1)\n", "print(\"\\nNumber of patients responded to immunotherapy:\")\n", "display(y.value_counts())" ] }, { "cell_type": "markdown", "metadata": { "id": "bO6Az3A1wea7" }, "source": [ "В данном случае имеет место несбалансированность классов в наборе данных: пациентов, ответивших на терапию, гораздо меньше.\n", "\n", "Есть два способа работы с несбалансированными по классам данными. Первый способ — это получение стратифицированных выборок. Необходимо иметь одинаковую долю образцов каждого класса в тренировочной и тестовой выборках, иначе возникает риск получения смещённых выборок, что приводит к некорректной оценке качества модели. Второй способ — это использование специальных алгоритмов, учитывающих несбалансированность классов." ] }, { "cell_type": "markdown", "metadata": { "id": "NY_xlQmKwea7" }, "source": [ "В данном задании вам нужно продемонстрировать эффективность различных подходов работы с несбалансированными выборками. Для этого вы будете использовать три модели, представленные ниже:\n", "\n", "1. [[doc] 🛠️](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) `RandomForestClassifier`, библиотека sklearn\n", "2. [[doc] 🛠️](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) `RandomForestClassifier` с балансировкой классов, библиотека sklearn — меняет стандартный вес каждого класса, равный 1, на долю класса во входных данных (см. `class_weight`).\n", "3. [[doc] 🛠️](https://imbalanced-learn.org/stable/references/generated/imblearn.ensemble.BalancedRandomForestClassifier.html) `BalancedRandomForestClassifier`, библиотека imblearn — семплирует псевдовыборки таким образом, что в каждой псевдовыборке, которая подается на вход модели, баланс классов оказывается \"выправлен\"." ] }, { "cell_type": "markdown", "metadata": { "id": "QOEvidjdwea7" }, "source": [ "Оцените эффективность подходов с помощью кросс-валидации, производя разбиение с учетом репрезентации классов и без него. В качестве метрики, отображающей эффективность модели, используйте значения `accuracy` и `balanced_accuracy`. Проинтерпретируйте результаты." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a4ZBBGJ7wea8" }, "outputs": [], "source": [ "from imblearn.ensemble import BalancedRandomForestClassifier" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "d1JMRGzFwea_", "outputId": "f3b1181f-c44d-43af-96c4-80b632343674" }, "outputs": [ { "data": { "text/plain": [ "5" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "skf = StratifiedKFold(n_splits=5, shuffle=True)\n", "skf.get_n_splits(x, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "20iSJ6xlJ3mL" }, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "markdown", "metadata": { "id": "X4v2c9W6wea_" }, "source": [ "Объекты, принадлежащие разным классам, распределены неравномерно. Для адекватной работы `cross_validate` нужно перемешать данные. Для этого используйте флаг `shuffle=True`, применяя `KFold` и `StratifiedKFold` (см. параметр `cv` в функции `cross_validate`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "qG34XBBNwebA", "jupyter": { "outputs_hidden": true } }, "outputs": [], "source": [ "# Your code here\n", "cv_rf_acc = cross_validate(estimator=RandomForestClassifier(n_estimators=100), X=x, y=y, cv=skf, scoring='accuracy')\n", "cv_rf_weighted_acc = cross_validate(estimator=RandomForestClassifier(n_estimators=100, class_weight='balanced'), X=x, y=y, cv=skf, scoring='accuracy')\n", "cv_bal_rf_acc = cross_validate(estimator=BalancedRandomForestClassifier(n_estimators=100), X=x, y=y, cv=skf, scoring='accuracy')\n", "\n", "cv_rf_bal = cross_validate(estimator=RandomForestClassifier(n_estimators=100), X=x, y=y, cv=skf, scoring='balanced_accuracy')\n", "cv_rf_weighted_bal = cross_validate(estimator=RandomForestClassifier(n_estimators=100, class_weight='balanced'), X=x, y=y, cv=skf, scoring='balanced_accuracy')\n", "cv_bal_rf_bal = cross_validate(estimator=BalancedRandomForestClassifier(n_estimators=100), X=x, y=y, cv=skf, scoring='balanced_accuracy')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lqFDshlNIedM", "outputId": "ab03d88c-609e-4370-b2f9-1e343c041f6f" }, "outputs": [ { "data": { "text/plain": [ "{'fit_time': array([0.28391361, 0.34531474, 0.30177093, 0.28197789, 0.28857279]),\n", " 'score_time': array([0.01238203, 0.01297426, 0.0112915 , 0.01253176, 0.01142955]),\n", " 'test_score': array([0.88679245, 0.86792453, 0.8490566 , 0.81132075, 0.8490566 ])}" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cv_rf_acc" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 371 }, "id": "zOM1_-2TIzUp", "outputId": "4ceb4650-c66d-44e2-cc7e-c277164b72be" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import pandas as pd\n", "\n", "\n", "plt.figure(figsize=(16, 4))\n", "sns.boxplot(\n", " data=pd.DataFrame(\n", " {\n", " \"RandomForest_acc\": cv_rf_acc['test_score'],\n", " \"RandomForest_weighted_acc\": cv_rf_weighted_acc['test_score'],\n", " \"BalancedRandFor_acc\": cv_bal_rf_acc['test_score'],\n", " \"RandomForest_bal\": cv_rf_bal['test_score'],\n", " \"RandomForest_weighted_bal\": cv_rf_weighted_bal['test_score'],\n", " \"BalancedRandFor_bal\": cv_bal_rf_bal['test_score']\n", " }\n", " )\n", ")\n", "plt.ylabel(\"Accuracies \", size=20)\n", "plt.tick_params(axis=\"both\", which=\"major\", labelsize=10)\n", "plt.yticks(fontsize=14)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "10cFP9gNM8qy", "outputId": "97f2afd8-4ccd-4dd6-b6ce-5a6e10abb4ff" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy\n", "RFC RFC_weighted BalRFC\n", "0.8528301886792452 0.025031130493248295\n", "0.8641509433962264 0.007547169811320753\n", "0.7056603773584905 0.06381711141618024\n", "\n", "\n", "Balanced Accuracy\n", "RFC RFC_weighted BalRFC\n", "0.5421204278812974 0.06459732294998131\n", "0.4955555555555556 0.008888888888888878\n", "0.667123878536922 0.07672751479997694\n" ] } ], "source": [ "print('Accuracy')\n", "print(*['RFC', 'RFC_weighted','BalRFC'] )\n", "print(cv_rf_acc['test_score'].mean(), cv_rf_acc['test_score'].std())\n", "print(cv_rf_weighted_acc['test_score'].mean(), cv_rf_weighted_acc['test_score'].std())\n", "print(cv_bal_rf_acc['test_score'].mean(), cv_bal_rf_acc['test_score'].std())\n", "print('\\n')\n", "print('Balanced Accuracy')\n", "print(*['RFC', 'RFC_weighted','BalRFC'] )\n", "print(cv_rf_bal['test_score'].mean(), cv_rf_bal['test_score'].std())\n", "print(cv_rf_weighted_bal['test_score'].mean(), cv_rf_weighted_bal['test_score'].std())\n", "print(cv_bal_rf_bal['test_score'].mean(), cv_bal_rf_bal['test_score'].std())" ] }, { "cell_type": "markdown", "metadata": { "id": "3ZobGaqLwebB" }, "source": [ "Какая модель лучше справляется с дисбалансом классов?\n", "\n", "**Напишите вывод**\n", "\n", "Первые три результата получены по accuracy (не баланс). Следующие три по balanced_accuracy.\n", "\n", "**Лучше всего с дисбалансом классов справилась модель `BalancedRandomForest`** из imblearn.ensemble.\n", "\n", "* Результаты\n", "\n", "1. `RandomForestClassifier`, библиотека sklearn;\n", " * Accuracy Mean 0.85\n", " * Accuracy STD 0.02\n", " * Balanced Accuracy Mean 0.54\n", " * Balanced Accuracy STD 0.06\n", " \n", "2. `RandomForestClassifier` с балансировкой классов, библиотека sklearn;\n", " * Accuracy Mean 0.86\n", " * Accuracy STD 0.007\n", " * Balanced Accuracy Mean 0.49\n", " * Balanced Accuracy STD 0.008\n", "3. `BalancedRandomForestClassifier`, библиотека imblearn.\n", " * Accuracy Mean 0.70\n", " * Accuracy STD 0.063\n", " * Balanced Accuracy Mean 0.66\n", " * Balanced Accuracy STD 0.07\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "oIl8corLwebD" }, "source": [ "## Формат результата" ] }, { "cell_type": "markdown", "metadata": { "id": "vICuS2W0webD" }, "source": [ "Получить значения `accuracy` и `balanced_accuracy`, оцененные на кросс-валидации с учетом стратификации по классам и без, для моделей:\n", "1. `RandomForestClassifier`, библиотека sklearn;\n", "2. `RandomForestClassifier` с балансировкой классов, библиотека sklearn;\n", "3. `BalancedRandomForestClassifier`, библиотека imblearn." ] }, { "cell_type": "markdown", "metadata": { "id": "xZ2pCvlfwebD" }, "source": [ "# Задание 3. Разные типы бустингов" ] }, { "cell_type": "markdown", "metadata": { "id": "Ik3gvkBCwebD" }, "source": [ "В этом задании будем использовать датасет с рейтингом блюд по некоторым характеристикам.\n", "\n", "В некоторых реализациях градиентного бустинга есть возможность использовать другой метод обучения. Например, в XGB есть тип `dart`, а в LGBM — `goss`. Это позволяет составлять более эффективные ансамбли.\n", "\n", "Используя кросс-валидацию (используйте 3 фолда), обучите модели:\n", "* `CatBoostRegressor`\n", "* `XGBRegressor`\n", "* `LGBMRegressor`\n", "\n", "Сохраните модель на каждом фолде и посчитайте `mse` для тестовой выборки, используя модель с каждого фолда. Получите предсказания всех 9 моделей на тестовой выборке и усредните их. Затем посчитайте `mse` для усредненных предсказаний.\n", "\n", "Напишите выводы о полученном качестве моделей." ] }, { "cell_type": "markdown", "metadata": { "id": "NCvzeTTWwebE" }, "source": [ "Установка и импорт необходимых библиотек:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CvQEk0vawebE", "outputId": "83297572-6d0d-479a-e98f-be50263463b6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.7/98.7 MB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install -q catboost" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SFRi2aqzwebE" }, "outputs": [], "source": [ "import xgboost\n", "import catboost\n", "import lightgbm\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.metrics import mean_squared_error as mse\n", "from sklearn.model_selection import train_test_split, KFold" ] }, { "cell_type": "markdown", "metadata": { "id": "wSXmaZuPwebF" }, "source": [ "Загрузка датасета:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "id": "8r1aYLl3webF", "outputId": "e0af53b7-2627-479a-d8c7-1a8284733393" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "summary": "{\n \"name\": \"recipies\",\n \"rows\": 15864,\n \"fields\": [\n {\n \"column\": \"calories\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 359848.41786830855,\n \"min\": 0.0,\n \"max\": 30111218.0,\n \"num_unique_values\": 1858,\n \"samples\": [\n 156.0,\n 765.0,\n 1807.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"protein\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 3843.4623117544525,\n \"min\": 0.0,\n \"max\": 236489.0,\n \"num_unique_values\": 282,\n \"samples\": [\n 81.0,\n 91.0,\n 200210.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"fat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 20459.329548921767,\n \"min\": 0.0,\n \"max\": 1722763.0,\n \"num_unique_values\": 326,\n \"samples\": [\n 1007.0,\n 221495.0,\n 313.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"sodium\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 334042.078448393,\n \"min\": 0.0,\n \"max\": 27675110.0,\n \"num_unique_values\": 2433,\n \"samples\": [\n 201.0,\n 2362.0,\n 1414.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"cakeweek\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.019444680844068227,\n \"min\": 0.0,\n \"max\": 1.0,\n \"num_unique_values\": 2,\n \"samples\": [\n 1.0,\n 0.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"wasteless\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.007939509074046341,\n \"min\": 0.0,\n \"max\": 1.0,\n \"num_unique_values\": 2,\n \"samples\": [\n 1.0,\n 0.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"rating\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.2855179766910894,\n \"min\": 0.0,\n \"max\": 5.0,\n \"num_unique_values\": 8,\n \"samples\": [\n 4.375,\n 5.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", "type": "dataframe", "variable_name": "recipies" }, "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
caloriesproteinfatsodiumcakeweekwastelessrating
0426.030.07.0559.00.00.02.500
1403.018.023.01439.00.00.04.375
2165.06.07.0165.00.00.03.750
3547.020.032.0452.00.00.03.125
4948.019.079.01042.00.00.04.375
........................
1585928.02.02.064.00.00.03.125
15860671.022.028.0583.00.00.04.375
15861563.031.038.0652.00.00.04.375
15862631.045.024.0517.00.00.04.375
15863560.073.010.03698.00.00.04.375
\n", "

15864 rows × 7 columns

\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "
\n", "
\n" ], "text/plain": [ " calories protein fat sodium cakeweek wasteless rating\n", "0 426.0 30.0 7.0 559.0 0.0 0.0 2.500\n", "1 403.0 18.0 23.0 1439.0 0.0 0.0 4.375\n", "2 165.0 6.0 7.0 165.0 0.0 0.0 3.750\n", "3 547.0 20.0 32.0 452.0 0.0 0.0 3.125\n", "4 948.0 19.0 79.0 1042.0 0.0 0.0 4.375\n", "... ... ... ... ... ... ... ...\n", "15859 28.0 2.0 2.0 64.0 0.0 0.0 3.125\n", "15860 671.0 22.0 28.0 583.0 0.0 0.0 4.375\n", "15861 563.0 31.0 38.0 652.0 0.0 0.0 4.375\n", "15862 631.0 45.0 24.0 517.0 0.0 0.0 4.375\n", "15863 560.0 73.0 10.0 3698.0 0.0 0.0 4.375\n", "\n", "[15864 rows x 7 columns]" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "recipies = pd.read_csv(\n", " \"https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/recipes.csv\"\n", ")\n", "recipies" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SvgVNo6NwebF" }, "outputs": [], "source": [ "y = recipies[\"rating\"]\n", "x = recipies.drop([\"rating\"], axis=1)\n", "\n", "x_train_all, x_test, y_train_all, y_test = train_test_split(\n", " x.values, y.values, train_size=0.7, random_state=42\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BCoU_bs2RMmg" }, "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error\n", "\n", "\n", "def train_and_test_regressor(models, x_train, y_train, x_test, y_test, verb=True):\n", " boot_scores = {}\n", " for name, model in models.items():\n", " model.fit(x_train, y_train) # train the model\n", " y_pred = model.predict(x_test) # get predictions\n", " boot_scores[name] = bootstrap_metric( # calculate bootstrap score\n", " y_test,\n", " y_pred,\n", " metric_fn=mean_squared_error,\n", " )\n", " if verb:\n", " print(f\"Fitted {name} with bootstrap score {boot_scores[name].mean():.3f}\")\n", "\n", " results = pd.DataFrame(boot_scores)\n", "\n", " return results\n", "\n", "\n", "# results_rf = train_and_test_regressor(models_rf, x_train, y_train, x_test, y_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9J2LMa9XwebF" }, "outputs": [], "source": [ "# Your code here\n", "models = {}\n", "models['cat'] = catboost.CatBoostRegressor(verbose=0)\n", "models['xgb'] = xgboost.XGBRFRegressor()\n", "models['lgbm'] = lightgbm.LGBMRegressor()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "40UIWsutRwvO", "outputId": "70cd0b61-e4f9-4af4-ff1b-38b744daf48f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitted cat with bootstrap score 1.532\n", "Fitted xgb with bootstrap score 1.535\n", "Fitted lgbm with bootstrap score 1.561\n" ] } ], "source": [ "results_preds = train_and_test_regressor(models, x_train_all, y_train_all, x_test, y_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "x4G9a9jXS3Ma", "outputId": "151c19e5-4cab-46af-bcf4-42f22627a72d" }, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "kfold = KFold(n_splits=3)\n", "kfold.get_n_splits(x_train_all, y_train_all)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-VgLsdPrSlmS" }, "outputs": [], "source": [ "results_cv = {}\n", "for name, model in models.items():\n", " results_cv[name] = cross_validate(estimator=model, X=x_train_all, y=y_train_all, cv=kfold, scoring='neg_mean_squared_error')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tCwnho3IUHbI", "outputId": "f4eef433-bda3-4089-cf4e-c4e84e8f42d7" }, "outputs": [ { "data": { "text/plain": [ "{'cat': {'fit_time': array([7.03786325, 8.02481794, 5.47648811]),\n", " 'score_time': array([0.02814388, 0.05217481, 0.01294518]),\n", " 'test_score': array([-1.47398748, -1.60308341, -1.5522741 ])},\n", " 'xgb': {'fit_time': array([0.14803743, 0.14114594, 0.14647412]),\n", " 'score_time': array([0.00994015, 0.00969887, 0.01481056]),\n", " 'test_score': array([-1.46576047, -1.59082778, -1.53800993])},\n", " 'lgbm': {'fit_time': array([0.08506894, 0.08238196, 0.08508801]),\n", " 'score_time': array([0.0242455 , 0.02546668, 0.02399921]),\n", " 'test_score': array([-1.49344309, -1.61136957, -1.55886835])}}" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_cv" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 391 }, "id": "DB-6wXTgVo6g", "outputId": "9914db8a-304f-4b66-f1b2-1a4160b72f47" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(16, 4))\n", "sns.boxplot(\n", " data=pd.DataFrame(\n", " {\n", " \"CatBoost\": results_cv['cat']['test_score'],\n", " \"XGB\": results_cv['xgb']['test_score'],\n", " \"LGBM\": results_cv['lgbm']['test_score'],\n", " }\n", " )\n", ")\n", "plt.ylabel(\"Mean Squared Error\", size=20)\n", "plt.tick_params(axis=\"both\", which=\"major\", labelsize=10)\n", "plt.yticks(fontsize=14)\n", "plt.title('Кросс-валидация')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 391 }, "id": "SUpQZlYSWR8L", "outputId": "03475a3d-38bb-4bae-9834-e43c8374a3a8" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(16, 4))\n", "sns.boxplot(\n", " data=pd.DataFrame(\n", " {\n", " \"CatBoost\": results_preds['cat'],\n", " \"XGB\": results_preds['xgb'],\n", " \"LGBM\": results_preds['lgbm'],\n", " }\n", " )\n", ")\n", "plt.ylabel(\"Mean Squared Error\", size=20)\n", "plt.tick_params(axis=\"both\", which=\"major\", labelsize=10)\n", "plt.yticks(fontsize=14)\n", "plt.title('Предсказания на тесте')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "88yFFrNIX1s9", "outputId": "e2f05349-d870-4829-834f-e869cc613823" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cat mean: 1.53 \t std: 0.05\n", "XGB mean: 1.54 \t std: 0.05\n", "LGBM mean: 1.56 \t std: 0.05\n" ] } ], "source": [ "print(f'Cat mean: {results_preds[\"cat\"].mean():.2f} \\t std: {results_preds[\"cat\"].std():.2f}')\n", "print(f'XGB mean: {results_preds[\"xgb\"].mean():.2f} \\t std: {results_preds[\"xgb\"].std():.2f}')\n", "print(f'LGBM mean: {results_preds[\"lgbm\"].mean():.2f} \\t std: {results_preds[\"lgbm\"].std():.2f}')" ] }, { "cell_type": "markdown", "metadata": { "id": "n0iSxHR4XaaV" }, "source": [ "* Результаты\n", "\n", "На кроссвалидации лучше всего себя показал XGBoost, а за ним шёл CatBoost\n", "\n", "При проверке на тестах лучше всего оказался CatBoost.\n", "Стандартное отклонение у моделей идентичное, поэтому нет никаких причин не выбрать CatBoost\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "xJDgfinEwebG" }, "source": [ "## Формат результата" ] }, { "cell_type": "markdown", "metadata": { "id": "MXG4PahWwebG" }, "source": [ "Получить значения MSE для всех моделей и среднее значение MSE по предсказаниям всех моделей. Написать вывод.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "me1sL619webH" }, "source": [ "# Задание 4. Подбор гиперпараметров" ] }, { "cell_type": "markdown", "metadata": { "id": "vtDysuP8webH" }, "source": [ "В этом задании нужно подобрать параметры для бустинга `CatBoostRegressor`, используя библиотеку `optuna`. И улучшить результат по сравнению со стандартными параметрами.\n", "\n", "Список параметров для подбора:\n", "\n", "* `depth`\n", "* `iterations`\n", "* `learning_rate`\n", "* `colsample_bylevel`\n", "* `subsample`\n", "* `l2_leaf_reg`\n", "* `min_data_in_leaf`\n", "* `max_bin`\n", "* `random_strength`\n", "* `bootstrap_type`\n", "\n", "**Важно!** *Подбирать параметры нужно на валидационной выборке*" ] }, { "cell_type": "markdown", "metadata": { "id": "YffT7WXtwebI" }, "source": [ "Установка и импорт необходимых библиотек:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-10-24T14:04:18.891446Z", "iopub.status.busy": "2024-10-24T14:04:18.890257Z", "iopub.status.idle": "2024-10-24T14:04:47.474778Z", "shell.execute_reply": "2024-10-24T14:04:47.473762Z", "shell.execute_reply.started": "2024-10-24T14:04:18.891398Z" }, "id": "fwBhoRnSwebI", "outputId": "0bcdfb29-aec9-41cd-e96e-52cd2626fbe9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3 -m pip install --upgrade pip\u001b[0m\n", "\u001b[33m WARNING: The script mako-render is installed in '/home/jupyter/.local/bin' which is not on PATH.\n", " Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33m WARNING: The script alembic is installed in '/home/jupyter/.local/bin' which is not on PATH.\n", " Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33m WARNING: The script optuna is installed in '/home/jupyter/.local/bin' which is not on PATH.\n", " Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3 -m pip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "%pip install -q catboost\n", "%pip install -q optuna" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-10-24T14:04:47.477221Z", "iopub.status.busy": "2024-10-24T14:04:47.476617Z", "iopub.status.idle": "2024-10-24T14:04:50.564929Z", "shell.execute_reply": "2024-10-24T14:04:50.564004Z", "shell.execute_reply.started": "2024-10-24T14:04:47.477180Z" }, "id": "E6LCZl3KwebJ" }, "outputs": [], "source": [ "import optuna\n", "import numpy as np\n", "import pandas as pd\n", "from catboost import CatBoostRegressor\n", "from optuna.samplers import RandomSampler\n", "from sklearn.metrics import mean_squared_error as mse\n", "from sklearn.model_selection import train_test_split, KFold" ] }, { "cell_type": "markdown", "metadata": { "id": "lwIlJfOOwebJ" }, "source": [ "Загрузка датасета:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-10-24T14:05:18.696722Z", "iopub.status.busy": "2024-10-24T14:05:18.695630Z", "iopub.status.idle": "2024-10-24T14:05:18.865997Z", "shell.execute_reply": "2024-10-24T14:05:18.864897Z", "shell.execute_reply.started": "2024-10-24T14:05:18.696671Z" }, "id": "67DlWdjswebJ" }, "outputs": [], "source": [ "recipies = pd.read_csv(\n", " \"https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/recipes.csv\"\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-10-24T14:05:19.014400Z", "iopub.status.busy": "2024-10-24T14:05:19.013302Z", "iopub.status.idle": "2024-10-24T14:05:19.039410Z", "shell.execute_reply": "2024-10-24T14:05:19.038399Z", "shell.execute_reply.started": "2024-10-24T14:05:19.014346Z" }, "id": "SfA8jFCIwebK" }, "outputs": [], "source": [ "y = recipies[\"rating\"]\n", "x = recipies.drop([\"rating\"], axis=1)\n", "\n", "x_train_all, x_test, y_train_all, y_test = train_test_split(\n", " x.values, y.values, train_size=0.7, random_state=42\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-10-24T14:05:20.014627Z", "iopub.status.busy": "2024-10-24T14:05:20.013246Z", "iopub.status.idle": "2024-10-24T14:05:20.448768Z", "shell.execute_reply": "2024-10-24T14:05:20.447766Z", "shell.execute_reply.started": "2024-10-24T14:05:20.014574Z" }, "id": "ipA3ufjNwebK", "outputId": "2fab02dc-9cd8-4744-c533-1ccfee1d7d6d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Learning rate set to 0.074308\n", "0:\tlearn: 1.2817437\ttest: 1.2774827\tbest: 1.2774827 (0)\ttotal: 52.4ms\tremaining: 52.3s\n", "Stopped by overfitting detector (100 iterations wait)\n", "\n", "bestTest = 1.242760353\n", "bestIteration = 44\n", "\n", "Shrink model to first 45 iterations.\n", "\n", "mse_score before tuning: 1.5445\n" ] } ], "source": [ "model = CatBoostRegressor(random_seed=42)\n", "\n", "model.fit(\n", " x_train_all,\n", " y_train_all,\n", " eval_set=(x_test, y_test),\n", " verbose=200,\n", " use_best_model=True,\n", " plot=False,\n", " early_stopping_rounds=100,\n", ")\n", "\n", "print(f\"\\nmse_score before tuning: {mse(y_test, model.predict(x_test)):.4f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Hk8HXmDDbCTS", "outputId": "18d93de7-f4d6-4f10-c6b1-1bfad252f51f" }, "outputs": [ { "data": { "text/plain": [ "(11104, 6)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train_all.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-8GELU4-nrIi", "outputId": "85e9a7b4-b189-4830-87e6-b8cba13460e7" }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "1e-3 == 0.001" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-10-24T14:12:14.674241Z", "iopub.status.busy": "2024-10-24T14:12:14.673104Z" }, "id": "g_iO_zpawebK", "outputId": "4b7f321c-a436-4b99-f2f9-9dbb44cc74a8", "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[I 2024-10-24 14:12:14,683] A new study created in memory with name: Optimizer\n", "[I 2024-10-24 14:12:15,976] Trial 0 finished with value: 1.5329984654109692 and parameters: {'depth': 10, 'min_data_in_leaf': 4, 'l2_leaf_reg': 4.140000000000001, 'random_strength': 0.7396915762758474, 'iterations': 200, 'learning_rate': 0.03500152798750839, 'colsample_bylevel': 0.4, 'subsample': 1.0, 'max_bin': 198, 'bootstrap_type': 'Bernoulli'}. Best is trial 0 with value: 1.5329984654109692.\n" ] } ], "source": [ "# Your code here\n", "from optuna.samplers import TPESampler\n", "from sklearn.model_selection import cross_val_score, KFold\n", "\n", "# Define function which will optimized\n", "\n", "\n", "def objective(trial):\n", " # boundaries for the optimizer's\n", " depth = trial.suggest_int(\"depth\", 3, 15, step=1)\n", " min_data_in_leaf = trial.suggest_int(\"min_data_in_leaf\", 3, 10, step=1)\n", " l2_leaf_reg = trial.suggest_float(\"l2_leaf_reg\", 2, 8, step=0.01)\n", " random_strength = trial.suggest_float(\"random_strength\", 0.5, 2)\n", " iterations = trial.suggest_int(\"iterations\", 100, 1500, step=50)\n", " learning_rate = trial.suggest_float('learning_rate', 1e-3, 5e-2)\n", " colsample_bylevel = trial.suggest_float('colsample_bylevel', 0.1, 1., step=0.1) #step=0.01)\n", " subsample = trial.suggest_float('subsample', 0.2, 1, step=0.1)\n", " max_bin = trial.suggest_int('max_bin', 10, 255, step=1)\n", " bootstrap_type = trial.suggest_categorical('bootstrap_type', choices=['Bernoulli'])\n", " # params = {\n", " # 'depth': trial.suggest_int('depth', 3, 15),\n", " # 'iterations': trial.suggest_int('iterations', 100, 1000),\n", " # 'learning_rate': trial.suggest_loguniform('learning_rate', 1e-3, 0.3),\n", " # 'colsample_bylevel': trial.suggest_uniform('colsample_bylevel', 0.5, 1.0),\n", " # 'subsample': trial.suggest_uniform('subsample', 0.5, 1.0),\n", " # 'l2_leaf_reg': trial.suggest_int('l2_leaf_reg', 1, 10),\n", " # 'min_data_in_leaf': trial.suggest_int('min_data_in_leaf', 1, 10),\n", " # 'max_bin': trial.suggest_int('max_bin', 10, 255),\n", " # 'random_strength': trial.suggest_uniform('random_strength', 1, 10),\n", " # 'bootstrap_type': trial.suggest_categorical('bootstrap_type', ['No', 'Bernoulli', 'MVS'])\n", " # }\n", "\n", "\n", " # create new model(and all parameters) every iteration\n", " model = CatBoostRegressor(\n", " # **params,\n", " iterations=iterations,\n", " # iterations=100,\n", " learning_rate=learning_rate,\n", " depth=depth,\n", " min_data_in_leaf=min_data_in_leaf,\n", " l2_leaf_reg=l2_leaf_reg,\n", " random_strength=random_strength,\n", " colsample_bylevel=colsample_bylevel,\n", " subsample=subsample,\n", " max_bin=max_bin,\n", " bootstrap_type=bootstrap_type,\n", " random_state=42,\n", " verbose=0,\n", " early_stopping_rounds=50\n", " )\n", " kf = KFold(n_splits=3, shuffle=True, random_state=42)\n", " neg_mse = cross_val_score(\n", " model, x_train_all, y_train_all, cv=kf,\n", " scoring=\"neg_mean_squared_error\"\n", " ).mean()\n", " error = -neg_mse\n", "\n", " return error\n", "\n", "\n", "# Create \"exploration\"\n", "study = optuna.create_study(\n", " direction=\"minimize\", study_name=\"Optimizer\", sampler=TPESampler(42)\n", ")\n", "\n", "study.optimize(\n", " objective, n_trials=20\n", ") # The more iterations, the higher the chances of catching the most optimal hyperparameters\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.status.busy": "2024-10-24T14:11:58.929013Z", "iopub.status.idle": "2024-10-24T14:11:58.929437Z", "shell.execute_reply": "2024-10-24T14:11:58.929257Z", "shell.execute_reply.started": "2024-10-24T14:11:58.929236Z" } }, "outputs": [], "source": [ "study.best_params" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AXKpxM0rxzoH" }, "outputs": [], "source": [ "# x_train, x_val, y_train, y_val = train_test_split(x.values, y.values, test_size=0.2, random_state=42)\n", "from sklearn.metrics import mean\n", "\n", "\n", "def tuner(trial):\n", " params = {\n", " 'depth': trial.suggest_int('depth', 4, 10),\n", " 'iterations': trial.suggest_int('iterations', 100, 1000),\n", " 'learning_rate': trial.suggest_loguniform('learning_rate', 1e-3, 0.3),\n", " 'colsample_bylevel': trial.suggest_uniform('colsample_bylevel', 0.5, 1.0),\n", " 'subsample': trial.suggest_uniform('subsample', 0.5, 1.0),\n", " 'l2_leaf_reg': trial.suggest_int('l2_leaf_reg', 1, 10),\n", " 'min_data_in_leaf': trial.suggest_int('min_data_in_leaf', 1, 10),\n", " 'max_bin': trial.suggest_int('max_bin', 10, 255),\n", " 'random_strength': trial.suggest_uniform('random_strength', 1, 10),\n", " 'bootstrap_type': trial.suggest_categorical('bootstrap_type', ['No', 'Bernoulli', 'MVS'])\n", " }\n", "\n", "\n", " model = CatBoostRegressor(verbose=0)\n", " model.fit(x_train_all, y_train_all, eval_set=(x_test, y_test), early_stopping_rounds=50, use_best_model=True)\n", "\n", "\n", " preds = model.predict(x_test)\n", " mse = mean_squared_error(y_test, preds)\n", "\n", " return mse\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hlNoL30zyB_i" }, "outputs": [], "source": [ "study = optuna.create_study(direction='minimize')\n", "study.optimize(tuner, n_trials=20)\n", "\n", "\n", "print(\"best parameters: \", study.best_params)\n", "print(\"best MSE: \", study.best_value)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xgstohg3fBT9" }, "outputs": [], "source": [ "study.best_params, study.best_value" ] }, { "cell_type": "markdown", "metadata": { "id": "BMlLEJ18sLLW" }, "source": [ "* Тестирую потихоньку тут после подбора на каждых 9 трайлах из-за трейсбеков (лучший рез-т дальше)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qlsku72wqsAO" }, "outputs": [], "source": [ "tuned_model = CatBoostRegressor(random_seed=42, use_best_model=True, **study.best_params)\n", "\n", "tuned_model.fit(\n", " x_train_all,\n", " y_train_all,\n", " eval_set=(x_test, y_test),\n", " verbose=200,\n", " early_stopping_rounds=100\n", ")\n", "\n", "print(f\"\\nmse_score after tuning: {mse(y_test, tuned_model.predict(x_test)):.4f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "J7DsnE1xropw" }, "source": [ "* Лучший результат (MSE = 1.5374)\n", "\n", "* UPD. Лучший результат 1.5359 в предыдущей ячейке" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v-tuNEoDkAcX" }, "outputs": [], "source": [ "tuned_model = CatBoostRegressor(random_seed=42, use_best_model=True, **study.best_params)\n", "\n", "tuned_model.fit(\n", " x_train_all,\n", " y_train_all,\n", " eval_set=(x_test, y_test),\n", " verbose=200,\n", " early_stopping_rounds=100\n", ")\n", "\n", "print(f\"\\nmse_score after tuning: {mse(y_test, tuned_model.predict(x_test)):.4f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "eCpBJz1-webL" }, "source": [ "## Формат результата\n", "\n", "Значение `mse` с подобранными параметрами меньше, чем при стандартных параметрах." ] }, { "cell_type": "markdown", "metadata": { "id": "Tncs7-RLwebL" }, "source": [ "# Задание 5. Ансамблевое обучение (дополнительно)" ] }, { "cell_type": "markdown", "metadata": { "id": "9q-ofFGCwebL" }, "source": [ "В данной задаче вам нужно диагностировать сердечное заболевание у людей по медицинским показателям ([Heart Disease 🛠️[doc]](https://www.kaggle.com/datasets/cherngs/heart-disease-cleveland-uci))." ] }, { "cell_type": "markdown", "metadata": { "id": "cSCv91_5webL" }, "source": [ "Установка и импорт необходимых библиотек:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GtEtm4H-webM" }, "outputs": [], "source": [ "!pip install -q catboost\n", "!pip install -q lightgbm==3.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9onL70rhwebM" }, "outputs": [], "source": [ "import catboost\n", "import lightgbm\n", "import xgboost\n", "import sklearn\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "from sklearn.svm import SVC\n", "from sklearn.naive_bayes import GaussianNB\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import (\n", " train_test_split,\n", " cross_val_score,\n", " KFold,\n", ")\n", "from sklearn.ensemble import (\n", " RandomForestClassifier,\n", " ExtraTreesClassifier,\n", " VotingClassifier,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "DJUh_blZwebM" }, "source": [ "Загрузка датасета:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NAj8rvoGwebM" }, "outputs": [], "source": [ "heart_dataset = pd.read_csv(\n", " \"https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/heart.csv\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Q66b2s5owebN" }, "outputs": [], "source": [ "x = heart_dataset.drop(\"target\", axis=1)\n", "y = heart_dataset[\"target\"]\n", "x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=42)" ] }, { "cell_type": "markdown", "metadata": { "id": "EqOj73tzwebN" }, "source": [ "Обучите разнообразные классификаторы, приведенные ниже, а также ансамбль `VotingClassifier` из `sklearn.ensemble`, объединяющий эти классификаторы с помощью жесткого или мякого голосования (параметр `voting =` `\"hard\"` или `\"soft\"` соответственно). Оцените качество моделей с помощью кросс-валидации на тренировочном наборе, используя функцию `cross_val_score` и метрику `f1`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Atd1mHxMwebO" }, "outputs": [], "source": [ "rng = np.random.RandomState(42)\n", "\n", "dt = DecisionTreeClassifier(random_state=rng, max_depth=10, min_samples_leaf=10)\n", "rf = RandomForestClassifier(n_estimators=50, random_state=rng)\n", "etc = ExtraTreesClassifier(random_state=rng)\n", "knn = KNeighborsClassifier(n_neighbors=5, weights=\"distance\")\n", "svc_lin = SVC(kernel=\"linear\", probability=True, random_state=rng)\n", "svc_rbf = SVC(kernel=\"rbf\", probability=True, random_state=rng)\n", "cat = catboost.CatBoostClassifier(verbose=0, random_seed=42)\n", "lgbm = lightgbm.LGBMClassifier(random_state=42, verbose=-1)\n", "lgbm_rf = lightgbm.LGBMClassifier(\n", " boosting_type=\"rf\", subsample_freq=1, subsample=0.7, random_state=42, verbose=-1\n", ")\n", "xgb = xgboost.XGBClassifier(random_state=42)\n", "xgb_rf = xgboost.XGBRFClassifier(random_state=42)\n", "lr = LogisticRegression(solver=\"liblinear\", max_iter=10000)\n", "nb = GaussianNB()\n", "\n", "# Your code here\n", "\n", "voting_hard =\n", "voting_soft =\n", "# -----------\n", "\n", "\n", "for model in [voting_hard, voting_soft]:\n", " scores = cross_val_score(\n", " model,\n", " x_train,\n", " y_train,\n", " cv=KFold(n_splits=3, shuffle=True, random_state=rng),\n", " scoring=\"f1\",\n", " )\n", " print(f\"{model.__class__.__name__}: {scores.mean():.4f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "93PJbbhQwebO" }, "source": [ "Вы можете заметить, что ансамбль показывает хорошее, но не лучшее качество предсказания, попробуем его улучшить. Как вы знаете, ансамбли работают лучше, когда модели, входящие в них, не скоррелированы друг с другом. Определите корреляцию предсказаний базовых моделей в ансамбле на тренировочном наборе и удалите из ансамбля те модели, чьи предсказания будут сильнее коррелировать с остальными. Можете модифицировать функцию `base_model_pair_correlation` из лекции." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kCRlENI1webP" }, "outputs": [], "source": [ "# Your code here" ] }, { "cell_type": "markdown", "metadata": { "id": "i8owETDKwebP" }, "source": [ "Создайте новый ансамбль на исправленном наборе моделей и оцените его качество с помощью кросс-валидации на тренировочном наборе, используя функцию `cross_val_score` и метрику `f1`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UvNFICpwwebP" }, "outputs": [], "source": [ "# Your code here\n", "\n", "voting_hard_2 =\n", "voting_soft_2 =\n", "# ------------\n", "\n", "for model in [voting_hard_2, voting_soft_2]:\n", " scores = cross_val_score(\n", " model,\n", " x_train,\n", " y_train,\n", " cv=KFold(n_splits=3, shuffle=True, random_state=rng),\n", " scoring=\"f1\",\n", " )\n", " print(f\"{model.__class__.__name__}: {scores.mean():.4f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "nE4lZPMQwebR" }, "source": [ "Обучите все получившиеся модели на тренировочном наборе и испытайте их качество на тестовом наборе. Получилось ли у улучшенных версий ансамблевого классификатора превзойти базовые модели, входящие в него, и свои предыдущие версии?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yHJ0HaphwebR" }, "outputs": [], "source": [ "# Your code here" ] }, { "cell_type": "markdown", "metadata": { "id": "plBM3EorwebS" }, "source": [ "Какие ансамбли работают лучше? Всегда ли больше моделей значит лучше?\n", "\n", "**Напишите вывод**" ] }, { "cell_type": "markdown", "metadata": { "id": "bjCCek3DwebS" }, "source": [ "## Формат результата" ] }, { "cell_type": "markdown", "metadata": { "id": "_2DU6GQswebS" }, "source": [ "Получить значения качества для ансамблей и моделей." ] } ], "metadata": { "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "DataSphere Kernel", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }