File size: 3,415 Bytes
7a28749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "from PIL import Image \n",
    "from loguru import logger\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "project_root = Path().resolve().parent\n",
    "sys.path.append(str(project_root))\n",
    "\n",
    "from yolo import AugmentationComposer, bbox_nms, create_model, custom_logger, draw_bboxes, Vec2Box\n",
    "from yolo.config.config import NMSConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"v9-c\"\n",
    "DEVICE = \"cuda:0\"\n",
    "\n",
    "WEIGHT_PATH = f\"../weights/{MODEL}.pt\" \n",
    "TRT_WEIGHT_PATH = f\"../weights/{MODEL}.trt\"\n",
    "MODEL_CONFIG = f\"../yolo/config/model/{MODEL}.yaml\"\n",
    "\n",
    "IMAGE_PATH = \"../demo/images/inference/image.png\"\n",
    "IMAGE_SIZE = (640, 640)\n",
    "\n",
    "custom_logger()\n",
    "device = torch.device(DEVICE)\n",
    "image = Image.open(IMAGE_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(TRT_WEIGHT_PATH):\n",
    "    from torch2trt import TRTModule\n",
    "\n",
    "    model_trt = TRTModule()\n",
    "    model_trt.load_state_dict(torch.load(TRT_WEIGHT_PATH))\n",
    "else:\n",
    "    from torch2trt import torch2trt\n",
    "\n",
    "    with open(MODEL_CONFIG) as stream:\n",
    "        cfg_model = OmegaConf.load(stream)\n",
    "\n",
    "    model = create_model(cfg_model, weight_path=WEIGHT_PATH)\n",
    "    model = model.to(device).eval()\n",
    "\n",
    "    dummy_input = torch.ones((1, 3, 640, 640)).to(device)\n",
    "    logger.info(f\"♻️ Creating TensorRT model\")\n",
    "    model_trt = torch2trt(model, [dummy_input])\n",
    "    torch.save(model_trt.state_dict(), TRT_WEIGHT_PATH)\n",
    "    logger.info(f\"📥 TensorRT model saved to oonx.pt\")\n",
    "\n",
    "transform = AugmentationComposer([], IMAGE_SIZE)\n",
    "vec2box = Vec2Box(model_trt, IMAGE_SIZE, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image, bbox = transform(image, torch.zeros(0, 5))\n",
    "image = image.to(device)[None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    predict = model_trt(image)\n",
    "    predict = vec2box(predict[\"Main\"])\n",
    "predict_box = bbox_nms(predict[0], predict[2], NMSConfig(0.5, 0.5))\n",
    "draw_bboxes(image, predict_box)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sample Output:\n",
    "\n",
    "![image](../demo/images/output/visualize.png)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "yolomit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.1.undefined"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}