diff --git "a/ViTPose/easy_ViTPose/easy_ViTPose/ViTPose_Inference.ipynb" "b/ViTPose/easy_ViTPose/easy_ViTPose/ViTPose_Inference.ipynb" new file mode 100644--- /dev/null +++ "b/ViTPose/easy_ViTPose/easy_ViTPose/ViTPose_Inference.ipynb" @@ -0,0 +1,2423 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "dcfcb0a2-93fb-4f3c-aa40-56fbe6a5dcff", + "metadata": {}, + "outputs": [], + "source": [ + "import cv2\n", + "from easy_ViTPose import VitInference\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Image to run inference RGB format\n", + "img = cv2.imread('testVITPOSE.jpg')\n", + "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", + "\n", + "# set is_video=True to enable tracking in video inference\n", + "# be sure to use VitInference.reset() function to reset the tracker after each video\n", + "# There are a few flags that allows to customize VitInference, be sure to check the class definition\n", + "model_path = r'C:\\Users\\user\\ViTPose/ckpts/vitpose-s-coco_25.pth'\n", + "yolo_path = r'C:\\Users\\user\\ViTPose/yolov8s.pt'\n", + "\n", + "# If you want to use MPS (on new macbooks) use the torch checkpoints for both ViTPose and Yolo\n", + "# If device is None will try to use cuda -> mps -> cpu (otherwise specify 'cpu', 'mps' or 'cuda')\n", + "# dataset and det_class parameters can be inferred from the ckpt name, but you can specify them.\n", + "model = VitInference(model_path, yolo_path, model_name='s', yolo_size=320, is_video=False, device=\"cuda\")\n", + "\n", + "# Infer keypoints, output is a dict where keys are person ids and values are keypoints (np.ndarray (25, 3): (y, x, score))\n", + "# If is_video=True the IDs will be consistent among the ordered video frames.\n", + "keypoints = model.inference(img)\n", + "\n", + "# call model.reset() after each video\n", + "\n", + "img = model.draw(show_yolo=True) # Returns RGB image with drawings\n", + "plt.imshow(img)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "772c119d-0e34-488a-bcec-40e0007155aa", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9e2a99d2-ece2-4f00-b9e1-e130099026bd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "torch.cuda.is_available()" + ] + }, + { + "cell_type": "markdown", + "id": "ea96beea-c174-45c4-9119-e3db40f18793", + "metadata": {}, + "source": [ + "# Training the ViT_Pose" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ba4a27fc-20e0-433f-9de5-65320f963af9", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) OpenMMLab. All rights reserved.\n", + "import argparse\n", + "import copy\n", + "import os\n", + "import os.path as osp\n", + "import time\n", + "import warnings\n", + "import click\n", + "import yaml\n", + "\n", + "from glob import glob\n", + "\n", + "import torch\n", + "import torch.distributed as dist\n", + "\n", + "from vit_utils.util import init_random_seed, set_random_seed\n", + "from vit_utils.dist_util import get_dist_info, init_dist\n", + "from vit_utils.logging import get_root_logger\n", + "\n", + "import configs.ViTPose_small_coco_256x192 as s_cfg\n", + "# import configs.ViTPose_base_coco_256x192 as b_cfg\n", + "# import configs.ViTPose_large_coco_256x192 as l_cfg\n", + "# import configs.ViTPose_huge_coco_256x192 as h_cfg\n", + "\n", + "from vit_models.model import ViTPose\n", + "from datasets.COCO import COCODataset\n", + "from vit_utils.train_valid_fn import train_model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4ef1a26d-9303-4859-9112-bedea1dd46e8", + "metadata": {}, + "outputs": [], + "source": [ + "__file__ = r\"C:\\Users\\user\\ViTPose\\easy_ViTPose\\easy_ViTPose\"\n", + "CUR_PATH = osp.dirname(__file__)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "21be3367-277e-4d8c-8fca-d7524235c21b", + "metadata": {}, + "outputs": [], + "source": [ + "model_name = 's'\n", + "config_path = 'config.yaml'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16e707e1-e55b-4c44-abc3-fd217a60b381", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "2d5c73bb-6486-4513-af1c-3c37c08e80f8", + "metadata": {}, + "source": [ + "### Loading the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "41ea408e-c350-48d5-bf88-d9a2a68322c2", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "# Load the JSON file\n", + "with open(r\"D:\\ViTPose\\Evaluating\\annotations\\person_keypoints_val2017.json\", 'r') as f:\n", + " coco_data = json.load(f)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bee2d84-71c7-4f53-8154-3b7340bd8708", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "590ad908-cc80-4c4e-a2af-88450a2aa77e", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c1fa1c4b-75ff-4ed9-a46d-90c639acae41", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a mapping of image_id to file_name\n", + "image_id_to_filename = {img['id']: img['file_name'] for img in coco_data['images']}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d8b2585-0ef4-4dc2-86ef-ab6d5d799075", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b83f079-efd1-416b-ab6c-29e80e668cff", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1dde6eac-a62d-49b1-ad27-a39c114e0f1b", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "37899bdc-69a4-4271-8fd7-aa03148883a2", + "metadata": {}, + "outputs": [], + "source": [ + "# # Example: Process keypoints for one annotation\n", + "# for ann in annotations:\n", + "# keypoints = ann['keypoints']\n", + "# keypoints_array = [keypoints[i:i + 3] for i in range(0, len(keypoints), 3)]\n", + "# print(\"Keypoints:\", keypoints_array)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9942e11-9394-4f67-a5ed-2ee700d33625", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b030648f-8e1f-4abe-8351-a7ae34cb3bb2", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9b30c5c1-6657-4102-8cd6-3f68b100bf61", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from PIL import Image\n", + "import numpy as np\n", + "\n", + "data_dir = r'D:\\ViTPose\\Evaluating\\val2017\\\\'\n", + "dataset = []\n", + "\n", + "for ann in coco_data['annotations']:\n", + " image_id = ann['image_id']\n", + " #print(\"image_id: \", image_id)\n", + " file_name = image_id_to_filename[image_id]\n", + " #print(\"file_name: \", file_name)\n", + " image_path = os.path.join(data_dir, file_name)\n", + " #print(\"image_path: \", image_path)\n", + " # Load the image\n", + " if not os.path.exists(image_path):\n", + " continue\n", + " image = Image.open(image_path).convert('RGB')\n", + " \n", + " # Process keypoints\n", + " keypoints = ann['keypoints']\n", + " keypoints_array = np.array([keypoints[i:i + 3] for i in range(0, len(keypoints), 3)])\n", + " \n", + " # Collect data\n", + " dataset.append((image, keypoints_array))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a9c3c26-7d13-4103-97c1-f5d51cef5584", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "47734b7b-3c0f-4e58-abe6-bc6a70dc29c9", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "397133\n", + "000000397133.jpg\n" + ] + } + ], + "source": [ + "print(coco_data['images'][0]['id'])\n", + "print(coco_data['images'][0]['file_name'])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9463a1ea-09be-4138-b6d7-501a4f25c2f1", + "metadata": {}, + "outputs": [], + "source": [ + "# Apply scaling transformation for each keypoint\n", + "def resize_keypoints(keypoints, scale_w, scale_h):\n", + " resized_keypoints = keypoints.clone()\n", + " \n", + " for j in range(keypoints.shape[0]):\n", + " x, y, visibility = keypoints[j]\n", + " # Only resize if visibility > 0 (to ignore invisible keypoints)\n", + " if visibility > 0:\n", + " resized_keypoints[j, 0] = int(x * scale_w)\n", + " resized_keypoints[j, 1] = int(y * scale_h)\n", + " \n", + " return resized_keypoints\n", + "\n", + "\n", + "\n", + "def transformKeypoint(img, target_shape, keypoints):\n", + " orig_width, orig_height = img.width, img.height\n", + " (target_width, target_height) = target_shape\n", + " \n", + " # Scaling factors for width and height\n", + " scale_w = target_width / orig_width\n", + " scale_h = target_height / orig_height\n", + " # Resized keypoints\n", + " resized_keypoints = resize_keypoints(keypoints, scale_w, scale_h)\n", + " \n", + " # Print the resized keypoints\n", + " #print(\"Resized Keypoints:\\n\", resized_keypoints)\n", + " return resized_keypoints\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "012be853-653e-4651-9861-fc2e11a00a00", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be1f8f9d-2c11-4ddc-a646-e193b79d3829", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "3faf2805-5a62-4f5a-967b-4e7ec1c32c87", + "metadata": {}, + "outputs": [], + "source": [ + "target_shape = (208, 208)\n", + "import torch\n", + "from torch.utils.data import Dataset\n", + "\n", + "class COCOKeypointsDataset(Dataset):\n", + " def __init__(self, json_path, images_dir, transform=None, transformKP = None):\n", + " with open(json_path, 'r') as f:\n", + " coco_data = json.load(f)\n", + " \n", + " self.image_id_to_filename = {img['id']: img['file_name'] for img in coco_data['images']}\n", + " self.annotations = coco_data['annotations']\n", + " self.images_dir = images_dir\n", + " self.transform = transform\n", + "\n", + " def __len__(self):\n", + " return len(self.annotations)\n", + "\n", + " def __getitem__(self, idx):\n", + " # Get annotation\n", + " ann = self.annotations[idx]\n", + " image_id = ann['image_id']\n", + " file_name = self.image_id_to_filename[image_id]\n", + " image_path = os.path.join(self.images_dir, file_name)\n", + " \n", + " # Load image\n", + " image = Image.open(image_path).convert('RGB')\n", + " \n", + " # Process keypoints\n", + " keypoints = ann['keypoints']\n", + " keypoints = torch.tensor([keypoints[i:i + 3] for i in range(0, len(keypoints), 3)], dtype=torch.float32) \n", + " keypoints = transformKeypoint(image, target_shape, keypoints)\n", + " #print(\"keypoints: \", keypoints)\n", + " \n", + " # Apply transformations\n", + " if self.transform:\n", + " image = self.transform(image)\n", + " \n", + " return image, keypoints\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "960b56d0-391d-4cd5-886c-c951d5e5bf63", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccddf3cf-5b28-4fad-a467-917a53d19d63", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47ebf732-8f35-4caa-85ec-5dabb6e95e93", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "05380b1e-f0d4-4dcc-84de-147726da3ac4", + "metadata": {}, + "outputs": [], + "source": [ + "from torchvision import transforms\n", + "\n", + "transform = transforms.Compose([\n", + " transforms.Resize((208, 208)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", + "])\n", + "\n", + "dataset = COCOKeypointsDataset(\n", + " json_path=r\"D:\\ViTPose\\Evaluating\\annotations\\person_keypoints_val2017.json\",\n", + " images_dir=r'D:\\ViTPose\\Evaluating\\val2017\\\\',\n", + " transform=transform\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "62efbb61-7d6e-4b8d-9fe5-537ca0f7ee04", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "dataloader = DataLoader(dataset, batch_size=4, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c2a8ab8-5496-462d-a58b-f87f06e427bf", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d9d940c5-13f6-4b24-a887-28f08f370d04", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Images shape: torch.Size([4, 3, 208, 208])\n", + "Keypoints shape: torch.Size([4, 17, 3])\n", + "keypoints: tensor([[[ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [150., 2., 2.],\n", + " [146., 10., 2.],\n", + " [ 0., 0., 0.],\n", + " [145., 18., 2.],\n", + " [ 0., 0., 0.],\n", + " [148., 20., 2.],\n", + " [151., 19., 2.],\n", + " [146., 33., 1.],\n", + " [151., 33., 1.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.]],\n", + "\n", + " [[ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.]],\n", + "\n", + " [[ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.]],\n", + "\n", + " [[152., 67., 2.],\n", + " [155., 65., 2.],\n", + " [150., 64., 2.],\n", + " [158., 65., 2.],\n", + " [ 0., 0., 0.],\n", + " [163., 81., 2.],\n", + " [146., 76., 2.],\n", + " [166., 104., 2.],\n", + " [140., 88., 2.],\n", + " [160., 123., 2.],\n", + " [141., 99., 2.],\n", + " [156., 119., 2.],\n", + " [146., 115., 2.],\n", + " [154., 149., 2.],\n", + " [142., 146., 2.],\n", + " [152., 177., 2.],\n", + " [144., 171., 2.]]])\n" + ] + } + ], + "source": [ + "for images, keypoints in dataloader:\n", + " print(\"Images shape:\", images.shape)\n", + " print(\"Keypoints shape:\", keypoints.shape)\n", + " print(\"keypoints: \", keypoints)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b26cfb19-15d6-47a6-8114-6ef1385d6fbc", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e5e00bc-e79f-408f-bd3c-7663225cde47", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e24981bc-7a91-45ba-bcc5-d79a7604b8b3", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "670e81c3-27a2-43ef-a149-f4d88d7214c4", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "97e0ee06-e303-4a6c-8cda-eb6087693980", + "metadata": {}, + "source": [ + "### Ending loading the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8334085f-85d4-4889-9f7f-71e8b7b6adc5", + "metadata": {}, + "outputs": [], + "source": [ + "cfg = {'s':s_cfg}.get(model_name.lower())" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b7b53af5-0822-47a4-be61-6f8b2de9f9c6", + "metadata": {}, + "outputs": [], + "source": [ + "# Load config.yaml\n", + "with open(config_path, 'r') as f:\n", + " cfg_yaml = yaml.load(f, Loader=yaml.SafeLoader)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7ef13d0b-2fd5-4313-9822-8b1626da933f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'log_level': 'logging.INFO',\n", + " 'seed': 0,\n", + " 'gpu_ids': [0],\n", + " 'deterministic': True,\n", + " 'cudnn_benchmark': True,\n", + " 'resume_from': 'C:/Users/user/ViTPose/ckpts/vitpose-s-coco_25.pth',\n", + " 'launcher': 'none',\n", + " 'use_amp': False,\n", + " 'validate': True,\n", + " 'autoscale_lr': False,\n", + " 'dist_params': '...'}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg_yaml" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "64402e77-c6b8-49bb-90db-6693de309268", + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in cfg_yaml.items():\n", + " if hasattr(cfg, k):\n", + " raise ValueError(f\"Already exists {k} in config\")\n", + " else:\n", + " cfg.__setattr__(k, v)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98b406f9-d79a-40ed-af8a-e73d4808776a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "057c9e72-4693-4387-a141-a844159b6410", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88841210-3245-47d4-8fea-22a3fdff04ab", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "765d06d7-dcc9-46a0-a9bc-9569ff314eec", + "metadata": {}, + "outputs": [], + "source": [ + "# set cudnn_benchmark\n", + "if cfg.cudnn_benchmark:\n", + " torch.backends.cudnn.benchmark = True" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "14bcd3c1-01e4-4308-a693-7cba5b28ec97", + "metadata": {}, + "outputs": [], + "source": [ + "# Set work directory (session-level)\n", + "if not hasattr(cfg, 'work_dir'):\n", + " cfg.__setattr__('work_dir', f\"{CUR_PATH}/runs/train\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "fe3506fa-36d0-461e-8f18-856a35ae3268", + "metadata": {}, + "outputs": [], + "source": [ + "if not osp.exists(cfg.work_dir):\n", + " os.makedirs(cfg.work_dir)\n", + "session_list = sorted(glob(f\"{cfg.work_dir}/*\"))\n", + "if len(session_list) == 0:\n", + " session = 1\n", + "else:\n", + " session = int(os.path.basename(session_list[-1])) + 1\n", + "session_dir = osp.join(cfg.work_dir, str(session).zfill(3))\n", + "os.makedirs(session_dir)\n", + "cfg.__setattr__('work_dir', session_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "aaffeee8-3d10-4cb8-ab1b-98f79bdb9910", + "metadata": {}, + "outputs": [], + "source": [ + "if cfg.autoscale_lr:\n", + " # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)\n", + " cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c022f17-9135-4317-a8ca-e7d4b51b4d1a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "bc6c864a-6cf0-4599-aa2a-0de7f6fd19ec", + "metadata": {}, + "outputs": [], + "source": [ + "# init distributed env first, since logger depends on the dist info.\n", + "if cfg.launcher == 'none':\n", + " distributed = False\n", + " if len(cfg.gpu_ids) > 1:\n", + " warnings.warn(\n", + " f\"We treat {cfg['gpu_ids']} as gpu-ids, and reset to \"\n", + " f\"{cfg['gpu_ids'][0:1]} as gpu-ids to avoid potential error in \"\n", + " \"non-distribute training time.\")\n", + " cfg.gpu_ids = cfg.gpu_ids[0:1]\n", + "else:\n", + " distributed = True\n", + " init_dist(cfg.launcher, **cfg.dist_params)\n", + " # re-set gpu_ids with distributed training mode\n", + " _, world_size = get_dist_info()\n", + " cfg.gpu_ids = range(world_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "935e56e5-d5d1-4f19-bba0-6d659665c570", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "3881144e-53b7-4a2b-b878-0d8c43a17d82", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-12-25 08:52:31,377 - vit_utils - INFO - Distributed training: False\n", + "2024-12-25 08:52:31,378 - vit_utils - INFO - Set random seed to 0, deterministic: True\n" + ] + } + ], + "source": [ + "# init the logger before other steps\n", + "timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n", + "log_file = osp.join(session_dir, f'{timestamp}.log')\n", + "logger = get_root_logger(log_file=log_file)\n", + "\n", + "# init the meta dict to record some important information such as\n", + "# environment info and seed, which will be logged\n", + "meta = dict()\n", + "\n", + "# log some basic info\n", + "logger.info(f'Distributed training: {distributed}')\n", + "\n", + "# set random seeds\n", + "seed = init_random_seed(cfg.seed)\n", + "logger.info(f\"Set random seed to {seed}, \"\n", + " f\"deterministic: {cfg.deterministic}\")\n", + "set_random_seed(seed, deterministic=cfg.deterministic)\n", + "meta['seed'] = seed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0588f6f-6fe0-4ac6-9fcf-a2c744e68091", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "5878400f-1165-4301-aad8-aef907113a4a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\user\\AppData\\Local\\Temp\\ipykernel_19640\\1963230343.py:5: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " ckpt_state = torch.load(cfg.resume_from) #['state_dict']\n" + ] + } + ], + "source": [ + "# Set model\n", + "model = ViTPose(cfg.model)\n", + "if cfg.resume_from:\n", + " # Load ckpt partially\n", + " ckpt_state = torch.load(cfg.resume_from) #['state_dict']\n", + " ckpt_state.pop('keypoint_head.final_layer.bias')\n", + " ckpt_state.pop('keypoint_head.final_layer.weight')\n", + " model.load_state_dict(ckpt_state, strict=False)\n", + "\n", + " # freeze the backbone, leave the head to be finetuned\n", + " model.backbone.frozen_stages = model.backbone.depth - 1\n", + " model.backbone.freeze_ffn = True\n", + " model.backbone.freeze_attn = True\n", + " model.backbone._freeze_stages()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52cef111-bc9c-417d-9770-7595408863be", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "93e7765d-94ca-4a16-b4d6-c95085bfad35", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "ViTPose(\n", + " (backbone): ViT(\n", + " (patch_embed): PatchEmbed(\n", + " (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16), padding=(2, 2))\n", + " )\n", + " (blocks): ModuleList(\n", + " (0): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (1): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.00909090880304575)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (2): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.0181818176060915)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (3): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.027272727340459824)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (4): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.036363635212183)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (5): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.045454543083906174)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (6): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.054545458406209946)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (7): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.06363636255264282)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (8): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.0727272778749466)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (9): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.08181818574666977)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (10): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.09090909361839294)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (11): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath(p=0.10000000149011612)\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " (last_norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " )\n", + " (keypoint_head): TopdownHeatmapSimpleHead(\n", + " (deconv_layers): Sequential(\n", + " (0): ConvTranspose2d(384, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (final_layer): Conv2d(256, 17, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + ")" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "851da2fd-707c-4dc9-bca0-99ac5530374a", + "metadata": {}, + "outputs": [], + "source": [ + "# Set dataset\n", + "datasets_train = dataloader\n", + "datasets_valid = dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d352a68-dccd-4d94-86c7-deb716545df4", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "5af32dc2-8a9c-4abf-bc31-1454d64622e2", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch 1:\n", + " - Images: torch.Size([4, 3, 208, 208])\n", + " - Labels: tensor([[[ 83., 46., 2.],\n", + " [ 0., 0., 0.],\n", + " [ 83., 44., 2.],\n", + " [ 0., 0., 0.],\n", + " [ 79., 44., 2.],\n", + " [ 83., 53., 2.],\n", + " [ 75., 54., 2.],\n", + " [ 86., 64., 2.],\n", + " [ 78., 70., 2.],\n", + " [ 90., 78., 2.],\n", + " [ 87., 79., 2.],\n", + " [ 83., 80., 2.],\n", + " [ 78., 81., 2.],\n", + " [ 0., 0., 0.],\n", + " [ 80., 99., 2.],\n", + " [ 0., 0., 0.],\n", + " [ 74., 121., 2.]],\n", + "\n", + " [[ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [200., 83., 2.],\n", + " [ 0., 0., 0.],\n", + " [192., 150., 2.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.]],\n", + "\n", + " [[ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 74., 48., 2.],\n", + " [ 63., 59., 2.],\n", + " [ 64., 60., 2.],\n", + " [ 70., 82., 2.],\n", + " [ 72., 85., 2.],\n", + " [ 83., 83., 2.],\n", + " [ 76., 62., 2.],\n", + " [ 58., 107., 2.],\n", + " [ 58., 109., 2.],\n", + " [ 83., 87., 2.],\n", + " [ 84., 98., 2.],\n", + " [ 75., 115., 1.],\n", + " [ 77., 141., 2.]],\n", + "\n", + " [[ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.],\n", + " [ 0., 0., 0.]]])\n" + ] + } + ], + "source": [ + "# Iterate Through the DataLoader\n", + "for batch_idx, (images, labels) in enumerate(dataloader):\n", + " print(f\"Batch {batch_idx + 1}:\")\n", + " print(f\" - Images: {images.shape}\") # Shape: (batch_size, 3, H, W)\n", + " print(f\" - Labels: {labels}\") # Tensor of labels\n", + " # Perform operations on images and labels (e.g., training)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "c7000fcb-4487-4671-a88c-635da8e17d93", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([17, 3])" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "labels[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68165efc-54b3-4774-81d7-5e87b409219f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "raw", + "id": "48811727-c7b5-48ec-8d2e-234e14cd1312", + "metadata": {}, + "source": [ + "train_model(\n", + " model=model,\n", + " datasets_train=datasets_train,\n", + " datasets_valid=datasets_valid,\n", + " cfg=cfg,\n", + " distributed=distributed,\n", + " validate=cfg.validate,\n", + " timestamp=timestamp,\n", + " meta=meta\n", + " )" + ] + }, + { + "cell_type": "raw", + "id": "62b37ca0-46a6-49e9-8d7b-5f0c69b93f20", + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "source": [ + "import os.path as osp\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from vit_models.losses import JointsMSELoss\n", + "from vit_models.optimizer import LayerDecayOptimizer\n", + "\n", + "from torch.nn.parallel import DataParallel, DistributedDataParallel\n", + "from torch.nn.utils import clip_grad_norm_\n", + "from torch.optim import AdamW\n", + "from torch.optim.lr_scheduler import LambdaLR, MultiStepLR\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from torch.utils.data.distributed import DistributedSampler\n", + "from torch.cuda.amp import autocast, GradScaler\n", + "from tqdm import tqdm\n", + "from time import time\n", + "\n", + "\n", + "logger = get_root_logger()\n", + "\n", + " \n", + "dataloaders_train = datasets_train\n", + "dataloaders_valid = datasets_valid\n", + "# put model on gpus\n", + "if distributed:\n", + " find_unused_parameters = cfg.get('find_unused_parameters', False)\n", + " # Sets the `find_unused_parameters` parameter in\n", + " # torch.nn.parallel.DistributedDataParallel\n", + "\n", + " model = DistributedDataParallel(\n", + " module=model, \n", + " device_ids=[torch.cuda.current_device()], \n", + " broadcast_buffers=False, \n", + " find_unused_parameters=find_unused_parameters)\n", + "else:\n", + " model = DataParallel(model, device_ids=cfg.gpu_ids)\n", + "\n", + "# Loss function\n", + "criterion = JointsMSELoss(use_target_weight=cfg.model['keypoint_head']['loss_keypoint']['use_target_weight'])\n", + "\n", + "# Optimizer\n", + "optimizer = AdamW(model.parameters(), lr=cfg.optimizer['lr'], betas=cfg.optimizer['betas'], weight_decay=cfg.optimizer['weight_decay'])\n", + "\n", + "# Layer-wise learning rate decay\n", + "lr_mult = [cfg.optimizer['paramwise_cfg']['layer_decay_rate']] * cfg.optimizer['paramwise_cfg']['num_layers']\n", + "layerwise_optimizer = LayerDecayOptimizer(optimizer, lr_mult)\n", + "\n", + "\n", + "# Learning rate scheduler (MultiStepLR)\n", + "milestones = cfg.lr_config['step']\n", + "gamma = 0.1\n", + "scheduler = MultiStepLR(optimizer, milestones, gamma)\n", + "\n", + "# Warm-up scheduler\n", + "num_warmup_steps = cfg.lr_config['warmup_iters'] # Number of warm-up steps\n", + "warmup_factor = cfg.lr_config['warmup_ratio'] # Initial learning rate = warmup_factor * learning_rate\n", + "warmup_scheduler = LambdaLR(\n", + " optimizer,\n", + " lr_lambda=lambda step: warmup_factor + (1.0 - warmup_factor) * step / num_warmup_steps\n", + ")\n", + "\n", + "# AMP setting\n", + "if cfg.use_amp:\n", + " logger.info(\"Using Automatic Mixed Precision (AMP) training...\")\n", + " # Create a GradScaler object for FP16 training\n", + " scaler = GradScaler()\n", + "\n", + "# Logging config\n", + "total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "logger.info(f'''\\n\n", + "#========= [Train Configs] =========#\n", + "# - Num GPUs: {len(cfg.gpu_ids)}\n", + "# - Batch size (per gpu): {cfg.data['samples_per_gpu']}\n", + "# - LR: {cfg.optimizer['lr']: .6f}\n", + "# - Num params: {total_params:,d}\n", + "# - AMP: {cfg.use_amp}\n", + "#===================================# \n", + "''')\n", + "\n", + "global_step = 0\n", + "for dataloader in dataloaders_train:\n", + " print(\"start training\")\n", + " for epoch in range(cfg.total_epochs):\n", + " model.train()\n", + " train_pbar = tqdm(dataloader)\n", + " total_loss = 0\n", + " tic = time()\n", + " for batch_idx, batch in enumerate(train_pbar):\n", + " layerwise_optimizer.zero_grad()\n", + " \n", + " images, targets, target_weights, __ = batch\n", + " images = images.to('cuda').unsqueeze(0)\n", + " targets = targets.to('cuda').unsqueeze(0)\n", + " target_weights = target_weights.to('cuda')\n", + " \n", + " if cfg.use_amp:\n", + " with autocast():\n", + " outputs = model(images)\n", + " loss = criterion(outputs, targets, target_weights)\n", + " scaler.scale(loss).backward()\n", + " clip_grad_norm_(model.parameters(), **cfg.optimizer_config['grad_clip'])\n", + " scaler.step(layerwise_optimizer)\n", + " scaler.update()\n", + " else:\n", + " print(images.shape)\n", + " outputs = model(images)\n", + " print(\"outputs: \", outputs.shape)\n", + " print(\"targets: \", targets.shape)\n", + " loss = criterion(outputs, targets, target_weights)\n", + " loss.backward()\n", + " clip_grad_norm_(model.parameters(), **cfg.optimizer_config['grad_clip'])\n", + " layerwise_optimizer.step()\n", + " \n", + " if global_step < num_warmup_steps:\n", + " warmup_scheduler.step()\n", + " global_step += 1\n", + " \n", + " total_loss += loss.item()\n", + " train_pbar.set_description(f\"🏋️> Epoch [{str(epoch).zfill(3)}/{str(cfg.total_epochs).zfill(3)}] | Loss {loss.item():.4f} | LR {optimizer.param_groups[0]['lr']:.6f} | Step\")\n", + " scheduler.step()\n", + " \n", + " avg_loss_train = total_loss/len(dataloader)\n", + " logger.info(f\"[Summary-train] Epoch [{str(epoch).zfill(3)}/{str(cfg.total_epochs).zfill(3)}] | Average Loss (train) {avg_loss_train:.4f} --- {time()-tic:.5f} sec. elapsed\")\n", + " ckpt_name = f\"epoch{str(epoch).zfill(3)}.pth\"\n", + " ckpt_path = osp.join(cfg.work_dir, ckpt_name)\n", + " torch.save(model.module.state_dict(), ckpt_path)\n", + "\n", + " # validation\n", + " if validate:\n", + " tic2 = time()\n", + " avg_loss_valid = valid_model(model, dataloaders_valid, criterion, cfg)\n", + " logger.info(f\"[Summary-valid] Epoch [{str(epoch).zfill(3)}/{str(cfg.total_epochs).zfill(3)}] | Average Loss (valid) {avg_loss_valid:.4f} --- {time()-tic2:.5f} sec. elapsed\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c3aaa51-efa5-414c-a0a8-9f3b475ca67e", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2f0dd2c-ceaa-40c1-943f-3392c4bb1b3d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc8f7cda-227e-4a03-b694-97af713b73f6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91348ce9-12d7-4882-b222-b9b60cedebec", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1e63562-cc6c-460d-aab5-f2f9bb5724cc", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "6b2dd91c-9830-41a3-bddd-6c2f09254a7c", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda')\n", + "# Move model to device\n", + "model = model.to(device)\n", + "\n", + "# Move inputs to device\n", + "images = images.to(device)\n" + ] + }, + { + "cell_type": "markdown", + "id": "faf1e322-2867-47d7-8778-09e83ead56ef", + "metadata": {}, + "source": [ + "## Define my own training process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cbc9686-7e2f-4fcb-af2b-b260fb8051f5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4ed1ef2-58d7-4d18-9966-1ebb95adc529", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e568149-21ff-47c5-93f1-35d7290b837f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0a0a470-7c51-4027-b1cb-85d5a52bb28e", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "430f8a82-9c8f-4014-b3ff-ac9548213294", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4b00355-74b1-4f87-abd1-2e2dc31846e6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2de2c69e-ebc5-4e96-aa0e-5aab0cc88c31", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "raw", + "id": "867a6faa-7419-4a01-9719-2739639be673", + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "def generate_heatmaps(keypoints, output_size):\n", + " \"\"\"\n", + " Generate heatmaps from keypoints for training.\n", + " Args:\n", + " - keypoints: Tensor of shape (batch_size, num_keypoints, 3) containing (x, y, visibility)\n", + " - output_size: (height, width) of the heatmaps\n", + " Returns:\n", + " - heatmaps: Tensor of shape (batch_size, num_keypoints, height, width)\n", + " \"\"\"\n", + " batch_size, num_keypoints, _ = keypoints.shape\n", + " height, width = output_size\n", + " heatmaps = torch.zeros(batch_size, num_keypoints, height, width, device=keypoints.device)\n", + "\n", + " #print(\"heatmaps: \", heatmaps)\n", + " for i in range(batch_size):\n", + " for j in range(num_keypoints):\n", + " x, y, visibility = keypoints[i, j, 0], keypoints[i, j, 1], keypoints[i, j, 2]\n", + " if visibility > 0:\n", + " # Create a Gaussian heatmap for each keypoint\n", + " gaussian = generate_gaussian(x, y, height, width)\n", + " print(\"gaussian max: \", gaussian.max())\n", + " print(\"gaussian min: \", gaussian.min())\n", + " heatmaps[i, j] = gaussian\n", + "\n", + " return heatmaps\n", + "\n", + "def generate_gaussian(x, y, height, width, sigma=1):\n", + " \"\"\"\n", + " Generate a Gaussian heatmap centered at (x, y) with standard deviation sigma.\n", + " \"\"\"\n", + " grid_x, grid_y = torch.meshgrid(torch.arange(0, width), torch.arange(0, height))\n", + " grid = torch.stack([grid_x, grid_y], dim=-1).float()\n", + " \n", + " mean = torch.tensor([x, y], dtype=torch.float32)\n", + " variance = sigma ** 2\n", + " diff = grid - mean\n", + " dist = torch.sum(diff ** 2, dim=-1)\n", + " gaussian = torch.exp(-dist / (2 * variance))\n", + "\n", + " return gaussian\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bdf3b94a-370e-481b-a1a6-9c165bc06e34", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "3e554ac4-3b6f-4621-8bde-a358e68e8e25", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "images.shape: torch.Size([4, 3, 208, 208])\n", + "labels.shape: torch.Size([4, 17, 3])\n" + ] + } + ], + "source": [ + "for images, labels in dataloader:\n", + " print(\"images.shape: \", images.shape)\n", + " print(\"labels.shape: \", labels.shape)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "f2e9ad48-4a8e-4bdd-91a2-68f810e859c4", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'plt' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[32], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mplt\u001b[49m\u001b[38;5;241m.\u001b[39mimshow(outputs[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy())\n\u001b[0;32m 2\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n", + "\u001b[1;31mNameError\u001b[0m: name 'plt' is not defined" + ] + } + ], + "source": [ + "plt.imshow(outputs[0][0].cpu().detach().numpy())\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38308db6-5e0c-4d08-91c4-b7408adcb81a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf122f87-f911-4c39-a6fc-4f004d1988ce", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "9c5386a0-b07b-47f5-be1e-9af65dcb6de2", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "def generate_heatmaps(keypoints, output_size, sigma=2):\n", + " \"\"\"\n", + " Generate ground truth heatmaps for keypoints.\n", + " \n", + " Args:\n", + " keypoints: Tensor of shape (batch_size, num_keypoints, 3) with (x, y, confidence).\n", + " output_size: Tuple (height, width) of the heatmap.\n", + " sigma: Standard deviation of the Gaussian.\n", + " \n", + " Returns:\n", + " heatmaps: Tensor of shape (batch_size, num_keypoints, height, width).\n", + " \"\"\"\n", + " batch_size, num_keypoints, _ = keypoints.shape\n", + " height, width = output_size\n", + " heatmaps = torch.zeros((batch_size, num_keypoints, height, width), device=keypoints.device)\n", + "\n", + " for b in range(batch_size):\n", + " for k in range(num_keypoints):\n", + " x, y, confidence = keypoints[b, k]\n", + " \n", + " # Skip keypoints with zero confidence\n", + " if confidence <= 0 or x < 0 or y < 0:\n", + " continue\n", + " \n", + " # Create a meshgrid for Gaussian generation\n", + " xx, yy = torch.meshgrid(torch.arange(width, device=keypoints.device), \n", + " torch.arange(height, device=keypoints.device), \n", + " indexing='xy')\n", + " \n", + " # Calculate the 2D Gaussian heatmap\n", + " heatmap = torch.exp(-((xx - x)**2 + (yy - y)**2) / (2 * sigma**2))\n", + " heatmaps[b, k] = heatmap\n", + "\n", + " return heatmaps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b738642-c839-410a-8ba0-aa9698e70aa0", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "6ca3122b-2019-4bc1-a8cd-39e7230d52de", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated heatmaps shape: torch.Size([2, 17, 255, 255])\n" + ] + } + ], + "source": [ + "# Example usage\n", + "keypoints = torch.tensor([[[226., 129., 2.], [228., 127., 2.], [225., 127., 2.], [0., 0., 0.], [0., 0., 0.],\n", + " [233., 128., 2.], [218., 130., 2.], [239., 135., 2.], [213., 136., 2.], [243., 139., 2.],\n", + " [211., 137., 2.], [232., 149., 2.], [222., 148., 2.], [232., 169., 2.], [222., 169., 2.],\n", + " [233., 188., 2.], [221., 182., 2.]],\n", + " [[584., 101., 2.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [587., 137., 2.],\n", + " [637., 137., 2.], [567., 196., 2.], [0., 0., 0.], [561., 235., 2.], [619., 214., 2.],\n", + " [589., 222., 2.], [630., 224., 2.], [579., 317., 2.], [614., 309., 2.], [586., 400., 2.],\n", + " [611., 399., 2.]]], device='cuda:0')\n", + "\n", + "heatmaps = generate_heatmaps(keypoints, output_size=(255, 255), sigma=2)\n", + "\n", + "print(\"Generated heatmaps shape:\", heatmaps.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "3615cb01-7df8-43d9-9bd1-16f644df0125", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAAGiCAYAAACGUJO6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiCElEQVR4nO3df2xV9eH/8de5ve1tC9zbldLeVn5Y/IXIjznE2ugYGw0tMKfCEnFsQ0MgstZMUedqFMQtq2HLtuhwZMki/iFumohEspEhWJjzglolCGgjfJkto7cgrPe2hd7e2/v+/oHcz66UHy1tr+/2+UhO6D3n3Nv3eafNk3PvubeOMcYIAADLuFI9AAAAeoOAAQCsRMAAAFYiYAAAKxEwAICVCBgAwEoEDABgJQIGALASAQMAWImAAQCslLKArV27VldeeaUyMzNVUlKid999N1VDAQBYKCUB++tf/6oVK1Zo1apV+uCDDzR16lSVl5fr2LFjqRgOAMBCTio+zLekpETTp0/XH/7wB0lSPB7XmDFj9MADD+jnP//5QA8HAGAh90B/w87OTtXV1am6ujqxzuVyqaysTIFAoNv7RCIRRSKRxO14PK6TJ09q5MiRchyn38cMAOhbxhi1traqqKhILlfvngwc8IB9/vnn6urqUkFBQdL6goICffLJJ93ep6amRqtXrx6I4QEABlBjY6NGjx7dq/sOeMB6o7q6WitWrEjcDoVCGjt2rG7TXLmVnsKRAQB6I6ao3tbfNGLEiF4/xoAHLC8vT2lpaWpubk5a39zcLL/f3+19PB6PPB7POevdSpfbIWAAYJ0vrr64nJeBBvwqxIyMDE2bNk3btm1LrIvH49q2bZtKS0sHejgAAEul5CnEFStWaPHixbrpppt088036/e//73a29t13333pWI4AAALpSRgd999t44fP66VK1cqGAzq61//urZs2XLOhR0AAJxPSt4HdrnC4bB8Pp9m6g5eAwMAC8VMVLXapFAoJK/X26vH4LMQAQBWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYKU+D9hTTz0lx3GSlgkTJiS2d3R0qLKyUiNHjtTw4cO1YMECNTc39/UwAACDXL+cgd1www1qampKLG+//XZi20MPPaQ33nhDr776qnbs2KGjR49q/vz5/TEMAMAg5u6XB3W75ff7z1kfCoX05z//WRs2bNB3vvMdSdILL7yg66+/Xrt27dItt9zSH8MBAAxC/XIG9umnn6qoqEjjx4/XokWL1NDQIEmqq6tTNBpVWVlZYt8JEyZo7NixCgQC/TEUAMAg1ednYCUlJVq/fr2uu+46NTU1afXq1frmN7+pffv2KRgMKiMjQzk5OUn3KSgoUDAYPO9jRiIRRSKRxO1wONzXwwYAWKbPAzZnzpzE11OmTFFJSYnGjRunV155RVlZWb16zJqaGq1evbqvhggAGAT6/TL6nJwcXXvttTp48KD8fr86OzvV0tKStE9zc3O3r5mdVV1drVAolFgaGxv7edQAgK+6fg9YW1ubDh06pMLCQk2bNk3p6enatm1bYnt9fb0aGhpUWlp63sfweDzyer1JCwBgaOvzpxAfeeQR3X777Ro3bpyOHj2qVatWKS0tTffcc498Pp+WLFmiFStWKDc3V16vVw888IBKS0u5AhEA0CN9HrAjR47onnvu0YkTJzRq1Cjddttt2rVrl0aNGiVJ+t3vfieXy6UFCxYoEomovLxczz//fF8PAwAwyDnGGJPqQfRUOByWz+fTTN0ht5Oe6uEAAHooZqKq1SaFQqFevyzEZyECAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGClHgds586duv3221VUVCTHcfT6668nbTfGaOXKlSosLFRWVpbKysr06aefJu1z8uRJLVq0SF6vVzk5OVqyZIna2tou60AAAENLjwPW3t6uqVOnau3atd1uX7NmjZ599lmtW7dOu3fv1rBhw1ReXq6Ojo7EPosWLdL+/fu1detWbd68WTt37tSyZct6fxQAgCHHMcaYXt/ZcbRx40bdeeedks6cfRUVFenhhx/WI488IkkKhUIqKCjQ+vXrtXDhQn388ceaOHGi3nvvPd10002SpC1btmju3Lk6cuSIioqKLvp9w+GwfD6fZuoOuZ303g4fAJAiMRNVrTYpFArJ6/X26jH69DWww4cPKxgMqqysLLHO5/OppKREgUBAkhQIBJSTk5OIlySVlZXJ5XJp9+7d3T5uJBJROBxOWgAAQ1ufBiwYDEqSCgoKktYXFBQktgWDQeXn5ydtd7vdys3NTezzZTU1NfL5fIllzJgxfTlsAICFrLgKsbq6WqFQKLE0NjamekgAgBTr04D5/X5JUnNzc9L65ubmxDa/369jx44lbY/FYjp58mRiny/zeDzyer1JCwBgaOvTgBUXF8vv92vbtm2JdeFwWLt371ZpaakkqbS0VC0tLaqrq0vss337dsXjcZWUlPTlcAAAg5i7p3doa2vTwYMHE7cPHz6sPXv2KDc3V2PHjtWDDz6oX/7yl7rmmmtUXFysJ598UkVFRYkrFa+//npVVFRo6dKlWrdunaLRqKqqqrRw4cJLugIRAACpFwF7//339e1vfztxe8WKFZKkxYsXa/369frZz36m9vZ2LVu2TC0tLbrtttu0ZcsWZWZmJu7z0ksvqaqqSrNmzZLL5dKCBQv07LPP9sHhAACGist6H1iq8D4wALDbV+59YAAADBQCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBK7lQPAABgAcf54t9uzntMXDJmYMcjAgYAuBDHkRyXHNcX/6a5JNcXEYvHZYyR4kamq2vAQ0bAAAAX5KSlSS5HjtstJyNDcjlnwtYVl9PVdSZenZLpkqSBixgBAwCc6+yZV1qanEyPHE+GnMxMmexMyZ0m43bJiXbJiUTlRDpl2ttlOqMy0ZhMLDogESNgAIBz/U+8XN4RMtmZin1tmCIjMxXPcBRPd5TWYZTeHpO7tVNpx9OktnbJiZw5IxuAMzECBgBI5jhyXI6cNJecjHSZ7EzFc4apIz9L7f40dWU66vJI7lNGnpBLnsw0ZZ+OyumKS11dcjrTzjydeOY5xX5DwAAA50pLk9LT5WRnK5o3Qh0FHoXGudU6Pi4zIqr0rKiiIY8yjqcpO5iutMhweSQ5XV1yOiKSicsYp1/PwggYACDZ2acPM9JlsjzqzM1Qe0Ga2q6Mq/iGo7rae1zjsz5XXWis9hy5Qm3pw5T1ebrSTmcq/XREakuXuuKSOAMDAKSCK01ypymW6VIs21HcG9UNOU36xvDPNNHzH0lSsN2rxv9mKpaZJpPhknGnyXEcGenMe8b68WlEPokDAHAux5FcjozbpXi6o1imlD6sUxOzj+obmQ2alpGmqVkNGj28Rc6wmLo8UtztktJcZy6zdzn9PkTOwAAA5zLmzAUZsbjSIkbuU9KpsEd1rVeqSy4d9xxVoP1q/b/QSCmcLvdpydUZl2JdUvzMm5v7GwEDAJwrHj/zOlZnVO5TXcpodSn9pFt1zaPV3DFC+7Ov0IH/+nXsmE8ZJ13KaI8rrSMmJxpTvOuLpw1NvF+HSMAAAMnMFx8RFYvJ1dGpjFCnsjJdiv7HrZDna/rvcK8OZBeqK5SuzGNuZTUbef4blas9IkU6pbMfK9XPCBgAIJkxMtHYmS/bT8nd7Nawjpjcp7OVedKtWKZb8Qy33KclT7hLGS0xZRwNywm3ybSfkonFvvhsRN7IDAAYaObMm5JNJCKdOi2XMfIYo7RIprrSXYqnu5QWict9KipXW6ectlMyHRGZzk6ZAXj9SyJgAIDzMHEjRWPSqdNnnk6MxpR+KqJ0d9qZT6SPnnnNSx2RM2denZ0ysdiAfSo9AQMAnMsYyXTJmLh0Oi7TEZFzukNOq/v/LpGPmzMXbHR1nXnaMG4G9E+qEDAAwPmZL/7Wl3MmTqbrS29M/iJaAx0viYABAC7GGEnxMx+q8eXXt85ebchfZAYAfCWdDVQ/f8J8T/BRUgAAK/U4YDt37tTtt9+uoqIiOY6j119/PWn7vffeK8dxkpaKioqkfU6ePKlFixbJ6/UqJydHS5YsUVtb22UdCABgaOlxwNrb2zV16lStXbv2vPtUVFSoqakpsbz88stJ2xctWqT9+/dr69at2rx5s3bu3Klly5b1fPQAgCGrx6+BzZkzR3PmzLngPh6PR36/v9ttH3/8sbZs2aL33ntPN910kyTpueee09y5c/Wb3/xGRUVFPR0SAGAI6pfXwGpra5Wfn6/rrrtOy5cv14kTJxLbAoGAcnJyEvGSpLKyMrlcLu3evbs/hgMAGIT6/CrEiooKzZ8/X8XFxTp06JAef/xxzZkzR4FAQGlpaQoGg8rPz08ehNut3NxcBYPBbh8zEokoEokkbofD4b4eNgDAMn0esIULFya+njx5sqZMmaKrrrpKtbW1mjVrVq8es6amRqtXr+6rIQIABoF+v4x+/PjxysvL08GDByVJfr9fx44dS9onFovp5MmT533drLq6WqFQKLE0Njb297ABAF9x/R6wI0eO6MSJEyosLJQklZaWqqWlRXV1dYl9tm/frng8rpKSkm4fw+PxyOv1Ji0AgKGtx08htrW1Jc6mJOnw4cPas2ePcnNzlZubq9WrV2vBggXy+/06dOiQfvazn+nqq69WeXm5JOn6669XRUWFli5dqnXr1ikajaqqqkoLFy7kCkQAwCXr8RnY+++/rxtvvFE33nijJGnFihW68cYbtXLlSqWlpWnv3r363ve+p2uvvVZLlizRtGnT9M9//lMejyfxGC+99JImTJigWbNmae7cubrtttv0pz/9qe+OCgAw6DnGpOATGC9TOByWz+fTTN0ht5Oe6uEAAHooZqKq1SaFQqFevyzEZyECAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGClHgWspqZG06dP14gRI5Sfn68777xT9fX1Sft0dHSosrJSI0eO1PDhw7VgwQI1Nzcn7dPQ0KB58+YpOztb+fn5evTRRxWLxS7/aAAAQ0aPArZjxw5VVlZq165d2rp1q6LRqGbPnq329vbEPg899JDeeOMNvfrqq9qxY4eOHj2q+fPnJ7Z3dXVp3rx56uzs1DvvvKMXX3xR69ev18qVK/vuqAAAg55jjDG9vfPx48eVn5+vHTt2aMaMGQqFQho1apQ2bNig73//+5KkTz75RNdff70CgYBuueUW/f3vf9d3v/tdHT16VAUFBZKkdevW6bHHHtPx48eVkZFx0e8bDofl8/k0U3fI7aT3dvgAgBSJmahqtUmhUEher7dXj3FZr4GFQiFJUm5uriSprq5O0WhUZWVliX0mTJigsWPHKhAISJICgYAmT56ciJcklZeXKxwOa//+/d1+n0gkonA4nLQAAIa2XgcsHo/rwQcf1K233qpJkyZJkoLBoDIyMpSTk5O0b0FBgYLBYGKf/43X2e1nt3WnpqZGPp8vsYwZM6a3wwYADBK9DlhlZaX27dunv/zlL305nm5VV1crFAollsbGxn7/ngCArzZ3b+5UVVWlzZs3a+fOnRo9enRivd/vV2dnp1paWpLOwpqbm+X3+xP7vPvuu0mPd/YqxbP7fJnH45HH4+nNUAEAg1SPzsCMMaqqqtLGjRu1fft2FRcXJ22fNm2a0tPTtW3btsS6+vp6NTQ0qLS0VJJUWlqqjz76SMeOHUvss3XrVnm9Xk2cOPFyjgUAMIT06AyssrJSGzZs0KZNmzRixIjEa1Y+n09ZWVny+XxasmSJVqxYodzcXHm9Xj3wwAMqLS3VLbfcIkmaPXu2Jk6cqB/96Edas2aNgsGgnnjiCVVWVnKWBQC4ZD26jN5xnG7Xv/DCC7r33nslnXkj88MPP6yXX35ZkUhE5eXlev7555OeHvzss8+0fPly1dbWatiwYVq8eLGeeeYZud2X1lMuowcAu/XFZfSX9T6wVCFgAGC3lL8PDACAVCFgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGAlAgYAsBIBAwBYiYABAKxEwAAAViJgAAArETAAgJUIGADASgQMAGClHgWspqZG06dP14gRI5Sfn68777xT9fX1SfvMnDlTjuMkLffff3/SPg0NDZo3b56ys7OVn5+vRx99VLFY7PKPBgAwZLh7svOOHTtUWVmp6dOnKxaL6fHHH9fs2bN14MABDRs2LLHf0qVL9fTTTyduZ2dnJ77u6urSvHnz5Pf79c4776ipqUk//vGPlZ6erl/96ld9cEgAgKGgRwHbsmVL0u3169crPz9fdXV1mjFjRmJ9dna2/H5/t4/xj3/8QwcOHNCbb76pgoICff3rX9cvfvELPfbYY3rqqaeUkZHRi8MAAAw1l/UaWCgUkiTl5uYmrX/ppZeUl5enSZMmqbq6WqdOnUpsCwQCmjx5sgoKChLrysvLFQ6HtX///ssZDgBgCOnRGdj/isfjevDBB3Xrrbdq0qRJifU/+MEPNG7cOBUVFWnv3r167LHHVF9fr9dee02SFAwGk+IlKXE7GAx2+70ikYgikUjidjgc7u2wAQCDRK8DVllZqX379untt99OWr9s2bLE15MnT1ZhYaFmzZqlQ4cO6aqrrurV96qpqdHq1at7O1QAwCDUq6cQq6qqtHnzZr311lsaPXr0BfctKSmRJB08eFCS5Pf71dzcnLTP2dvne92surpaoVAosTQ2NvZm2ACAQaRHATPGqKqqShs3btT27dtVXFx80fvs2bNHklRYWChJKi0t1UcffaRjx44l9tm6dau8Xq8mTpzY7WN4PB55vd6kBQAwtPXoKcTKykpt2LBBmzZt0ogRIxKvWfl8PmVlZenQoUPasGGD5s6dq5EjR2rv3r166KGHNGPGDE2ZMkWSNHv2bE2cOFE/+tGPtGbNGgWDQT3xxBOqrKyUx+Pp+yMEAAxKjjHGXPLOjtPt+hdeeEH33nuvGhsb9cMf/lD79u1Te3u7xowZo7vuuktPPPFE0lnTZ599puXLl6u2tlbDhg3T4sWL9cwzz8jtvrSehsNh+Xw+zdQdcjvplzp8AMBXRMxEVatNCoVCvX5WrUcB+6oIhULKycnRbZortwgYANgmpqje1t/U0tIin8/Xq8fo9VWIqdTa2ipJelt/S/FIAACXo7W1tdcBs/IMLB6Pq76+XhMnTlRjYyMXdXQjHA5rzJgxzM8FMEcXxvxcHHN0YReaH2OMWltbVVRUJJerd5+pYeUZmMvl0hVXXCFJXJV4EczPxTFHF8b8XBxzdGHnm5/ennmdxZ9TAQBYiYABAKxkbcA8Ho9WrVrFe8fOg/m5OObowpifi2OOLqy/58fKizgAALD2DAwAMLQRMACAlQgYAMBKBAwAYCUrA7Z27VpdeeWVyszMVElJid59991UDyllnnrqKTmOk7RMmDAhsb2jo0OVlZUaOXKkhg8frgULFpzz99gGk507d+r2229XUVGRHMfR66+/nrTdGKOVK1eqsLBQWVlZKisr06effpq0z8mTJ7Vo0SJ5vV7l5ORoyZIlamtrG8Cj6F8Xm6N77733nJ+pioqKpH0G8xzV1NRo+vTpGjFihPLz83XnnXeqvr4+aZ9L+b1qaGjQvHnzlJ2drfz8fD366KOKxWIDeSj94lLmZ+bMmef8DN1///1J+/TF/FgXsL/+9a9asWKFVq1apQ8++EBTp05VeXl50t8XG2puuOEGNTU1JZb//SvZDz30kN544w29+uqr2rFjh44ePar58+encLT9q729XVOnTtXatWu73b5mzRo9++yzWrdunXbv3q1hw4apvLxcHR0diX0WLVqk/fv3a+vWrdq8ebN27tyZ9JfGbXexOZKkioqKpJ+pl19+OWn7YJ6jHTt2qLKyUrt27dLWrVsVjUY1e/Zstbe3J/a52O9VV1eX5s2bp87OTr3zzjt68cUXtX79eq1cuTIVh9SnLmV+JGnp0qVJP0Nr1qxJbOuz+TGWufnmm01lZWXidldXlykqKjI1NTUpHFXqrFq1ykydOrXbbS0tLSY9Pd28+uqriXUff/yxkWQCgcAAjTB1JJmNGzcmbsfjceP3+82vf/3rxLqWlhbj8XjMyy+/bIwx5sCBA0aSee+99xL7/P3vfzeO45j//Oc/Azb2gfLlOTLGmMWLF5s77rjjvPcZanN07NgxI8ns2LHDGHNpv1d/+9vfjMvlMsFgMLHPH//4R+P1ek0kEhnYA+hnX54fY4z51re+ZX7605+e9z59NT9WnYF1dnaqrq5OZWVliXUul0tlZWUKBAIpHFlqffrppyoqKtL48eO1aNEiNTQ0SJLq6uoUjUaT5mvChAkaO3bskJyvw4cPKxgMJs2Hz+dTSUlJYj4CgYBycnJ00003JfYpKyuTy+XS7t27B3zMqVJbW6v8/Hxdd911Wr58uU6cOJHYNtTmKBQKSZJyc3MlXdrvVSAQ0OTJk1VQUJDYp7y8XOFwWPv37x/A0fe/L8/PWS+99JLy8vI0adIkVVdX69SpU4ltfTU/Vn2Y7+eff66urq6kg5akgoICffLJJykaVWqVlJRo/fr1uu6669TU1KTVq1frm9/8pvbt26dgMKiMjAzl5OQk3aegoCDx17SHkrPH3N3Pz9ltwWBQ+fn5Sdvdbrdyc3OHzJxVVFRo/vz5Ki4u1qFDh/T4449rzpw5CgQCSktLG1JzFI/H9eCDD+rWW2/VpEmTJOmSfq+CwWC3P2dntw0W3c2PJP3gBz/QuHHjVFRUpL179+qxxx5TfX29XnvtNUl9Nz9WBQznmjNnTuLrKVOmqKSkROPGjdMrr7yirKysFI4Mtlq4cGHi68mTJ2vKlCm66qqrVFtbq1mzZqVwZAOvsrJS+/btS3pdGf/nfPPzv6+HTp48WYWFhZo1a5YOHTqkq666qs++v1VPIebl5SktLe2cq32am5vl9/tTNKqvlpycHF177bU6ePCg/H6/Ojs71dLSkrTPUJ2vs8d8oZ8fv99/zgVBsVhMJ0+eHJJzJknjx49XXl6eDh48KGnozFFVVZU2b96st956S6NHj06sv5TfK7/f3+3P2dltg8H55qc7JSUlkpT0M9QX82NVwDIyMjRt2jRt27YtsS4ej2vbtm0qLS1N4ci+Otra2nTo0CEVFhZq2rRpSk9PT5qv+vp6NTQ0DMn5Ki4ult/vT5qPcDis3bt3J+ajtLRULS0tqqurS+yzfft2xePxxC/hUHPkyBGdOHFChYWFkgb/HBljVFVVpY0bN2r79u0qLi5O2n4pv1elpaX66KOPkkK/detWeb1eTZw4cWAOpJ9cbH66s2fPHklK+hnqk/npxUUnKfWXv/zFeDwes379enPgwAGzbNkyk5OTk3Q1y1Dy8MMPm9raWnP48GHzr3/9y5SVlZm8vDxz7NgxY4wx999/vxk7dqzZvn27ef/9901paakpLS1N8aj7T2trq/nwww/Nhx9+aCSZ3/72t+bDDz80n332mTHGmGeeecbk5OSYTZs2mb1795o77rjDFBcXm9OnTyceo6Kiwtx4441m9+7d5u233zbXXHONueeee1J1SH3uQnPU2tpqHnnkERMIBMzhw4fNm2++ab7xjW+Ya665xnR0dCQeYzDP0fLly43P5zO1tbWmqakpsZw6dSqxz8V+r2KxmJk0aZKZPXu22bNnj9myZYsZNWqUqa6uTsUh9amLzc/BgwfN008/bd5//31z+PBhs2nTJjN+/HgzY8aMxGP01fxYFzBjjHnuuefM2LFjTUZGhrn55pvNrl27Uj2klLn77rtNYWGhycjIMFdccYW5++67zcGDBxPbT58+bX7yk5+Yr33tayY7O9vcddddpqmpKYUj7l9vvfWWkXTOsnjxYmPMmUvpn3zySVNQUGA8Ho+ZNWuWqa+vT3qMEydOmHvuuccMHz7ceL1ec99995nW1tYUHE3/uNAcnTp1ysyePduMGjXKpKenm3HjxpmlS5ee8x/EwTxH3c2NJPPCCy8k9rmU36t///vfZs6cOSYrK8vk5eWZhx9+2ESj0QE+mr53sflpaGgwM2bMMLm5ucbj8Zirr77aPProoyYUCiU9Tl/MD39OBQBgJateAwMA4CwCBgCwEgEDAFiJgAEArETAAABWImAAACsRMACAlQgYAMBKBAwAYCUCBgCwEgEDAFiJgAEArPT/AVdGY0VlNRNmAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "#plt.imshow(outputs[0][0].cpu().detach().numpy())\n", + "import matplotlib.pyplot as plt\n", + "for i in range(0, 17):\n", + " plt.imshow(heatmaps[0][i].cpu().detach().numpy())\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "f238c110-8e4f-4865-bc1c-f8aef0127ef3", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'outputs' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[36], line 4\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m17\u001b[39m):\n\u001b[1;32m----> 4\u001b[0m plt\u001b[38;5;241m.\u001b[39mimshow(\u001b[43moutputs\u001b[49m[\u001b[38;5;241m0\u001b[39m][i]\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy())\n\u001b[0;32m 5\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n", + "\u001b[1;31mNameError\u001b[0m: name 'outputs' is not defined" + ] + } + ], + "source": [ + "#plt.imshow(outputs[0][0].cpu().detach().numpy())\n", + "import matplotlib.pyplot as plt\n", + "for i in range(0, 17):\n", + " plt.imshow(outputs[0][i].cpu().detach().numpy())\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89ae76ec-92b7-4773-8893-fa1028f1c0a2", + "metadata": {}, + "outputs": [], + "source": [ + "#plt.imshow(outputs[0][0].cpu().detach().numpy())\n", + "import matplotlib.pyplot as plt\n", + "plt.imshow(gt_heatmaps[0][5].cpu().detach().numpy())\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d55e9df0-d46a-445a-8032-6ff5bd566ea7", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "d55944fe-f7bd-463c-bda9-848d3ac29275", + "metadata": {}, + "source": [ + "## convert keypoint" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "83242f5c-71ae-4c95-a8fa-d385904d00d7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.]],\n", + "\n", + " [[0., 0., 2.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 0.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.],\n", + " [0., 0., 2.]]], device='cuda:0')\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "\n", + "# Original and target sizes\n", + "original_size = (208, 208) # Replace with actual dimensions\n", + "target_size = (52, 52)\n", + "\n", + "# Resizing function\n", + "def resize_keypoints_new(keypoints, original_size, target_size):\n", + " original_height, original_width = original_size\n", + " target_height, target_width = target_size\n", + " \n", + " scale_x = int(target_width / original_width)\n", + " scale_y = int(target_height / original_height)\n", + " \n", + " resized_keypoints = keypoints.clone()\n", + " resized_keypoints[..., 0] *= scale_x\n", + " resized_keypoints[..., 1] *= scale_y\n", + " \n", + " return resized_keypoints\n", + "\n", + "# Resized keypoints\n", + "resized_keypoints = resize_keypoints_new(keypoints, original_size, target_size)\n", + "print(resized_keypoints)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "718c212b-ebb5-4ddb-8b7c-88537ea6f85f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f4ef8c2-e3d2-4e7d-8388-d4b4ddc18587", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "e2e1b2cd-c9ce-4149-a89c-ddd210196fb7", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'gt_heatmaps' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[38], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m#plt.imshow(outputs[0][0].cpu().detach().numpy())\u001b[39;00m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m----> 3\u001b[0m plt\u001b[38;5;241m.\u001b[39mimshow(\u001b[43mgt_heatmaps\u001b[49m[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;241m2\u001b[39m]\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy())\n\u001b[0;32m 4\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n", + "\u001b[1;31mNameError\u001b[0m: name 'gt_heatmaps' is not defined" + ] + } + ], + "source": [ + "#plt.imshow(outputs[0][0].cpu().detach().numpy())\n", + "import matplotlib.pyplot as plt\n", + "plt.imshow(gt_heatmaps[0][2].cpu().detach().numpy())\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9f12e02-03b0-47bb-85fb-d9723e50e0e5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "3b414522-0a5e-4a11-83e6-20aa0fc04d2b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 3, 208, 208])" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "images.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "5dc276a4-7c8b-4972-90d0-e3c77fb9c199", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'gt_heatmaps' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[40], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mgt_heatmaps\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", + "\u001b[1;31mNameError\u001b[0m: name 'gt_heatmaps' is not defined" + ] + } + ], + "source": [ + "gt_heatmaps.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1c06c2b-8708-4a24-9abe-33901f646d90", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch [1/10]: 100%|██████████████████████████████████████████████████| 2751/2751 [03:34<00:00, 12.80batch/s, loss=4.61]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [1/10], Loss: 4.6097\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch [2/10]: 100%|██████████████████████████████████████████████████| 2751/2751 [03:34<00:00, 12.81batch/s, loss=3.91]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [2/10], Loss: 3.9105\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch [3/10]: 100%|██████████████████████████████████████████████████| 2751/2751 [03:31<00:00, 13.01batch/s, loss=3.79]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [3/10], Loss: 3.7892\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch [4/10]: 100%|███████████████████████████████████████████████████| 2751/2751 [03:37<00:00, 12.67batch/s, loss=3.7]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [4/10], Loss: 3.7024\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch [5/10]: 100%|██████████████████████████████████████████████████| 2751/2751 [03:30<00:00, 13.09batch/s, loss=3.63]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [5/10], Loss: 3.6335\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch [6/10]: 100%|██████████████████████████████████████████████████| 2751/2751 [03:31<00:00, 13.01batch/s, loss=3.58]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [6/10], Loss: 3.5774\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch [7/10]: 100%|██████████████████████████████████████████████████| 2751/2751 [03:25<00:00, 13.37batch/s, loss=3.52]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [7/10], Loss: 3.5167\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch [8/10]: 99%|█████████████████████████████████████████████████▌| 2730/2751 [03:26<00:01, 12.95batch/s, loss=3.48]" + ] + } + ], + "source": [ + "import torch\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "from tqdm import tqdm # Import tqdm\n", + "\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n", + "criterion = torch.nn.MSELoss()\n", + "\n", + "# Define training loop\n", + "num_epochs = 10 # Number of epochs\n", + "for epoch in range(num_epochs):\n", + " model.train() # Set model to training mode\n", + " running_loss = 0.0\n", + "\n", + " # Wrap the DataLoader with tqdm to show the progress\n", + " with tqdm(dataloader, desc=f\"Epoch [{epoch + 1}/{num_epochs}]\", unit='batch') as pbar:\n", + " for images, labels in pbar:\n", + " # Send images and labels to GPU if needed\n", + " images = images.cuda()\n", + " labels = labels.cuda()\n", + "\n", + " optimizer.zero_grad() # Zero the gradients\n", + "\n", + " # Forward pass\n", + " outputs = model(images) # The model outputs keypoint heatmaps\n", + " #print(outputs.shape)\n", + "\n", + " original_size = (208, 208)\n", + " target_size = (52, 52)\n", + " labels = resize_keypoints_new(labels, original_size, target_size)\n", + " gt_heatmaps = generate_heatmaps(labels, output_size=outputs.shape[2:]) # Implement this\n", + " loss = criterion(outputs*100, gt_heatmaps*100)\n", + "\n", + " # Backward pass and optimization\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " running_loss += loss.item()\n", + "\n", + " # Update the tqdm progress bar with the current loss\n", + " pbar.set_postfix(loss=running_loss / (pbar.n + 1)) # Display the average loss\n", + "\n", + " # Print loss for the current epoch\n", + " avg_loss = running_loss / len(dataloader)\n", + " print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cac52986-3416-4280-8092-9b197ac006ad", + "metadata": {}, + "outputs": [], + "source": [ + "input_tensor.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5dbc351-bb56-450f-85c7-225902b2a9b8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12157857-e9bf-4645-a137-3e5ddf0d3141", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a27d991-83b5-485f-8e50-9273601dedab", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdd034a0-c159-4b2f-8f49-abdbd49844bd", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88eab6b5-72f5-4407-a2d1-c9f74dfb2f8b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}