File size: 3,587 Bytes
7a28749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1bfd74
 
 
 
 
 
 
 
 
7a28749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1bfd74
 
7a28749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1bfd74
7a28749
 
 
 
 
 
 
 
f1bfd74
7a28749
 
 
 
 
 
 
 
 
 
 
f1bfd74
7a28749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1bfd74
7a28749
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "from PIL import Image \n",
    "from loguru import logger\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "project_root = Path().resolve().parent\n",
    "sys.path.append(str(project_root))\n",
    "\n",
    "from yolo import (\n",
    "    AugmentationComposer, \n",
    "    bbox_nms, \n",
    "    create_model, \n",
    "    custom_logger, \n",
    "    create_converter,\n",
    "    draw_bboxes, \n",
    "    Vec2Box\n",
    ")\n",
    "from yolo.config.config import NMSConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"v9-c\"\n",
    "DEVICE = \"cuda:0\"\n",
    "\n",
    "WEIGHT_PATH = f\"../weights/{MODEL}.pt\" \n",
    "TRT_WEIGHT_PATH = f\"../weights/{MODEL}.trt\"\n",
    "MODEL_CONFIG = f\"../yolo/config/model/{MODEL}.yaml\"\n",
    "\n",
    "IMAGE_PATH = \"../demo/images/inference/image.png\"\n",
    "IMAGE_SIZE = (640, 640)\n",
    "\n",
    "custom_logger()\n",
    "device = torch.device(DEVICE)\n",
    "image = Image.open(IMAGE_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(MODEL_CONFIG) as stream:\n",
    "    cfg_model = OmegaConf.load(stream)\n",
    "if os.path.exists(TRT_WEIGHT_PATH):\n",
    "    from torch2trt import TRTModule\n",
    "\n",
    "    model_trt = TRTModule()\n",
    "    model_trt.load_state_dict(torch.load(TRT_WEIGHT_PATH))\n",
    "else:\n",
    "    from torch2trt import torch2trt\n",
    "\n",
    "\n",
    "    model = create_model(cfg_model, weight_path=WEIGHT_PATH)\n",
    "    model = model.to(device).eval()\n",
    "\n",
    "    dummy_input = torch.ones((1, 3, 640, 640)).to(device)\n",
    "    logger.info(f\"♻️ Creating TensorRT model\")\n",
    "    model_trt = torch2trt(model, [dummy_input])\n",
    "    torch.save(model_trt.state_dict(), TRT_WEIGHT_PATH)\n",
    "    logger.info(f\"📥 TensorRT model saved to oonx.pt\")\n",
    "\n",
    "transform = AugmentationComposer([], IMAGE_SIZE)\n",
    "converter = create_converter(cfg_model.name, model_trt, cfg_model.anchor, IMAGE_SIZE, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image, bbox, rev_tensor = transform(image, torch.zeros(0, 5))\n",
    "image = image.to(device)[None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    predict = model_trt(image)\n",
    "    predict = converter(predict[\"Main\"])\n",
    "predict_box = bbox_nms(predict[0], predict[2], NMSConfig(0.5, 0.5))\n",
    "draw_bboxes(image, predict_box)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sample Output:\n",
    "\n",
    "![image](../demo/images/output/visualize.png)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "yolomit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}