henry000 commited on
Commit
b6b57c7
Β·
1 Parent(s): 15f0a98

πŸš€ [Add] the example: using sliding windows in YOLO

Browse files
examples/notebook_inference.ipynb CHANGED
@@ -25,15 +25,14 @@
25
  "source": [
26
  "CONFIG_PATH = \"../yolo/config\"\n",
27
  "CONFIG_NAME = \"config\"\n",
 
28
  "\n",
29
  "DEVICE = 'cuda:0'\n",
30
  "CLASS_NUM = 80\n",
31
- "WEIGHT_PATH = '../weights/v9-c.pt' \n",
32
  "IMAGE_PATH = '../demo/images/inference/image.png'\n",
33
  "\n",
34
  "custom_logger()\n",
35
- "device = torch.device(DEVICE)\n",
36
- "image = Image.open(IMAGE_PATH)"
37
  ]
38
  },
39
  {
@@ -43,8 +42,8 @@
43
  "outputs": [],
44
  "source": [
45
  "with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
46
- " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", \"model=v9-m\"])\n",
47
- " model = create_model(cfg.model, class_num=CLASS_NUM, weight_path=WEIGHT_PATH, device = device)\n",
48
  " transform = AugmentationComposer([], cfg.image_size)\n",
49
  " vec2box = Vec2Box(model, cfg.image_size, device)"
50
  ]
@@ -55,8 +54,10 @@
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
58
- "image, bbox = transform(image, torch.zeros(0, 5))\n",
59
- "image = image.to(device)[None]"
 
 
60
  ]
61
  },
62
  {
@@ -67,10 +68,11 @@
67
  "source": [
68
  "with torch.no_grad():\n",
69
  " predict = model(image)\n",
70
- " predict = vec2box(predict[\"Main\"])\n",
71
  "\n",
72
- "predict_box = bbox_nms(predict[0], predict[2], cfg.task.nms)\n",
73
- "draw_bboxes(image, predict_box, idx2label=cfg.class_list)"
 
74
  ]
75
  },
76
  {
 
25
  "source": [
26
  "CONFIG_PATH = \"../yolo/config\"\n",
27
  "CONFIG_NAME = \"config\"\n",
28
+ "MODEL = \"v9-c\"\n",
29
  "\n",
30
  "DEVICE = 'cuda:0'\n",
31
  "CLASS_NUM = 80\n",
 
32
  "IMAGE_PATH = '../demo/images/inference/image.png'\n",
33
  "\n",
34
  "custom_logger()\n",
35
+ "device = torch.device(DEVICE)"
 
36
  ]
37
  },
38
  {
 
42
  "outputs": [],
43
  "source": [
44
  "with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
45
+ " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
46
+ " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
47
  " transform = AugmentationComposer([], cfg.image_size)\n",
48
  " vec2box = Vec2Box(model, cfg.image_size, device)"
49
  ]
 
54
  "metadata": {},
55
  "outputs": [],
56
  "source": [
57
+ "pil_image = Image.open(IMAGE_PATH)\n",
58
+ "image, bbox, rev_tensor = transform(pil_image)\n",
59
+ "image = image.to(device)[None]\n",
60
+ "rev_tensor = rev_tensor.to(device)"
61
  ]
62
  },
63
  {
 
68
  "source": [
69
  "with torch.no_grad():\n",
70
  " predict = model(image)\n",
71
+ " pred_class, _, pred_bbox = vec2box(predict[\"Main\"])\n",
72
  "\n",
73
+ "pred_bbox = (pred_bbox / rev_tensor[0] - rev_tensor[None, None, 1:]) \n",
74
+ "pred_bbox = bbox_nms(pred_class, pred_bbox, cfg.task.nms)\n",
75
+ "draw_bboxes(pil_image, pred_bbox, idx2label=cfg.class_list)"
76
  ]
77
  },
78
  {
examples/notebook_smallobject.ipynb ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "import sys\n",
20
+ "from pathlib import Path\n",
21
+ "\n",
22
+ "import torch\n",
23
+ "from hydra import compose, initialize\n",
24
+ "from PIL import Image \n",
25
+ "from einops import rearrange\n",
26
+ "\n",
27
+ "# Ensure that the necessary repository is cloned and installed. You may need to run: \n",
28
+ "# git clone [email protected]:WongKinYiu/YOLO.git\n",
29
+ "# cd YOLO \n",
30
+ "# pip install .\n",
31
+ "project_root = Path().resolve().parent\n",
32
+ "sys.path.append(str(project_root))\n",
33
+ "from yolo.config.config import NMSConfig\n",
34
+ "from yolo import AugmentationComposer, bbox_nms, Config, create_model, custom_logger, draw_bboxes, Vec2Box"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "CONFIG_PATH = \"../yolo/config\"\n",
44
+ "CONFIG_NAME = \"config\"\n",
45
+ "MODEL = \"v9-c\"\n",
46
+ "\n",
47
+ "DEVICE = 'cuda:0'\n",
48
+ "CLASS_NUM = 80\n",
49
+ "IMAGE_PATH = '../image.png'\n",
50
+ "SLIDE = 4\n",
51
+ "\n",
52
+ "custom_logger()\n",
53
+ "device = torch.device(DEVICE)"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
63
+ " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
64
+ " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
65
+ " transform = AugmentationComposer([], cfg.image_size)\n",
66
+ " vec2box = Vec2Box(model, cfg.image_size, device)"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "pil_image = Image.open(IMAGE_PATH)\n",
76
+ "image, bbox, rev_tensor = transform(pil_image)\n",
77
+ "image = image.to(device)[None]\n",
78
+ "rev_tensor = rev_tensor.to(device)"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "def slide_image(image, slide = 4, device = device):\n",
88
+ " up_image = torch.nn.functional.interpolate(image, scale_factor=slide)\n",
89
+ " image_list = [image]\n",
90
+ " shift_list = []\n",
91
+ " *_, w, h = up_image.shape\n",
92
+ " for x_slide in range(slide):\n",
93
+ " for y_slide in range(slide):\n",
94
+ " left_w, right_w = w // slide * x_slide, w // slide * (x_slide + 1)\n",
95
+ " left_h, right_h = h // slide * y_slide, h // slide * (y_slide + 1)\n",
96
+ " slide_image = up_image[:, :, left_w: right_w, left_h: right_h]\n",
97
+ " image_list.append(slide_image)\n",
98
+ " shift_list.append(torch.Tensor([left_h, left_w, left_h, left_w]))\n",
99
+ " total_image = torch.concat(image_list)\n",
100
+ " total_shift = torch.stack(shift_list).to(device)\n",
101
+ "\n",
102
+ " return total_image, total_shift"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "with torch.no_grad():\n",
112
+ " total_image, total_shift = slide_image(image)\n",
113
+ " predict = model(total_image)\n",
114
+ " pred_class, _, pred_bbox = vec2box(predict[\"Main\"])\n",
115
+ "pred_bbox[1:] = (pred_bbox[1: ] + total_shift[:, None]) / SLIDE\n",
116
+ "pred_bbox = pred_bbox.view(1, -1, 4)\n",
117
+ "pred_class = pred_class.view(1, -1, 80)"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": [
126
+ "pred_bbox = (pred_bbox / rev_tensor[0] - rev_tensor[None, None, 1:]) "
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "predict_box = bbox_nms(pred_class, pred_bbox, NMSConfig(0.5, 0.5))\n",
136
+ "draw_bboxes(pil_image, predict_box, idx2label=cfg.class_list)"
137
+ ]
138
+ }
139
+ ],
140
+ "metadata": {
141
+ "kernelspec": {
142
+ "display_name": "yolomit",
143
+ "language": "python",
144
+ "name": "python3"
145
+ },
146
+ "language_info": {
147
+ "codemirror_mode": {
148
+ "name": "ipython",
149
+ "version": 3
150
+ },
151
+ "file_extension": ".py",
152
+ "mimetype": "text/x-python",
153
+ "name": "python",
154
+ "nbconvert_exporter": "python",
155
+ "pygments_lexer": "ipython3",
156
+ "version": "3.10.14"
157
+ }
158
+ },
159
+ "nbformat": 4,
160
+ "nbformat_minor": 2
161
+ }
yolo/model/yolo.py CHANGED
@@ -119,7 +119,7 @@ class YOLO(nn.Module):
119
  raise ValueError(f"Unsupported layer type: {layer_type}")
120
 
121
 
122
- def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str], class_num: int = 80) -> YOLO:
123
  """Constructs and returns a model from a Dictionary configuration file.
124
 
125
  Args:
 
119
  raise ValueError(f"Unsupported layer type: {layer_type}")
120
 
121
 
122
+ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str] = True, class_num: int = 80) -> YOLO:
123
  """Constructs and returns a model from a Dictionary configuration file.
124
 
125
  Args: