{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Install" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: uv in /Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages (0.1.42)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install uv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!uv pip install dagshub setuptools accelerate toml torch torchvision transformers mlflow datasets ipywidgets python-dotenv evaluate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Initialized MLflow to track repo \"amaye15/CanineNet\"\n",
"
\n"
],
"text/plain": [
"Initialized MLflow to track repo \u001b[32m\"amaye15/CanineNet\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Repository amaye15/CanineNet initialized!\n", "\n" ], "text/plain": [ "Repository amaye15/CanineNet initialized!\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import os\n", "import toml\n", "import torch\n", "import mlflow\n", "import dagshub\n", "import datasets\n", "import evaluate\n", "from dotenv import load_dotenv\n", "from torchvision.transforms import v2\n", "from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer\n", "\n", "ENV_PATH = \"/Users/andrewmayes/Openclassroom/CanineNet/.env\"\n", "CONFIG_PATH = \"/Users/andrewmayes/Openclassroom/CanineNet/code/config.toml\"\n", "CONFIG = toml.load(CONFIG_PATH)\n", "\n", "load_dotenv(ENV_PATH)\n", "\n", "dagshub.init(repo_name=os.environ['MLFLOW_TRACKING_PROJECTNAME'], repo_owner=os.environ['MLFLOW_TRACKING_USERNAME'], mlflow=True, dvc=True)\n", "\n", "os.environ['MLFLOW_TRACKING_USERNAME'] = \"amaye15\"\n", "\n", "mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME']\n", " + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')\n", "\n", "CREATE_DATASET = True\n", "ORIGINAL_DATASET = \"Alanox/stanford-dogs\"\n", "MODIFIED_DATASET = \"amaye15/stanford-dogs\"\n", "REMOVE_COLUMNS = [\"name\", \"annotations\"]\n", "RENAME_COLUMNS = {\"image\":\"pixel_values\", \"target\":\"label\"}\n", "SPLIT = 0.2\n", "\n", "METRICS = [\"accuracy\", \"f1\", \"precision\", \"recall\"]\n", "# MODELS = 'google/vit-base-patch16-224'\n", "# MODELS = \"google/siglip-base-patch16-224\"\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Affenpinscher: 0\n", "Afghan Hound: 1\n", "African Hunting Dog: 2\n", "Airedale: 3\n", "American Staffordshire Terrier: 4\n", "Appenzeller: 5\n", "Australian Terrier: 6\n", "Basenji: 7\n", "Basset: 8\n", "Beagle: 9\n", "Bedlington Terrier: 10\n", "Bernese Mountain Dog: 11\n", "Black And Tan Coonhound: 12\n", "Blenheim Spaniel: 13\n", "Bloodhound: 14\n", "Bluetick: 15\n", "Border Collie: 16\n", "Border Terrier: 17\n", "Borzoi: 18\n", "Boston Bull: 19\n", "Bouvier Des Flandres: 20\n", "Boxer: 21\n", "Brabancon Griffon: 22\n", "Briard: 23\n", "Brittany Spaniel: 24\n", "Bull Mastiff: 25\n", "Cairn: 26\n", "Cardigan: 27\n", "Chesapeake Bay Retriever: 28\n", "Chihuahua: 29\n", "Chow: 30\n", "Clumber: 31\n", "Cocker Spaniel: 32\n", "Collie: 33\n", "Curly Coated Retriever: 34\n", "Dandie Dinmont: 35\n", "Dhole: 36\n", "Dingo: 37\n", "Doberman: 38\n", "English Foxhound: 39\n", "English Setter: 40\n", "English Springer: 41\n", "Entlebucher: 42\n", "Eskimo Dog: 43\n", "Flat Coated Retriever: 44\n", "French Bulldog: 45\n", "German Shepherd: 46\n", "German Short Haired Pointer: 47\n", "Giant Schnauzer: 48\n", "Golden Retriever: 49\n", "Gordon Setter: 50\n", "Great Dane: 51\n", "Great Pyrenees: 52\n", "Greater Swiss Mountain Dog: 53\n", "Groenendael: 54\n", "Ibizan Hound: 55\n", "Irish Setter: 56\n", "Irish Terrier: 57\n", "Irish Water Spaniel: 58\n", "Irish Wolfhound: 59\n", "Italian Greyhound: 60\n", "Japanese Spaniel: 61\n", "Keeshond: 62\n", "Kelpie: 63\n", "Kerry Blue Terrier: 64\n", "Komondor: 65\n", "Kuvasz: 66\n", "Labrador Retriever: 67\n", "Lakeland Terrier: 68\n", "Leonberg: 69\n", "Lhasa: 70\n", "Malamute: 71\n", "Malinois: 72\n", "Maltese Dog: 73\n", "Mexican Hairless: 74\n", "Miniature Pinscher: 75\n", "Miniature Poodle: 76\n", "Miniature Schnauzer: 77\n", "Newfoundland: 78\n", "Norfolk Terrier: 79\n", "Norwegian Elkhound: 80\n", "Norwich Terrier: 81\n", "Old English Sheepdog: 82\n", "Otterhound: 83\n", "Papillon: 84\n", "Pekinese: 85\n", "Pembroke: 86\n", "Pomeranian: 87\n", "Pug: 88\n", "Redbone: 89\n", "Rhodesian Ridgeback: 90\n", "Rottweiler: 91\n", "Saint Bernard: 92\n", "Saluki: 93\n", "Samoyed: 94\n", "Schipperke: 95\n", "Scotch Terrier: 96\n", "Scottish Deerhound: 97\n", "Sealyham Terrier: 98\n", "Shetland Sheepdog: 99\n", "Shih Tzu: 100\n", "Siberian Husky: 101\n", "Silky Terrier: 102\n", "Soft Coated Wheaten Terrier: 103\n", "Staffordshire Bullterrier: 104\n", "Standard Poodle: 105\n", "Standard Schnauzer: 106\n", "Sussex Spaniel: 107\n", "Tibetan Mastiff: 108\n", "Tibetan Terrier: 109\n", "Toy Poodle: 110\n", "Toy Terrier: 111\n", "Vizsla: 112\n", "Walker Hound: 113\n", "Weimaraner: 114\n", "Welsh Springer Spaniel: 115\n", "West Highland White Terrier: 116\n", "Whippet: 117\n", "Wire Haired Fox Terrier: 118\n", "Yorkshire Terrier: 119\n" ] } ], "source": [ "if CREATE_DATASET:\n", " ds = datasets.load_dataset(ORIGINAL_DATASET, token=os.getenv(\"HF_TOKEN\"), split=\"full\", trust_remote_code=True)\n", " ds = ds.remove_columns(REMOVE_COLUMNS).rename_columns(RENAME_COLUMNS)\n", "\n", " labels = ds.select_columns(\"label\").to_pandas().sort_values(\"label\").get(\"label\").unique().tolist()\n", " numbers = range(len(labels))\n", " label2int = dict(zip(labels, numbers))\n", " int2label = dict(zip(numbers, labels))\n", "\n", " for key, val in label2int.items():\n", " print(f\"{key}: {val}\")\n", "\n", " ds = ds.class_encode_column(\"label\")\n", " ds = ds.align_labels_with_mapping(label2int, \"label\")\n", "\n", " ds = ds.train_test_split(test_size=SPLIT, stratify_by_column = \"label\")\n", " #ds.push_to_hub(MODIFIED_DATASET, token=os.getenv(\"HF_TOKEN\"))\n", "\n", " CONFIG[\"label2int\"] = str(label2int)\n", " CONFIG[\"int2label\"] = str(int2label)\n", "\n", " # with open(\"output.toml\", \"w\") as toml_file:\n", " # toml.dump(toml.dumps(CONFIG), toml_file)\n", "\n", " #ds = datasets.load_dataset(MODIFIED_DATASET, token=os.getenv(\"HF_TOKEN\"), trust_remote_code=True, streaming=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n", "Some weights of SiglipForImageClassification were not initialized from the model checkpoint at google/siglip-base-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "max_steps is given, it will override any value given in num_train_epochs\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "343b3d32fc774b0f9a2b0dee471ec262", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1000 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 4.822, 'grad_norm': 11.180054664611816, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.16}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "48d9a7623928469c9bd9e61ad4a5573b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 4.254875183105469, 'eval_accuracy': 0.0782312925170068, 'eval_f1': 0.04927852996179247, 'eval_precision': 0.09874043278607707, 'eval_recall': 0.07264375052644872, 'eval_runtime': 55.3923, 'eval_samples_per_second': 74.306, 'eval_steps_per_second': 1.173, 'epoch': 0.16}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 4.236, 'grad_norm': 17.628389358520508, 'learning_rate': 4.9e-05, 'epoch': 0.31}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b7de237f61e54c158691316da90ae5a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 3.5278525352478027, 'eval_accuracy': 0.19071914480077745, 'eval_f1': 0.15072037668109098, 'eval_precision': 0.22011886017337598, 'eval_recall': 0.18300531751649823, 'eval_runtime': 55.5125, 'eval_samples_per_second': 74.145, 'eval_steps_per_second': 1.171, 'epoch': 0.31}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 3.5066, 'grad_norm': 19.224912643432617, 'learning_rate': 4.85e-05, 'epoch': 0.47}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "02737fd7a3554302b14c8fd8be6aad37", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 2.531590223312378, 'eval_accuracy': 0.33187560738581146, 'eval_f1': 0.2941424042839124, 'eval_precision': 0.4180298856360352, 'eval_recall': 0.320509389455932, 'eval_runtime': 56.5067, 'eval_samples_per_second': 72.841, 'eval_steps_per_second': 1.15, 'epoch': 0.47}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 2.8064, 'grad_norm': 22.580602645874023, 'learning_rate': 4.8e-05, 'epoch': 0.62}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5a4a4e46c1ec4352a0d248e7f8d53a9e", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 2.1243278980255127, 'eval_accuracy': 0.4361030126336249, 'eval_f1': 0.409040351489377, 'eval_precision': 0.5324247354698377, 'eval_recall': 0.4282087854976091, 'eval_runtime': 56.8186, 'eval_samples_per_second': 72.441, 'eval_steps_per_second': 1.144, 'epoch': 0.62}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 2.441, 'grad_norm': 17.738447189331055, 'learning_rate': 4.75e-05, 'epoch': 0.78}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "45d70e5072f1485199b309a10b746045", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 1.5798275470733643, 'eval_accuracy': 0.5510204081632653, 'eval_f1': 0.5250154943185481, 'eval_precision': 0.6242284529813324, 'eval_recall': 0.5437767896591994, 'eval_runtime': 57.0171, 'eval_samples_per_second': 72.189, 'eval_steps_per_second': 1.14, 'epoch': 0.78}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 2.0985, 'grad_norm': 18.94181251525879, 'learning_rate': 4.7e-05, 'epoch': 0.93}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "83e2ce0b615d4939819bf0db71dd745a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 1.42424476146698, 'eval_accuracy': 0.5843051506316812, 'eval_f1': 0.557705493250987, 'eval_precision': 0.6400162236362443, 'eval_recall': 0.5768419977409593, 'eval_runtime': 57.2333, 'eval_samples_per_second': 71.916, 'eval_steps_per_second': 1.136, 'epoch': 0.93}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.8689, 'grad_norm': 15.593049049377441, 'learning_rate': 4.6500000000000005e-05, 'epoch': 1.09}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e0e5069d05484298906c53c0409ae3b4", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 1.1481006145477295, 'eval_accuracy': 0.6625364431486881, 'eval_f1': 0.6455514206859728, 'eval_precision': 0.7142859368225736, 'eval_recall': 0.6564757487305617, 'eval_runtime': 54.7373, 'eval_samples_per_second': 75.196, 'eval_steps_per_second': 1.187, 'epoch': 1.09}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.6588, 'grad_norm': 18.39203453063965, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.24}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f3c04a3fbe7149e89e25a26ce4d543a7", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 1.1937264204025269, 'eval_accuracy': 0.6465014577259475, 'eval_f1': 0.6361000380324133, 'eval_precision': 0.7061715448588218, 'eval_recall': 0.6438641849166267, 'eval_runtime': 55.5191, 'eval_samples_per_second': 74.137, 'eval_steps_per_second': 1.171, 'epoch': 1.24}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.5807, 'grad_norm': 15.319233894348145, 'learning_rate': 4.55e-05, 'epoch': 1.4}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d8e25a97554844cb8d12799591b55105", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 0.9817520976066589, 'eval_accuracy': 0.70578231292517, 'eval_f1': 0.6890227220667341, 'eval_precision': 0.7438497507413404, 'eval_recall': 0.6980582780473442, 'eval_runtime': 54.4988, 'eval_samples_per_second': 75.525, 'eval_steps_per_second': 1.193, 'epoch': 1.4}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.4851, 'grad_norm': 15.890103340148926, 'learning_rate': 4.5e-05, 'epoch': 1.55}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ccacdd73fc6461b9a3a11bbf3c91179", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 1.0180633068084717, 'eval_accuracy': 0.6999514091350826, 'eval_f1': 0.6838587523312375, 'eval_precision': 0.7373019639077568, 'eval_recall': 0.6959074888023662, 'eval_runtime': 55.2973, 'eval_samples_per_second': 74.434, 'eval_steps_per_second': 1.175, 'epoch': 1.55}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.5033, 'grad_norm': 17.170801162719727, 'learning_rate': 4.4500000000000004e-05, 'epoch': 1.71}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5fbaecaf6e7e4a89b14a917a95a32c48", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 1.0169070959091187, 'eval_accuracy': 0.6914480077745384, 'eval_f1': 0.6845415929736886, 'eval_precision': 0.7489788726612852, 'eval_recall': 0.6883375806361393, 'eval_runtime': 54.525, 'eval_samples_per_second': 75.488, 'eval_steps_per_second': 1.192, 'epoch': 1.71}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.3022, 'grad_norm': 15.557647705078125, 'learning_rate': 4.4000000000000006e-05, 'epoch': 1.86}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f0ec4f69dabe4268a130a513f8556e8f", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 0.9087187051773071, 'eval_accuracy': 0.7276482021379981, 'eval_f1': 0.7169827898093813, 'eval_precision': 0.7642639115410531, 'eval_recall': 0.722171618202087, 'eval_runtime': 54.7556, 'eval_samples_per_second': 75.17, 'eval_steps_per_second': 1.187, 'epoch': 1.86}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.3106, 'grad_norm': 15.203620910644531, 'learning_rate': 4.35e-05, 'epoch': 2.02}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d4fdc66d719e477c91950d7c477e44a4", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 0.8385488986968994, 'eval_accuracy': 0.7431972789115646, 'eval_f1': 0.7352483871752059, 'eval_precision': 0.7666810806987456, 'eval_recall': 0.7363282855094594, 'eval_runtime': 57.5486, 'eval_samples_per_second': 71.522, 'eval_steps_per_second': 1.129, 'epoch': 2.02}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.1721, 'grad_norm': 18.051284790039062, 'learning_rate': 4.3e-05, 'epoch': 2.17}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dc6976d80b7e40bbb3cc0031eef848b8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 0.8956524133682251, 'eval_accuracy': 0.7128279883381924, 'eval_f1': 0.7025737877793609, 'eval_precision': 0.7591947211938203, 'eval_recall': 0.7074780847115492, 'eval_runtime': 58.6317, 'eval_samples_per_second': 70.201, 'eval_steps_per_second': 1.109, 'epoch': 2.17}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.131, 'grad_norm': 16.522109985351562, 'learning_rate': 4.25e-05, 'epoch': 2.33}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f602bc78c97842f5b7a7221559cb52f7", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 0.8729854226112366, 'eval_accuracy': 0.7259475218658892, 'eval_f1': 0.7148538617252097, 'eval_precision': 0.7687155689784482, 'eval_recall': 0.719605645045331, 'eval_runtime': 54.6147, 'eval_samples_per_second': 75.364, 'eval_steps_per_second': 1.19, 'epoch': 2.33}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.1223, 'grad_norm': 16.727994918823242, 'learning_rate': 4.2e-05, 'epoch': 2.48}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e7b0d1b380a145b8bf58631628e8830e", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 0.8132386803627014, 'eval_accuracy': 0.7546161321671526, 'eval_f1': 0.7457409451548778, 'eval_precision': 0.7855075954723149, 'eval_recall': 0.7482023394172116, 'eval_runtime': 54.8217, 'eval_samples_per_second': 75.08, 'eval_steps_per_second': 1.186, 'epoch': 2.48}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.0688, 'grad_norm': 14.611897468566895, 'learning_rate': 4.15e-05, 'epoch': 2.64}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2f071a57e8b14c4a9b85178f57885e03", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 0.7485197186470032, 'eval_accuracy': 0.7704081632653061, 'eval_f1': 0.7600821180493263, 'eval_precision': 0.7863249503261968, 'eval_recall': 0.7631296317667979, 'eval_runtime': 56.4165, 'eval_samples_per_second': 72.957, 'eval_steps_per_second': 1.152, 'epoch': 2.64}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 1.0686, 'grad_norm': 17.756242752075195, 'learning_rate': 4.1e-05, 'epoch': 2.79}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f4d3dc2d893047508b51c00948d8057c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 0.7559003233909607, 'eval_accuracy': 0.7650631681243926, 'eval_f1': 0.7586751052497263, 'eval_precision': 0.7920018070685718, 'eval_recall': 0.7609226412898984, 'eval_runtime': 53.1768, 'eval_samples_per_second': 77.402, 'eval_steps_per_second': 1.222, 'epoch': 2.79}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.9733, 'grad_norm': 14.432697296142578, 'learning_rate': 4.05e-05, 'epoch': 2.95}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "35ef518721bb4ff29f1d1d6286fc5f75", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/65 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'eval_loss': 0.7778576612472534, 'eval_accuracy': 0.7553449951409135, 'eval_f1': 0.7458481776644565, 'eval_precision': 0.7797152587168623, 'eval_recall': 0.7521461869682271, 'eval_runtime': 54.2043, 'eval_samples_per_second': 75.935, 'eval_steps_per_second': 1.199, 'epoch': 2.95}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " warnings.warn(\n" ] } ], "source": [ "metrics = {metric: evaluate.load(metric) for metric in METRICS}\n", "\n", "\n", "# for lr in [5e-3, 5e-4, 5e-5]: # 5e-5\n", "# for batch in [64]: # 32\n", "# for model_name in [\"google/vit-base-patch16-224\", \"microsoft/swinv2-base-patch4-window16-256\", \"google/siglip-base-patch16-224\"]: # \"facebook/dinov2-base\"\n", "\n", "lr = 5e-5\n", "batch = 64\n", "model_name = \"google/siglip-base-patch16-224\"\n", "\n", "image_processor = AutoImageProcessor.from_pretrained(model_name)\n", "model = AutoModelForImageClassification.from_pretrained(\n", "model_name,\n", "num_labels=len(label2int),\n", "id2label=int2label,\n", "label2id=label2int,\n", "ignore_mismatched_sizes=True,\n", ")\n", "\n", "# Then, in your transformations:\n", "def train_transform(examples, num_ops=10, magnitude=9, num_magnitude_bins=31):\n", "\n", " transformation = v2.Compose(\n", " [\n", " v2.RandAugment(\n", " num_ops=num_ops,\n", " magnitude=magnitude,\n", " num_magnitude_bins=num_magnitude_bins,\n", " )\n", " ]\n", " )\n", " # Ensure each image has three dimensions (in this case, ensure it's RGB)\n", " examples[\"pixel_values\"] = [\n", " image.convert(\"RGB\") for image in examples[\"pixel_values\"]\n", " ]\n", " # Apply transformations\n", " examples[\"pixel_values\"] = [\n", " image_processor(transformation(image), return_tensors=\"pt\")[\n", " \"pixel_values\"\n", " ].squeeze()\n", " for image in examples[\"pixel_values\"]\n", " ]\n", " return examples\n", "\n", "\n", "def test_transform(examples):\n", " # Ensure each image is RGB\n", " examples[\"pixel_values\"] = [\n", " image.convert(\"RGB\") for image in examples[\"pixel_values\"]\n", " ]\n", " # Apply processing\n", " examples[\"pixel_values\"] = [\n", " image_processor(image, return_tensors=\"pt\")[\"pixel_values\"].squeeze()\n", " for image in examples[\"pixel_values\"]\n", " ]\n", " return examples\n", "\n", "\n", "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " # predictions = np.argmax(logits, axis=-1)\n", " results = {}\n", " for key, val in metrics.items():\n", " if \"accuracy\" == key:\n", " result = next(\n", " iter(val.compute(predictions=predictions, references=labels).items())\n", " )\n", " if \"accuracy\" != key:\n", " result = next(\n", " iter(\n", " val.compute(\n", " predictions=predictions, references=labels, average=\"macro\"\n", " ).items()\n", " )\n", " )\n", " results[result[0]] = result[1]\n", " return results\n", "\n", "\n", "def collate_fn(examples):\n", " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n", " labels = torch.tensor([example[\"label\"] for example in examples])\n", " return {\"pixel_values\": pixel_values, \"labels\": labels}\n", "\n", "\n", "def preprocess_logits_for_metrics(logits, labels):\n", " \"\"\"\n", " Original Trainer may have a memory leak.\n", " This is a workaround to avoid storing too many tensors that are not needed.\n", " \"\"\"\n", " pred_ids = torch.argmax(logits, dim=-1)\n", " return pred_ids\n", "\n", "ds[\"train\"].set_transform(train_transform)\n", "ds[\"test\"].set_transform(test_transform)\n", "\n", "training_args = TrainingArguments(**CONFIG[\"training_args\"])\n", "training_args.per_device_train_batch_size = batch\n", "training_args.per_device_eval_batch_size = batch\n", "training_args.hub_model_id = f\"amaye15/{model_name.replace('/','-')}-batch{batch}-lr{lr}-standford-dogs\"\n", "\n", "mlflow.start_run(run_name=f\"{model_name.replace('/','-')}-batch{batch}-lr{lr}\")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=ds[\"train\"],\n", " eval_dataset=ds[\"test\"],\n", " tokenizer=image_processor,\n", " data_collator=collate_fn,\n", " compute_metrics=compute_metrics,\n", " # callbacks=[early_stopping_callback],\n", " preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n", ")\n", "\n", "# Train the model\n", "trainer.train()\n", "\n", "trainer.push_to_hub()\n", "\n", "mlflow.end_run()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'mlflow' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmlflow\u001b[49m\u001b[38;5;241m.\u001b[39mend_run()\n", "\u001b[0;31mNameError\u001b[0m: name 'mlflow' is not defined" ] } ], "source": [ "mlflow.end_run()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# training_args = TrainingArguments(**CONFIG[\"training_args\"])\n", "\n", "# image_processor = AutoImageProcessor.from_pretrained(MODELS)\n", "# model = AutoModelForImageClassification.from_pretrained(\n", "# MODELS,\n", "# num_labels=len(CONFIG[\"label2int\"]),\n", "# id2label=CONFIG[\"label2int\"],\n", "# label2id=CONFIG[\"int2label\"],\n", "# ignore_mismatched_sizes=True,\n", "# )\n", "\n", "\n", "# training_args = TrainingArguments(**CONFIG[\"training_args\"])\n", "\n", "# trainer = Trainer(\n", "# model=model,\n", "# args=training_args,\n", "# train_dataset=ds[\"train\"],\n", "# eval_dataset=ds[\"test\"],\n", "# tokenizer=image_processor,\n", "# data_collator=collate_fn,\n", "# compute_metrics=compute_metrics,\n", "# # callbacks=[early_stopping_callback],\n", "# preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n", "# )\n", "\n", "# # Train the model\n", "# trainer.train()\n", "\n", "# mlflow.end_run()" ] } ], "metadata": { "kernelspec": { "display_name": "env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }