{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "BL8OGgI9CmQy"
},
"source": [
"# Práctica 8\n",
"\n",
"Ahora es tu turno de trabajar con modelos de lenguaje y self-supervised learning.\n",
"\n",
"## Ejercicio obligatorio (5 puntos)\n",
"\n",
"El ejercicio que tenéis que realizar obligatoriamente consiste en usar el dataset de clasificación de texto que empleaste en la práctica anterior para construir un modelo siguiendo las instrucciones proporcionadas en el notebook de instrucciones. Además debes comparar los resultados obtenidos con los que lograste en la práctica anterior. \n",
"\n",
"## Ejercicio opcional 1 (0.5 puntos)\n",
"\n",
"Crea un nuevo espacio en HuggingFace con el nuevo modelo que has creado.\n",
"\n",
"## Ejercicio opcional 2 (4.5 puntos)\n",
"\n",
"El ejercicio opcional está centrado en investigar cómo utilizar self-supervised learning para clasificación de imágenes. Para ello crea un nuevo notebook donde reproduzcas los resultados de la parte obligatoria de la Práctica 1. A continuación realiza los siguientes ejercicios:\n",
"- Entrena el modelo desde cero (mira qué hace el parámetro `pretrained` del método `cnn_learner`. \n",
"- Usando la librería [Self Supervised Learning Fastai Extension](https://github.com/KeremTurgutlu/self_supervised) entrena tres modelos de self-supervised usando los algoritmos SimCLR, BYOL, y SwAV disponibles en dicha librería.\n",
"- A partir de los modelos de self-supervised entrenados en el paso anterior, crea nuevos modelos de clasificación. \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tijRfybvCmQ2"
},
"source": [
"Al finalizar, recuerda guardar los cambios en GitHub utilizando la opción Archivo -> Guardar una copia en GitHub.\n"
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "N3DDTYS45C9h"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ZDQ6rKcC2kq"
},
"source": [
"## Librerías\n",
"\n",
"Comenzamos actualizando la librería FastAI y descargando la librería datasets de HuggingFace. Al finalizar la instalación deberás reiniciar el kernel (menú Entorno de ejecución -> Reiniciar Entorno de ejecución)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "y_vFVAFzC2kq",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "8063282b-ccae-44b3-9357-9a88f23a2727"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m468.7/468.7 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m24.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m224.2/224.2 kB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.9/132.9 kB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.8/158.8 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m269.3/269.3 kB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.2/114.2 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
],
"source": [
"!pip install fastai -Uqq\n",
"!pip install datasets -Uqq"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JAo2DOLlC2kr"
},
"source": [
"Cargamos a continuación las librerías que necesitaremos en esta práctica que son la parte de procesado de lenguaje natural de la librería fastAI, la librería pandas, y la funcionalidad para cargar datasets de HuggingFace."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "-5sb8kBaC2ks"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"from fastai.text.all import *\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "p1aBx-Sk62AB"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Dataset\n",
"\n",
"Para este ejemplo vamos a usar el dataset [clinc_oos](https://huggingface.co/datasets/clinc_oos), un dataset para detectar 150 clases de intención en 10 dominios, el conjunto de datos contiene una etiqueta para la intención \"fuera del ámbito\".( Label Id | Label name | |--- |--- | | 0 | restaurant_reviews | | 1 | nutrition_info | | 2 | account_blocked | | 3 | oil_change_how | | 4 | time | | 5 | weather | | 6 | redeem_rewards | | 7 | interest_rate | | 8 | gas_type | | 9 | accept_reservations | | 10 | smart_home | | 11 | user_name | | 12 | report_lost_card | | 13 | repeat | | 14 | whisper_mode | | 15 | what_are_your_hobbies | | 16 | order | | 17 | jump_start | | 18 | schedule_meeting | | 19 | meeting_schedule | | 20 | freeze_account | | 21 | what_song | | 22 | meaning_of_life | | 23 | restaurant_reservation | | 24 | traffic | | 25 | make_call | | 26 | text | | 27 | bill_balance ..)\n",
"Descarga el dataset usando el siguiente comando. "
],
"metadata": {
"id": "V3SGfSUaEHv0"
}
},
{
"cell_type": "code",
"source": [
"clinc_oos_dataset = load_dataset(\"clinc_oos\",'small')"
],
"metadata": {
"id": "ephhRD5rDnCd",
"outputId": "999b8928-7ceb-4f17-bd1c-f94d92193fd4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 244,
"referenced_widgets": [
"acea7c14a43a4d98ac3fa7d55bcf6ee6",
"2da90ab6b0f54d49bd1c97944d7aa83e",
"bc81e5b43dac454a9b043edcbe8bdd22",
"a8520126760f425cb07c030d6ea2c273",
"dc6a959e6b694c5cbc8a539272f29ad7",
"b52ea6a60f84468d982c195927cbb7ef",
"11e54992e4fd4d40bae190ab927514b7",
"50b6f98995b940c0881aef1870ade8ab",
"ffdd49432e9744f7bd5c23bff0244b54",
"ba2477e9ee9f48c8be6f2f74f8bc2cd3",
"0b28d318f69844d88afd25bd867d827e",
"a03bc89197ea4573b0f6784ca5b4e80b",
"1b0a75b3b78b468aa8e75f3074d58dfb",
"603fffb705344e41bbf2cb31fc1e5bfc",
"e74351f03d53422789bbdbbfeb764efa",
"c65151b08bec4cf48d0e0d6ee1250f59",
"78fabf12226a4f9cab33f1885b963534",
"e34d7ab3f4834d17ae863d288445d9be",
"ff14ace405e4409589bc3531a789dda1",
"08c58056296849fdbab23840a9196979",
"84ff998cb0f8461daa8deaf55dc17db5",
"63cd9ad6473b43f0856596c29c6a0cf3",
"a4f85e3bf0eb45c4857f6173b1adebcd",
"8edf734aa6d64f5f89906b464d512650",
"94dba14a5bc342bb8711f6843b6d7ed1",
"fc6d1aa44bd94a3ab9ca40c6a66c8ffb",
"a15bf77b783d4833b9c71ad2b2d9a9a4",
"76f7f82fe6134040a90c7413538aafe5",
"30d05e576f3747e08d932ea00fed01be",
"ffda192e0bc24c40aa272eea8011e1a1",
"89692d1d3c914672a458832e51be6d2d",
"1ebc784efa24427b8af27c9fdda2686b",
"f7d8ef7ca553491c9fba3ab382af530c",
"fd342b9fb4a84009b32ab4c7bcb377a4",
"daf05dc94f98455684bc96752c216e9c",
"575724ca3bb24bc79d64ed65544bc043",
"2f8aa3e45a55423d83606705f5010348",
"7e55727bedd94e849a125e4406f6a980",
"f595b20d8db44b9998732a1cc713a1e7",
"00d5b27795f448d485da956dbd8a73d4",
"3d5308aafb4744508c3390d59f8c6992",
"cbcbf6c645a14f61b625154537e9e528",
"0bc1a29b59384aec931b7d0c611c5ec2",
"8ff6ad225eba4c1b9baf01aa9db1a7d3",
"18935a837f49426dac45432e83a5913c",
"c6a107ffcc754e82921ad75bc6643898",
"2e382b16620c4bb09534050dae7a9d6e",
"997bd650baad41819aa3a713fa4341f3",
"29e85d0ef3754c7f8366469cd530cacb",
"5bee187efff34f8eab2c434eb21211ea",
"275b01de8e6444a69420369d2a96b76d",
"71ab8921857b496791db2a2286fab6ba",
"225e43ff721745c5b01dd7fe5dfe7e4d",
"a6900524efc04d8d8019c415196d4a46",
"a682df09134f4c74aef12a2e294e2cae",
"b329b0dc353247f58bf8c38b24c4e051",
"2ac54f2315934367ae8bd34aba136f21",
"579357f9baec40f2915d8f3021fc20eb",
"4e1c1bd8b4d74978bff0cfa348a1bff0",
"68bae7a3d76f4e68afe3e6fac15ccfb4",
"720d5c006f2f46d982be76cb50eb8f17",
"d2a91625d68f43479aa16f24b386af74",
"19601448d0704949aaa8459fae71b878",
"cdbef121730444b8abd227f4dbd42d3c",
"c24578c84bae4ab5ac0145c742f0747d",
"843ad42380cb40b989659ac691e67d65",
"f2d9294ee6f44814a8e8d9a87cf08c64",
"532efa9cd0d8459eac8b03bba83aad51",
"f6d159342ce047df904907d04cceb9ec",
"035514d9dad34ea6b645018dad177693",
"09394797973941e68705559ddcd9b802",
"571f0233507544148b23a0dcf251daa3",
"8fcd79171edb43b28b3291c06440b5cb",
"cf801abb3ce94c9b8bf91b4d1efa39db",
"116121b6f33b4883884fa748e515d969",
"a2320a4c9d0e492c93cc9b06e758e0d5",
"10e14d896714494da3d8d37dd8ed9282",
"c8aca95b4c1d4734a86f912b7ea43c6e",
"55a5af05b8234f938ff6104af90f2460",
"f66bcff001e04b199a1bbc9d92f42367",
"ff9eb64ef78b4722b373285f8bce32ed",
"70aad33e2cd043c08a94c8b8a93a3342",
"892450aebc2b4611aec22ffb88a8cba9",
"1e6314ba4f4a445b9b68928c6debb77c",
"ef7d3624b5fe44eea8d40f6b6a02898b",
"e652c021449c485484fb899b17acd0ef",
"0661f3b9e7bc408f82030c5b7c164fcb",
"ab50e73bf5404380baba92cf1aa40976"
]
}
},
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading builder script: 0%| | 0.00/8.57k [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "acea7c14a43a4d98ac3fa7d55bcf6ee6"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading metadata: 0%| | 0.00/14.4k [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "a03bc89197ea4573b0f6784ca5b4e80b"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading readme: 0%| | 0.00/23.4k [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "a4f85e3bf0eb45c4857f6173b1adebcd"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading and preparing dataset clinc_oos/small to /root/.cache/huggingface/datasets/clinc_oos/small/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1...\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading data: 0%| | 0.00/217k [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "fd342b9fb4a84009b32ab4c7bcb377a4"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating train split: 0%| | 0/7600 [00:00, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "18935a837f49426dac45432e83a5913c"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating validation split: 0%| | 0/3100 [00:00, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "b329b0dc353247f58bf8c38b24c4e051"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating test split: 0%| | 0/5500 [00:00, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "f2d9294ee6f44814a8e8d9a87cf08c64"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Dataset clinc_oos downloaded and prepared to /root/.cache/huggingface/datasets/clinc_oos/small/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1. Subsequent calls will reuse this data.\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "c8aca95b4c1d4734a86f912b7ea43c6e"
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"clinc_oos_dataset"
],
"metadata": {
"id": "Du4cdYWVI0J7",
"outputId": "cbe0b060-b88a-474e-efc6-d2a5de555f2b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 7600\n",
" })\n",
" validation: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 3100\n",
" })\n",
" test: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 5500\n",
" })\n",
"})"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"source": [
"Podemos ver que el dataset es una estructura DatasetDict que puede verse como un diccionario. El diccionario tiene tres claves que son train, validation y test que indican respectivamente los conjuntos de entrenamiento, validación y test (estas claves pueden variar dependiendo del dataset). Cada uno de estos subconjuntos es un Dataset que puede verse como una lista. Podemos ver por ejemplo la primera frase del conjunto de entrenamiento del siguiente modo."
],
"metadata": {
"id": "ZVJsMgNUJe4T"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "0XB8Dl8jjwbN"
},
"source": [
"## Carga de datos\n",
"\n",
"Cargamos a continuación el dataset en distintos dataframes de pandas (el formato que puede leer la librería de FastAI)."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "mePgxpeLjwbO"
},
"outputs": [],
"source": [
"train_df = clinc_oos_dataset[\"train\"].to_pandas()\n",
"valid_df = clinc_oos_dataset[\"validation\"].to_pandas()\n",
"test_df = clinc_oos_dataset[\"test\"].to_pandas()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FN-QG9CljwbO"
},
"source": [
"Podemos ver el contenido del dataset usando el siguiente comando."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 424
},
"id": "veoRjZOWjwbP",
"outputId": "e790ac73-5a62-461e-cfe4-3bebe7b1b25f"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text \\\n",
"0 can you walk me through setting up direct deposits to my bank of internet savings account \n",
"1 i want to switch to direct deposit \n",
"2 set up direct deposit for me \n",
"3 how do i go about setting up direct deposit \n",
"4 i need to get my paycheck direct deposited to my chase account \n",
"... ... \n",
"7595 what percentage of species display cold blooded traits \n",
"7596 what does it mean to be an alpha male \n",
"7597 what animals have alpha males \n",
"7598 why do males want to be alpha \n",
"7599 what's the average battery life of an android phone \n",
"\n",
" intent \n",
"0 108 \n",
"1 108 \n",
"2 108 \n",
"3 108 \n",
"4 108 \n",
"... ... \n",
"7595 42 \n",
"7596 42 \n",
"7597 42 \n",
"7598 42 \n",
"7599 42 \n",
"\n",
"[7600 rows x 2 columns]"
],
"text/html": [
"\n",
"
\n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
text
\n",
"
intent
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
can you walk me through setting up direct deposits to my bank of internet savings account
\n",
"
108
\n",
"
\n",
"
\n",
"
1
\n",
"
i want to switch to direct deposit
\n",
"
108
\n",
"
\n",
"
\n",
"
2
\n",
"
set up direct deposit for me
\n",
"
108
\n",
"
\n",
"
\n",
"
3
\n",
"
how do i go about setting up direct deposit
\n",
"
108
\n",
"
\n",
"
\n",
"
4
\n",
"
i need to get my paycheck direct deposited to my chase account
\n",
"
108
\n",
"
\n",
"
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
\n",
"
\n",
"
7595
\n",
"
what percentage of species display cold blooded traits
\n",
"
42
\n",
"
\n",
"
\n",
"
7596
\n",
"
what does it mean to be an alpha male
\n",
"
42
\n",
"
\n",
"
\n",
"
7597
\n",
"
what animals have alpha males
\n",
"
42
\n",
"
\n",
"
\n",
"
7598
\n",
"
why do males want to be alpha
\n",
"
42
\n",
"
\n",
"
\n",
"
7599
\n",
"
what's the average battery life of an android phone
\n",
"
42
\n",
"
\n",
" \n",
"
\n",
"
7600 rows × 2 columns
\n",
"
\n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 10
}
],
"source": [
"train_df"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sLJE_AtyjwbP"
},
"source": [
"Del dataset nos interesan dos campos: `text` (que contiene el poema) y `intent` ."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ivLo2dIHjwbP",
"outputId": "f5120ac6-b4ff-4535-92f9-b3a8b4557ebe"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0 can you walk me through setting up direct deposits to my bank of internet savings account\n",
"1 i want to switch to direct deposit\n",
"2 set up direct deposit for me\n",
"3 how do i go about setting up direct deposit\n",
"4 i need to get my paycheck direct deposited to my chase account\n",
" ... \n",
"7595 what percentage of species display cold blooded traits\n",
"7596 what does it mean to be an alpha male\n",
"7597 what animals have alpha males\n",
"7598 why do males want to be alpha\n",
"7599 what's the average battery life of an android phone\n",
"Name: text, Length: 7600, dtype: object"
]
},
"metadata": {},
"execution_count": 11
}
],
"source": [
"train_df['text']"
]
},
{
"cell_type": "markdown",
"source": [
"## Dataset\n",
"\n",
"Para este ejemplo vamos a usar el dataset [clinc_oos](https://huggingface.co/datasets/clinc_oos), un dataset para detectar 150 clases de intención en 10 dominios, el conjunto de datos contiene una etiqueta para la intención \"fuera del ámbito\".( Label Id | Label name | |--- |--- | | 0 | restaurant_reviews | | 1 | nutrition_info | | 2 | account_blocked | | 3 | oil_change_how | | 4 | time | | 5 | weather | | 6 | redeem_rewards | | 7 | interest_rate | | 8 | gas_type | | 9 | accept_reservations | | 10 | smart_home | | 11 | user_name | | 12 | report_lost_card | | 13 | repeat | | 14 | whisper_mode | | 15 | what_are_your_hobbies | | 16 | order | | 17 | jump_start | | 18 | schedule_meeting | | 19 | meeting_schedule | | 20 | freeze_account | | 21 | what_song | | 22 | meaning_of_life | | 23 | restaurant_reservation | | 24 | traffic | | 25 | make_call | | 26 | text | | 27 | bill_balance ..)\n",
"Descarga el dataset usando el siguiente comando. "
],
"metadata": {
"id": "W-giW-0s7Hsp"
}
},
{
"cell_type": "code",
"source": [
"clinc_oos_dataset = load_dataset(\"clinc_oos\",'small')"
],
"metadata": {
"outputId": "2a45d8bc-cb19-476c-82f8-f866cd035942",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 148,
"referenced_widgets": [
"d2b46b2844f3474899145a75ffb4c571",
"b10eccc8dc2c429abdb627a62d6863e7",
"765efc5a4bce4307bd4540d8c1f56a0e",
"02d90990072842078d64e09f72176f8f",
"a1884aac1bd9483ab57f0b928c63832b",
"1be048cecdb14c2799b0ece8c6c9c9e4",
"3d3ff51a288540fc8cba1d6b4b831a14",
"bb970639d5144a3f92ac875377642655",
"6e7e5d794a8c40ecbda3cd7ad568f153",
"29f81fe476cd47089f5617bc7276d65c",
"8fc2e123da594603987eb98660b87d94",
"10d5cd30a0e54c5ba7b6df6b8ef580a0",
"b80735bbbcfa4b589a77d4eb34b23d04",
"4ccaa9c7811a4cb6829eedc1bca8ab34",
"1fe255143a2547719eabda362cc73069",
"1d37657c44d24109975288f069e35383",
"89ea37b2f176438c94b7273d18999424",
"a72de9da781c40d194194af2b03ce2fa",
"2939a6c917544effb6d30290e02750c2",
"b4e7b22a14ab4de39a6be5b9ac87a859",
"81514c49084c4699acf2a19f176905fe",
"a7f35632035245c7a808facbcde1faeb",
"c0ad3dcde07c46dc9026973722d53b97",
"e6dd5c7de93f45bcb5f8966d7df979a2",
"2e78b25a5d5d4674b1d1f61708a7f8bb",
"b97a5dca9f6141b08b81ff1a58d68919",
"45957dbb65994da793d90e4632a7e0d9",
"d15ffa6b1da143e8baf62041b23662cc",
"817890f3aef0421a83b33b3531c8ac1b",
"9e49685857a045fbbcc7378a5fdbaa50",
"018fe2a8ea194860934fe72fd989b067",
"98162db40afa4484a4321d58fd9d2848",
"71cba57f6c4c4e53b2c390ed3fa0bcbc",
"1cedd010d23c4709ad213689552198de",
"f60a867c6dd445e39103a549f176bb10",
"6e2e2c2a86014b85af276d407cd32ed6",
"c21f1c59e8c04770b2952383125dec27",
"d5875aca8911496b9b96367b6a8e8852",
"146d51dcd77f41518c064c23941569c3",
"670ffe1337604111be17198d8b087155",
"7fd60f1fbc0042198decd1113f697fad",
"14d341830ed042219b6bd40b748a994c",
"88871391967848f18068333f6a8a6530",
"309e721a2df84da39090d74931655bc2",
"cd54965515904b2dacb44113735df90d",
"ffcf7b7524be47ec9f50726299d22d9e",
"b3637041f9ab41fdb0bc409c33c8f921",
"8682ad4b47de444b968ba2e1992ddb65",
"d9a8ed1d734f49ffaf5e59138272abfd",
"afe03557c5df47afb4c3483ca18f3de2",
"d883e3b74b8343c78642b355e51b590c",
"41a3d446c21b40c28f5ad0328ef75adc",
"bc73d68234e54b83a053797d64ba24fd",
"e95a9bed6ee24ab4916528200a033764",
"ec7c526f8e3f43a094eb2d3bf87fe0df"
]
},
"id": "p7psPv0w7Hsr"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading and preparing dataset clinc_oos/small to /root/.cache/huggingface/datasets/clinc_oos/small/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1...\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading data: 0%| | 0.00/217k [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "d2b46b2844f3474899145a75ffb4c571"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating train split: 0%| | 0/7600 [00:00, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "10d5cd30a0e54c5ba7b6df6b8ef580a0"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating validation split: 0%| | 0/3100 [00:00, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "c0ad3dcde07c46dc9026973722d53b97"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating test split: 0%| | 0/5500 [00:00, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "1cedd010d23c4709ad213689552198de"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Dataset clinc_oos downloaded and prepared to /root/.cache/huggingface/datasets/clinc_oos/small/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1. Subsequent calls will reuse this data.\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "cd54965515904b2dacb44113735df90d"
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"clinc_oos_dataset"
],
"metadata": {
"outputId": "5f63daec-d372-4acc-aad8-d660d528ff3c",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vO3Al_D27Hst"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 7600\n",
" })\n",
" validation: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 3100\n",
" })\n",
" test: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 5500\n",
" })\n",
"})"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"source": [
"Podemos ver que el dataset es una estructura DatasetDict que puede verse como un diccionario. El diccionario tiene tres claves que son train, validation y test que indican respectivamente los conjuntos de entrenamiento, validación y test (estas claves pueden variar dependiendo del dataset). Cada uno de estos subconjuntos es un Dataset que puede verse como una lista. Podemos ver por ejemplo la primera frase del conjunto de entrenamiento del siguiente modo."
],
"metadata": {
"id": "K_S0SmHQ7Hsu"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Q9BiqO-7Hs1"
},
"source": [
"## Carga de datos\n",
"\n",
"Cargamos a continuación el dataset en distintos dataframes de pandas (el formato que puede leer la librería de FastAI)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U-rv9q8V7Hs2"
},
"outputs": [],
"source": [
"train_df = clinc_oos_dataset[\"train\"].to_pandas()\n",
"valid_df = clinc_oos_dataset[\"validation\"].to_pandas()\n",
"test_df = clinc_oos_dataset[\"test\"].to_pandas()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v5b9b13Y7Hs3"
},
"source": [
"Podemos ver el contenido del dataset usando el siguiente comando."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HG8ruoAv7Hs4"
},
"source": [
"Del dataset nos interesan dos campos: `text` (que contiene el poema) y `intent` ."
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "XMxR9DgE7XV4"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Dataset\n",
"\n",
"Para este ejemplo vamos a usar el dataset [clinc_oos](https://huggingface.co/datasets/clinc_oos), un dataset para detectar 150 clases de intención en 10 dominios, el conjunto de datos contiene una etiqueta para la intención \"fuera del ámbito\".( Label Id | Label name | |--- |--- | | 0 | restaurant_reviews | | 1 | nutrition_info | | 2 | account_blocked | | 3 | oil_change_how | | 4 | time | | 5 | weather | | 6 | redeem_rewards | | 7 | interest_rate | | 8 | gas_type | | 9 | accept_reservations | | 10 | smart_home | | 11 | user_name | | 12 | report_lost_card | | 13 | repeat | | 14 | whisper_mode | | 15 | what_are_your_hobbies | | 16 | order | | 17 | jump_start | | 18 | schedule_meeting | | 19 | meeting_schedule | | 20 | freeze_account | | 21 | what_song | | 22 | meaning_of_life | | 23 | restaurant_reservation | | 24 | traffic | | 25 | make_call | | 26 | text | | 27 | bill_balance ..)\n",
"Descarga el dataset usando el siguiente comando. "
],
"metadata": {
"id": "mhgyRRLq7Xjs"
}
},
{
"cell_type": "code",
"source": [
"clinc_oos_dataset = load_dataset(\"clinc_oos\",'small')"
],
"metadata": {
"outputId": "2a45d8bc-cb19-476c-82f8-f866cd035942",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 148,
"referenced_widgets": [
"d2b46b2844f3474899145a75ffb4c571",
"b10eccc8dc2c429abdb627a62d6863e7",
"765efc5a4bce4307bd4540d8c1f56a0e",
"02d90990072842078d64e09f72176f8f",
"a1884aac1bd9483ab57f0b928c63832b",
"1be048cecdb14c2799b0ece8c6c9c9e4",
"3d3ff51a288540fc8cba1d6b4b831a14",
"bb970639d5144a3f92ac875377642655",
"6e7e5d794a8c40ecbda3cd7ad568f153",
"29f81fe476cd47089f5617bc7276d65c",
"8fc2e123da594603987eb98660b87d94",
"10d5cd30a0e54c5ba7b6df6b8ef580a0",
"b80735bbbcfa4b589a77d4eb34b23d04",
"4ccaa9c7811a4cb6829eedc1bca8ab34",
"1fe255143a2547719eabda362cc73069",
"1d37657c44d24109975288f069e35383",
"89ea37b2f176438c94b7273d18999424",
"a72de9da781c40d194194af2b03ce2fa",
"2939a6c917544effb6d30290e02750c2",
"b4e7b22a14ab4de39a6be5b9ac87a859",
"81514c49084c4699acf2a19f176905fe",
"a7f35632035245c7a808facbcde1faeb",
"c0ad3dcde07c46dc9026973722d53b97",
"e6dd5c7de93f45bcb5f8966d7df979a2",
"2e78b25a5d5d4674b1d1f61708a7f8bb",
"b97a5dca9f6141b08b81ff1a58d68919",
"45957dbb65994da793d90e4632a7e0d9",
"d15ffa6b1da143e8baf62041b23662cc",
"817890f3aef0421a83b33b3531c8ac1b",
"9e49685857a045fbbcc7378a5fdbaa50",
"018fe2a8ea194860934fe72fd989b067",
"98162db40afa4484a4321d58fd9d2848",
"71cba57f6c4c4e53b2c390ed3fa0bcbc",
"1cedd010d23c4709ad213689552198de",
"f60a867c6dd445e39103a549f176bb10",
"6e2e2c2a86014b85af276d407cd32ed6",
"c21f1c59e8c04770b2952383125dec27",
"d5875aca8911496b9b96367b6a8e8852",
"146d51dcd77f41518c064c23941569c3",
"670ffe1337604111be17198d8b087155",
"7fd60f1fbc0042198decd1113f697fad",
"14d341830ed042219b6bd40b748a994c",
"88871391967848f18068333f6a8a6530",
"309e721a2df84da39090d74931655bc2",
"cd54965515904b2dacb44113735df90d",
"ffcf7b7524be47ec9f50726299d22d9e",
"b3637041f9ab41fdb0bc409c33c8f921",
"8682ad4b47de444b968ba2e1992ddb65",
"d9a8ed1d734f49ffaf5e59138272abfd",
"afe03557c5df47afb4c3483ca18f3de2",
"d883e3b74b8343c78642b355e51b590c",
"41a3d446c21b40c28f5ad0328ef75adc",
"bc73d68234e54b83a053797d64ba24fd",
"e95a9bed6ee24ab4916528200a033764",
"ec7c526f8e3f43a094eb2d3bf87fe0df"
]
},
"id": "9v9oMoJX7Xjs"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading and preparing dataset clinc_oos/small to /root/.cache/huggingface/datasets/clinc_oos/small/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1...\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading data: 0%| | 0.00/217k [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "d2b46b2844f3474899145a75ffb4c571"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating train split: 0%| | 0/7600 [00:00, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "10d5cd30a0e54c5ba7b6df6b8ef580a0"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating validation split: 0%| | 0/3100 [00:00, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "c0ad3dcde07c46dc9026973722d53b97"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Generating test split: 0%| | 0/5500 [00:00, ? examples/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "1cedd010d23c4709ad213689552198de"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Dataset clinc_oos downloaded and prepared to /root/.cache/huggingface/datasets/clinc_oos/small/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1. Subsequent calls will reuse this data.\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "cd54965515904b2dacb44113735df90d"
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"clinc_oos_dataset"
],
"metadata": {
"outputId": "5f63daec-d372-4acc-aad8-d660d528ff3c",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Nd71pRe97Xjt"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 7600\n",
" })\n",
" validation: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 3100\n",
" })\n",
" test: Dataset({\n",
" features: ['text', 'intent'],\n",
" num_rows: 5500\n",
" })\n",
"})"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"source": [
"Podemos ver que el dataset es una estructura DatasetDict que puede verse como un diccionario. El diccionario tiene tres claves que son train, validation y test que indican respectivamente los conjuntos de entrenamiento, validación y test (estas claves pueden variar dependiendo del dataset). Cada uno de estos subconjuntos es un Dataset que puede verse como una lista. Podemos ver por ejemplo la primera frase del conjunto de entrenamiento del siguiente modo."
],
"metadata": {
"id": "iys3O7jz7Xjt"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "e2OkF7097Xju"
},
"source": [
"## Carga de datos\n",
"\n",
"Cargamos a continuación el dataset en distintos dataframes de pandas (el formato que puede leer la librería de FastAI)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sv7ZtBqs7Xju"
},
"outputs": [],
"source": [
"train_df = clinc_oos_dataset[\"train\"].to_pandas()\n",
"valid_df = clinc_oos_dataset[\"validation\"].to_pandas()\n",
"test_df = clinc_oos_dataset[\"test\"].to_pandas()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "go0-ENcw7Xju"
},
"source": [
"Podemos ver el contenido del dataset usando el siguiente comando."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y1KMKZgN7Xjv"
},
"source": [
"Del dataset nos interesan dos campos: `text` (que contiene el poema) y `intent` ."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uw3h-1zPC2kw"
},
"source": [
"## Modelo de lenguaje\n",
"\n",
"El proceso a seguir para hacer fine-tuning sobre el modelo de lenguaje de FastAI es análogo al visto en prácticas anteriores. Comenzamos creando un `DataBlock` a partir de nuestro dataframe. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "AW74KDBJC2kw"
},
"outputs": [],
"source": [
"db_lm = DataBlock(\n",
" blocks=TextBlock.from_df('text', is_lm=True,max_vocab=100000), # Indicamos que vamos a trabajar con un modelo de lenguaje\n",
" get_items=ColReader('text'), # Indicamos donde estará el texto dentro del dataframe\n",
" splitter=RandomSplitter(0.1) # Partimos el dataset en entrenamiento y validación\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "leP76qI4C2kw"
},
"source": [
"Creamos ahora nuestro `dataloader` (esto puede llevar varios segundos)."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "e_Uzg_gRC2kx"
},
"outputs": [],
"source": [
"dls_lm = db_lm.dataloaders(train_valid_df, bs=128, seq_len=80)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ESmIQ-pkC2kx"
},
"source": [
"Podemos ahora mostrar un batch de este `dataloader`. Como podemos apreciar, la entrada del modelo es una frase, y la salida es dicha frase desplazada una posición a la derecha."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "WLkzmtyYC2kx",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 216
},
"outputId": "d778b76f-edb1-4058-aab6-3b1d1e5274c2"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"
\n",
" \n",
"
\n",
"
\n",
"
text
\n",
"
text_
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
xxbos xxunk when is my flight going to land xxbos do i have meetings with anyone today xxbos how soon does my visa card expire xxbos set a new timer for ten xxunk xxbos what can i get for 10 usd in cad xxbos when was the last date that i was paid xxbos can you help me find someone to look at my car 's engine because the light is on xxbos tell the restaurant to cancel my reservation
\n",
"
xxunk when is my flight going to land xxbos do i have meetings with anyone today xxbos how soon does my visa card expire xxbos set a new timer for ten xxunk xxbos what can i get for 10 usd in cad xxbos when was the last date that i was paid xxbos can you help me find someone to look at my car 's engine because the light is on xxbos tell the restaurant to cancel my reservation xxbos
\n",
"
\n",
"
\n",
"
1
\n",
"
beef and xxunk xxbos what can i do if i lost my luggage xxbos can you get me a recipe for mashed potatoes xxbos share the nutrition info for pizza with me xxbos can my 401k rollover or not xxbos what is the lowest amount i can pay for my cable bill xxbos are there a lot of calories in muffins xxbos cancel that xxbos how long until my oil needs to be changed xxbos what month does my card
\n",
"
and xxunk xxbos what can i do if i lost my luggage xxbos can you get me a recipe for mashed potatoes xxbos share the nutrition info for pizza with me xxbos can my 401k rollover or not xxbos what is the lowest amount i can pay for my cable bill xxbos are there a lot of calories in muffins xxbos cancel that xxbos how long until my oil needs to be changed xxbos what month does my card stop
\n",
" "
]
},
"metadata": {}
}
],
"source": [
"learn = language_model_learner(\n",
" dls_lm, # El dataloader que usamos\n",
" AWD_LSTM, # La arquitectura que es la misma usada en la práctica anterior\n",
" drop_mult=0.3, # Aplicamos dropout para evitar el sobreajuste\n",
" metrics=[accuracy, Perplexity()] # Como métricas usamos la accuracy y la perplexity.\n",
").to_fp16()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZeI2zGAC2ky"
},
"source": [
"Y por último entrenamos el modelo. "
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "wRXPjZB9C2ky",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 566
},
"outputId": "17be9a39-4039-4989-d74d-74bdb28a8ab9"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.9/dist-packages/torch/amp/autocast_mode.py:204: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
" warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n",
"/usr/local/lib/python3.9/dist-packages/torch/cuda/amp/grad_scaler.py:120: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n",
" warnings.warn(\"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\")\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"
\n",
" \n",
"
\n",
"
epoch
\n",
"
train_loss
\n",
"
valid_loss
\n",
"
accuracy
\n",
"
perplexity
\n",
"
time
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
5.696906
\n",
"
4.259863
\n",
"
0.198017
\n",
"
70.800308
\n",
"
04:09
\n",
"
\n",
" \n",
"
"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.9/dist-packages/torch/amp/autocast_mode.py:204: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
" warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n",
"/usr/local/lib/python3.9/dist-packages/torch/cuda/amp/grad_scaler.py:120: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n",
" warnings.warn(\"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\")\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"
\n",
" \n",
"
\n",
"
epoch
\n",
"
train_loss
\n",
"
valid_loss
\n",
"
accuracy
\n",
"
perplexity
\n",
"
time
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
3.862475
\n",
"
3.561986
\n",
"
0.299479
\n",
"
35.233109
\n",
"
04:13
\n",
"
\n",
"
\n",
"
1
\n",
"
3.606975
\n",
"
3.192473
\n",
"
0.342748
\n",
"
24.348566
\n",
"
04:14
\n",
"
\n",
"
\n",
"
2
\n",
"
3.388048
\n",
"
2.945242
\n",
"
0.369992
\n",
"
19.015263
\n",
"
04:16
\n",
"
\n",
"
\n",
"
3
\n",
"
3.182660
\n",
"
2.770793
\n",
"
0.400341
\n",
"
15.971290
\n",
"
04:20
\n",
"
\n",
"
\n",
"
4
\n",
"
3.000447
\n",
"
2.687873
\n",
"
0.407853
\n",
"
14.700380
\n",
"
04:18
\n",
"
\n",
"
\n",
"
5
\n",
"
2.841043
\n",
"
2.651250
\n",
"
0.414864
\n",
"
14.171744
\n",
"
04:02
\n",
"
\n",
"
\n",
"
6
\n",
"
2.706536
\n",
"
2.625234
\n",
"
0.423878
\n",
"
13.807803
\n",
"
04:02
\n",
"
\n",
"
\n",
"
7
\n",
"
2.591934
\n",
"
2.609441
\n",
"
0.425581
\n",
"
13.591455
\n",
"
04:03
\n",
"
\n",
"
\n",
"
8
\n",
"
2.498425
\n",
"
2.608087
\n",
"
0.427284
\n",
"
13.573062
\n",
"
04:03
\n",
"
\n",
"
\n",
"
9
\n",
"
2.420030
\n",
"
2.607645
\n",
"
0.427784
\n",
"
13.567060
\n",
"
04:02
\n",
"
\n",
" \n",
"
"
]
},
"metadata": {}
}
],
"source": [
"learn.fine_tune(10,base_lr=2e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pt2h96oNC2ky"
},
"source": [
"Una vez entrenado el modelo guardamos el `encoder` que usaremos luego para nuestro modelo de clasificación (esto es análogo a lo que vimos para los modelos de clasificación de imágenes)."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "KHT9LdeBC2ky"
},
"outputs": [],
"source": [
"learn.save_encoder('finetuned')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MuF9fJQeC2ky"
},
"source": [
"## Entrenando un modelo de clasificación\n",
"\n",
"Pasamos ahora a crear nuestro modelo de clasificación de texto. El proceso será el mismo que el que vimos en la práctica anterior con la diferencia de que antes de empezar a entrenar el modelo cargaremos el `encoder` guardado en el paso anterior.\n",
"\n",
"Comenzamos definiendo un `DataBlock` que se creará a partir de nuestro dataframe `df`. "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"id": "GWQt6jF0C2ky"
},
"outputs": [],
"source": [
"sentiment_clas = DataBlock(\n",
" blocks=(TextBlock.from_df('text', vocab=dls_lm.vocab), # La entrada del modelo es texto usando el mismo \n",
" # vocabulario que en el modelo de lenguaje \n",
" CategoryBlock), #, y la salida una clase \n",
" get_x=ColReader('text'), # Indicamos donde estará el texto dentro del dataframe\n",
" get_y=ColReader('intent'), # Indicamos cómo extraer la clase del dataframe\n",
" splitter=ColSplitter('set') # Partimos el dataset en entrenamiento y validación\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A_yKEURkC2kz"
},
"source": [
"Ahora definimos nuestro dataloader a partir del DataBlock que acabamos de crear. "
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "9ORL1uwgC2kz",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17
},
"outputId": "a3222cb5-d572-497d-d4ab-c8b4b937b1b1"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": []
},
"metadata": {}
}
],
"source": [
"dls = sentiment_clas.dataloaders(train_valid_df, bs=64)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hy2jPfJjC2kz"
},
"source": [
"Podemos mostrar un batch de nuestro dataloader. "
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "GB8nwXKbC2kz",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 112
},
"outputId": "59e115b9-68b5-4960-de61-8462cc3411df"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"
\n",
" \n",
"
\n",
"
\n",
"
text
\n",
"
category
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
xxbos can you tell me what to do as i am in the airport and have been for some time and there is still no xxunk of my xxunk
\n",
"
113
\n",
"
\n",
"
\n",
"
1
\n",
"
xxbos i xxunk like a rental car in denver xxunk between january 1st and january 3rd and i xxunk like a ford if possible
\n",
"
40
\n",
"
\n",
" \n",
"
"
]
},
"metadata": {}
}
],
"source": [
"dls.show_batch(max_n=2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WARPAlm5C2kz"
},
"source": [
"Pasamos ahora a crear nuestro `learner` usando el método `text_classifier_learner` al que pasamos como arquitectura de red la arquitectura [AWD_LSTM](https://arxiv.org/abs/1708.02182), además aplicamos dropout a nuestro modelo. "
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"id": "07dPozDGC2kz"
},
"outputs": [],
"source": [
"callbacks = [ShowGraphCallback(),\n",
" SaveModelCallback()]\n",
"\n",
"learnClass = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy,cbs=callbacks).to_fp16()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UE76n_MkC2kz"
},
"source": [
"Cargamos a continuación el `encoder` del modelo de lenguaje. "
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"id": "mUHHd338C2k0"
},
"outputs": [],
"source": [
"learnClass = learnClass.load_encoder('finetuned')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lWPtCffDC2k0"
},
"source": [
"Ahora podemos utilizar toda la funcionalidad que ya vimos para clasificación de imágenes. Por ejemplo, podemos buscar un learning rate adecuado para entrenar nuestro modelo. "
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"id": "mF9qM3rgC2k0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 541
},
"outputId": "31592431-2195-41cd-c129-f18a0f097263"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.9/dist-packages/torch/amp/autocast_mode.py:204: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
" warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n",
"/usr/local/lib/python3.9/dist-packages/torch/cuda/amp/grad_scaler.py:120: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n",
" warnings.warn(\"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\")\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": []
},
"metadata": {}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"SuggestedLRs(valley=0.004365158267319202)"
]
},
"metadata": {},
"execution_count": 30
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"
"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"learnClass.lr_find()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mRUbbfWYC2k0"
},
"source": [
"Y a continuación aplicar fine tuning. "
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"id": "wqN1ap8uC2k0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "a5e241da-de4c-4773-dfa3-229c7f2be2d5"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.9/dist-packages/torch/amp/autocast_mode.py:204: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
" warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n",
"/usr/local/lib/python3.9/dist-packages/torch/cuda/amp/grad_scaler.py:120: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n",
" warnings.warn(\"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\")\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"
\n",
" \n",
"
\n",
"
epoch
\n",
"
train_loss
\n",
"
valid_loss
\n",
"
accuracy
\n",
"
time
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
2.663182
\n",
"
1.729634
\n",
"
0.542903
\n",
"
01:56
\n",
"
\n",
" \n",
"
"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Better model found at epoch 0 with valid_loss value: 1.7296338081359863.\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"
"
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.9/dist-packages/torch/amp/autocast_mode.py:204: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
" warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n",
"/usr/local/lib/python3.9/dist-packages/torch/cuda/amp/grad_scaler.py:120: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n",
" warnings.warn(\"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\")\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"\n",
"\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"text/html": [
"