{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4beac401",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from models.yolo import Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1a8399f",
   "metadata": {},
   "source": [
    "## Convert YOLOv9-S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a40f10",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "cfg = \"./models/detect/gelan-s.yaml\"\n",
    "model = Model(cfg, ch=3, nc=80, anchors=3)\n",
    "#model = model.half()\n",
    "model = model.to(device)\n",
    "_ = model.eval()\n",
    "ckpt = torch.load('./yolov9-s.pt', map_location='cpu')\n",
    "model.names = ckpt['model'].names\n",
    "model.nc = ckpt['model'].nc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b046bb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 0\n",
    "for k, v in model.state_dict().items():\n",
    "    if \"model.{}.\".format(idx) in k:\n",
    "        if idx < 22:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv2.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv3.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.dfl.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "    else:\n",
    "        while True:\n",
    "            idx += 1\n",
    "            if \"model.{}.\".format(idx) in k:\n",
    "                break\n",
    "        if idx < 22:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv2.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv3.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.dfl.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "_ = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07eb0cde",
   "metadata": {},
   "outputs": [],
   "source": [
    "m_ckpt = {'model': model.half(),\n",
    "          'optimizer': None,\n",
    "          'best_fitness': None,\n",
    "          'ema': None,\n",
    "          'updates': None,\n",
    "          'opt': None,\n",
    "          'git': None,\n",
    "          'date': None,\n",
    "          'epoch': -1}\n",
    "torch.save(m_ckpt, \"./yolov9-s-converted.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba87d10f",
   "metadata": {},
   "source": [
    "## Convert YOLOv9-M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc41b027",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "cfg = \"./models/detect/gelan-m.yaml\"\n",
    "model = Model(cfg, ch=3, nc=80, anchors=3)\n",
    "#model = model.half()\n",
    "model = model.to(device)\n",
    "_ = model.eval()\n",
    "ckpt = torch.load('./yolov9-m.pt', map_location='cpu')\n",
    "model.names = ckpt['model'].names\n",
    "model.nc = ckpt['model'].nc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf7c3978",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 0\n",
    "for k, v in model.state_dict().items():\n",
    "    if \"model.{}.\".format(idx) in k:\n",
    "        if idx < 22:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+1))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv2.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv3.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.dfl.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "    else:\n",
    "        while True:\n",
    "            idx += 1\n",
    "            if \"model.{}.\".format(idx) in k:\n",
    "                break\n",
    "        if idx < 22:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+1))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv2.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv3.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.dfl.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "_ = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00a92a45",
   "metadata": {},
   "outputs": [],
   "source": [
    "m_ckpt = {'model': model.half(),\n",
    "          'optimizer': None,\n",
    "          'best_fitness': None,\n",
    "          'ema': None,\n",
    "          'updates': None,\n",
    "          'opt': None,\n",
    "          'git': None,\n",
    "          'date': None,\n",
    "          'epoch': -1}\n",
    "torch.save(m_ckpt, \"./yolov9-m-converted.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8680f822",
   "metadata": {},
   "source": [
    "## Convert YOLOv9-C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59f0198d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "cfg = \"./models/detect/gelan-c.yaml\"\n",
    "model = Model(cfg, ch=3, nc=80, anchors=3)\n",
    "#model = model.half()\n",
    "model = model.to(device)\n",
    "_ = model.eval()\n",
    "ckpt = torch.load('./yolov9-c.pt', map_location='cpu')\n",
    "model.names = ckpt['model'].names\n",
    "model.nc = ckpt['model'].nc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2de7e1be",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 0\n",
    "for k, v in model.state_dict().items():\n",
    "    if \"model.{}.\".format(idx) in k:\n",
    "        if idx < 22:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+1))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "        elif \"model.{}.cv2.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "        elif \"model.{}.cv3.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "        elif \"model.{}.dfl.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "    else:\n",
    "        while True:\n",
    "            idx += 1\n",
    "            if \"model.{}.\".format(idx) in k:\n",
    "                break\n",
    "        if idx < 22:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+1))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "        elif \"model.{}.cv2.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "        elif \"model.{}.cv3.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "        elif \"model.{}.dfl.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+16))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "_ = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "960796e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "m_ckpt = {'model': model.half(),\n",
    "          'optimizer': None,\n",
    "          'best_fitness': None,\n",
    "          'ema': None,\n",
    "          'updates': None,\n",
    "          'opt': None,\n",
    "          'git': None,\n",
    "          'date': None,\n",
    "          'epoch': -1}\n",
    "torch.save(m_ckpt, \"./yolov9-c-converted.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47c6e6ae",
   "metadata": {},
   "source": [
    "## Convert YOLOv9-E"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "801a1b7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "cfg = \"./models/detect/gelan-e.yaml\"\n",
    "model = Model(cfg, ch=3, nc=80, anchors=3)\n",
    "#model = model.half()\n",
    "model = model.to(device)\n",
    "_ = model.eval()\n",
    "ckpt = torch.load('./yolov9-e.pt', map_location='cpu')\n",
    "model.names = ckpt['model'].names\n",
    "model.nc = ckpt['model'].nc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2ef4fe6",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 0\n",
    "for k, v in model.state_dict().items():\n",
    "    if \"model.{}.\".format(idx) in k:\n",
    "        if idx < 29:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif idx < 42:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv2.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv3.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.dfl.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "    else:\n",
    "        while True:\n",
    "            idx += 1\n",
    "            if \"model.{}.\".format(idx) in k:\n",
    "                break\n",
    "        if idx < 29:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif idx < 42:\n",
    "            kr = k.replace(\"model.{}.\".format(idx), \"model.{}.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv2.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv2.\".format(idx), \"model.{}.cv4.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.cv3.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.cv3.\".format(idx), \"model.{}.cv5.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "        elif \"model.{}.dfl.\".format(idx) in k:\n",
    "            kr = k.replace(\"model.{}.dfl.\".format(idx), \"model.{}.dfl2.\".format(idx+7))\n",
    "            model.state_dict()[k] -= model.state_dict()[k]\n",
    "            model.state_dict()[k] += ckpt['model'].state_dict()[kr]\n",
    "            print(k, \"perfectly matched!!\")\n",
    "_ = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27bc1869",
   "metadata": {},
   "outputs": [],
   "source": [
    "m_ckpt = {'model': model.half(),\n",
    "          'optimizer': None,\n",
    "          'best_fitness': None,\n",
    "          'ema': None,\n",
    "          'updates': None,\n",
    "          'opt': None,\n",
    "          'git': None,\n",
    "          'date': None,\n",
    "          'epoch': -1}\n",
    "torch.save(m_ckpt, \"./yolov9-e-converted.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}