{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "from pathlib import Path\n", "\n", "import torch\n", "from PIL import Image \n", "from loguru import logger\n", "from omegaconf import OmegaConf\n", "\n", "project_root = Path().resolve().parent\n", "sys.path.append(str(project_root))\n", "\n", "from yolo import (\n", " AugmentationComposer, \n", " bbox_nms, \n", " create_model, \n", " custom_logger, \n", " create_converter,\n", " draw_bboxes, \n", " Vec2Box\n", ")\n", "from yolo.config.config import NMSConfig" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "MODEL = \"v9-c\"\n", "DEVICE = \"cuda:0\"\n", "\n", "WEIGHT_PATH = f\"../weights/{MODEL}.pt\" \n", "TRT_WEIGHT_PATH = f\"../weights/{MODEL}.trt\"\n", "MODEL_CONFIG = f\"../yolo/config/model/{MODEL}.yaml\"\n", "\n", "IMAGE_PATH = \"../demo/images/inference/image.png\"\n", "IMAGE_SIZE = (640, 640)\n", "\n", "custom_logger()\n", "device = torch.device(DEVICE)\n", "image = Image.open(IMAGE_PATH)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(MODEL_CONFIG) as stream:\n", " cfg_model = OmegaConf.load(stream)\n", "if os.path.exists(TRT_WEIGHT_PATH):\n", " from torch2trt import TRTModule\n", "\n", " model_trt = TRTModule()\n", " model_trt.load_state_dict(torch.load(TRT_WEIGHT_PATH))\n", "else:\n", " from torch2trt import torch2trt\n", "\n", "\n", " model = create_model(cfg_model, weight_path=WEIGHT_PATH)\n", " model = model.to(device).eval()\n", "\n", " dummy_input = torch.ones((1, 3, 640, 640)).to(device)\n", " logger.info(f\"♻️ Creating TensorRT model\")\n", " model_trt = torch2trt(model, [dummy_input])\n", " torch.save(model_trt.state_dict(), TRT_WEIGHT_PATH)\n", " logger.info(f\"📥 TensorRT model saved to oonx.pt\")\n", "\n", "transform = AugmentationComposer([], IMAGE_SIZE)\n", "converter = create_converter(cfg_model.name, model_trt, cfg_model.anchor, IMAGE_SIZE, device)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image, bbox, rev_tensor = transform(image, torch.zeros(0, 5))\n", "image = image.to(device)[None]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " predict = model_trt(image)\n", " predict = converter(predict[\"Main\"])\n", "predict_box = bbox_nms(predict[0], predict[2], NMSConfig(0.5, 0.5))\n", "draw_bboxes(image, predict_box)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample Output:\n", "\n", "![image](../demo/images/output/visualize.png)" ] } ], "metadata": { "kernelspec": { "display_name": "yolomit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }