{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Base Configurations" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "from transformers import SegformerForSemanticSegmentation\n", "from dataclasses import dataclass\n", "\n", "\n", "@dataclass\n", "class Configs:\n", " NUM_CLASSES = 4\n", " MODEL_PATH: str = \"nvidia/segformer-b4-finetuned-ade-512-512\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Model To Inspect Parameter Names" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "def get_model(*, model_path, num_classes):\n", " model = SegformerForSemanticSegmentation.from_pretrained(\n", " model_path,\n", " num_labels=num_classes,\n", " ignore_mismatched_sizes=True,\n", " )\n", " return model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:\n", "- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([4, 768, 1, 1]) in the model instantiated\n", "- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([4]) in the model instantiated\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "segformer.encoder.patch_embeddings.0.proj.weight\n", "segformer.encoder.patch_embeddings.0.proj.bias\n", "segformer.encoder.patch_embeddings.0.layer_norm.weight\n", "segformer.encoder.patch_embeddings.0.layer_norm.bias\n", "segformer.encoder.patch_embeddings.1.proj.weight\n", "segformer.encoder.patch_embeddings.1.proj.bias\n" ] } ], "source": [ "model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)\n", "model_state_dict = model.state_dict()\n", "\n", "print()\n", "for i, (key, val) in enumerate(model_state_dict.items()):\n", " print(key)\n", " if i == 5:\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download & load PyTorch-Lightning Checkpoint and Inspect Parameter Names" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mveb-101\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.15.5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in c:\\Users\\vaibh\\OneDrive\\Desktop\\Work\\BigVision\\BLOG_POSTS\\Medical_segmentation\\GRADIO_APP\\UWMGI_Medical_Image_Segmentation\\wandb\\run-20230719_204221-w5qu5rqw" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run ethereal-bush-2 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/w5qu5rqw" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-fpgquxev:v0, 977.89MB. 1 files... \n", "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", "Done. 0:1:5.3\n" ] } ], "source": [ "import wandb\n", "\n", "run = wandb.init()\n", "artifact = run.use_artifact(\"veb-101/UM_medical_segmentation/model-fpgquxev:v0\", type=\"model\")\n", "artifact_dir = artifact.download()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecisionPlugin', 'hparams_name', 'hyper_parameters'])\n" ] } ], "source": [ "CKPT = torch.load(os.path.join(artifact_dir, \"model.ckpt\"), map_location=\"cpu\")\n", "print(CKPT.keys())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model.segformer.encoder.patch_embeddings.0.proj.weight\n", "model.segformer.encoder.patch_embeddings.0.proj.bias\n", "model.segformer.encoder.patch_embeddings.0.layer_norm.weight\n", "model.segformer.encoder.patch_embeddings.0.layer_norm.bias\n", "model.segformer.encoder.patch_embeddings.1.proj.weight\n", "model.segformer.encoder.patch_embeddings.1.proj.bias\n" ] } ], "source": [ "TRAINED_CKPT_state_dict = CKPT[\"state_dict\"]\n", "\n", "for i, (key, val) in enumerate(TRAINED_CKPT_state_dict.items()):\n", " print(key)\n", " if i == 5:\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**The pytorch-lightning `state_dict()` has an extra `model.` string at the front that refers to the object/variable name that was holding the model in the `LightningModule` class.**\n", "\n", "We can simply iterate over the parameters and change the parameter key name. We'll create a new `OrderedDict` for it." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from collections import OrderedDict\n", "\n", "new_state_dict = OrderedDict()\n", "\n", "for key_name, value in CKPT[\"state_dict\"].items():\n", " new_state_dict[key_name.replace(\"model.\", \"\")] = value" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.load_state_dict(new_state_dict)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# torch.save(model.state_dict(), \"Segformer_best_state_dict.ckpt\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "model.save_pretrained(\"segformer_trained_weights\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To load the saved model, we simply need to pass the path to the directory \"segformer_trained_weights\"." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# model = get_model(model_path=os.path.join(os.getcwd(), \"segformer_trained_weights\"), num_classes=Configs.NUM_CLASSES)" ] } ], "metadata": { "kernelspec": { "display_name": "pytorchx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }