{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "from pathlib import Path\n", "\n", "import torch\n", "from hydra import compose, initialize\n", "from PIL import Image \n", "\n", "# Ensure that the necessary repository is cloned and installed. You may need to run: \n", "# git clone git@github.com:WongKinYiu/YOLO.git\n", "# cd YOLO \n", "# pip install .\n", "project_root = Path().resolve().parent\n", "sys.path.append(str(project_root))\n", "\n", "from yolo import (\n", " AugmentationComposer, \n", " Config, \n", " NMSConfig, \n", " PostProccess,\n", " bbox_nms, \n", " create_model, \n", " create_converter, \n", " custom_logger, \n", " draw_bboxes, \n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "CONFIG_PATH = \"../yolo/config\"\n", "CONFIG_NAME = \"config\"\n", "MODEL = \"v9-c\"\n", "\n", "DEVICE = 'cuda:0'\n", "CLASS_NUM = 80\n", "IMAGE_PATH = '../image.png'\n", "SLIDE = 4\n", "\n", "custom_logger()\n", "device = torch.device(DEVICE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n", " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n", " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n", " transform = AugmentationComposer([], cfg.image_size)\n", " converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)\n", " post_proccess = PostProccess(converter, NMSConfig(0.5, 0.9))\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pil_image = Image.open(IMAGE_PATH)\n", "image, bbox, rev_tensor = transform(pil_image)\n", "image = image.to(device)[None]\n", "rev_tensor = rev_tensor.to(device)[None]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def slide_image(image, slide = 4, device = device):\n", " up_image = torch.nn.functional.interpolate(image, scale_factor=slide)\n", " image_list = [image]\n", " shift_list = []\n", " *_, w, h = up_image.shape\n", " for x_slide in range(slide):\n", " for y_slide in range(slide):\n", " left_w, right_w = w // slide * x_slide, w // slide * (x_slide + 1)\n", " left_h, right_h = h // slide * y_slide, h // slide * (y_slide + 1)\n", " slide_image = up_image[:, :, left_w: right_w, left_h: right_h]\n", " image_list.append(slide_image)\n", " shift_list.append(torch.Tensor([left_h, left_w, left_h, left_w]))\n", " total_image = torch.concat(image_list)\n", " total_shift = torch.stack(shift_list).to(device)\n", "\n", " return total_image, total_shift" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " total_image, total_shift = slide_image(image)\n", " predict = model(total_image)\n", " pred_class, _, pred_bbox = converter(predict[\"Main\"])\n", "pred_bbox[1:] = (pred_bbox[1: ] + total_shift[:, None]) / SLIDE\n", "pred_bbox = pred_bbox.view(1, -1, 4)\n", "pred_class = pred_class.view(1, -1, 80)\n", "pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]\n", "predict_box = bbox_nms(pred_class, pred_bbox, NMSConfig(0.3, 0.5))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "draw_bboxes(pil_image, predict_box, idx2label=cfg.dataset.class_list)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "yolomit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }