henry000 commited on
Commit
862884c
·
1 Parent(s): fd35390

♻️ [Refactor] Code of examples, use PostProccess

Browse files
examples/notebook_inference.ipynb CHANGED
@@ -6,15 +6,18 @@
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
 
 
9
  "import torch\n",
10
  "from hydra import compose, initialize\n",
11
  "from PIL import Image \n",
12
  "\n",
13
- "# Ensure that the necessary repository is cloned and installed. You may need to run: \n",
14
- "# git clone git@github.com:WongKinYiu/YOLO.git\n",
15
- "# cd YOLO \n",
16
- "# pip install .\n",
17
- "from yolo import AugmentationComposer, bbox_nms, Config, create_model, custom_logger, draw_bboxes, Vec2Box"
18
  ]
19
  },
20
  {
@@ -25,7 +28,7 @@
25
  "source": [
26
  "CONFIG_PATH = \"../yolo/config\"\n",
27
  "CONFIG_NAME = \"config\"\n",
28
- "MODEL = \"v9-c\"\n",
29
  "\n",
30
  "DEVICE = 'cuda:0'\n",
31
  "CLASS_NUM = 80\n",
@@ -45,7 +48,9 @@
45
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
46
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
47
  " transform = AugmentationComposer([], cfg.image_size)\n",
48
- " vec2box = Vec2Box(model, cfg.image_size, device)"
 
 
49
  ]
50
  },
51
  {
@@ -57,7 +62,7 @@
57
  "pil_image = Image.open(IMAGE_PATH)\n",
58
  "image, bbox, rev_tensor = transform(pil_image)\n",
59
  "image = image.to(device)[None]\n",
60
- "rev_tensor = rev_tensor.to(device)"
61
  ]
62
  },
63
  {
@@ -68,10 +73,8 @@
68
  "source": [
69
  "with torch.no_grad():\n",
70
  " predict = model(image)\n",
71
- " pred_class, _, pred_bbox = vec2box(predict[\"Main\"])\n",
72
  "\n",
73
- "pred_bbox = (pred_bbox / rev_tensor[0] - rev_tensor[None, None, 1:]) \n",
74
- "pred_bbox = bbox_nms(pred_class, pred_bbox, cfg.task.nms)\n",
75
  "draw_bboxes(pil_image, pred_bbox, idx2label=cfg.class_list)"
76
  ]
77
  },
@@ -83,6 +86,23 @@
83
  "\n",
84
  "![image](../demo/images/output/visualize.png)"
85
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  }
87
  ],
88
  "metadata": {
 
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
9
+ "import sys\n",
10
+ "from pathlib import Path\n",
11
+ "\n",
12
  "import torch\n",
13
  "from hydra import compose, initialize\n",
14
  "from PIL import Image \n",
15
  "\n",
16
+ "project_root = Path().resolve().parent\n",
17
+ "sys.path.append(str(project_root))\n",
18
+ "\n",
19
+ "from yolo import AugmentationComposer, Config, create_model, custom_logger, draw_bboxes, Vec2Box, PostProccess\n",
20
+ "from yolo.utils.bounding_box_utils import Anc2Box"
21
  ]
22
  },
23
  {
 
28
  "source": [
29
  "CONFIG_PATH = \"../yolo/config\"\n",
30
  "CONFIG_NAME = \"config\"\n",
31
+ "MODEL = \"v7-base\"\n",
32
  "\n",
33
  "DEVICE = 'cuda:0'\n",
34
  "CLASS_NUM = 80\n",
 
48
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
49
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
50
  " transform = AugmentationComposer([], cfg.image_size)\n",
51
+ " converter = Anc2Box(model, cfg.model.anchor, cfg.image_size, device)\n",
52
+ " # converter = Vec2Box(model, cfg.model.anchor, cfg.image_size, device)\n",
53
+ " post_proccess = PostProccess(converter, cfg.task.nms)"
54
  ]
55
  },
56
  {
 
62
  "pil_image = Image.open(IMAGE_PATH)\n",
63
  "image, bbox, rev_tensor = transform(pil_image)\n",
64
  "image = image.to(device)[None]\n",
65
+ "rev_tensor = rev_tensor.to(device)[None]"
66
  ]
67
  },
68
  {
 
73
  "source": [
74
  "with torch.no_grad():\n",
75
  " predict = model(image)\n",
76
+ " pred_bbox = post_proccess(predict, rev_tensor)\n",
77
  "\n",
 
 
78
  "draw_bboxes(pil_image, pred_bbox, idx2label=cfg.class_list)"
79
  ]
80
  },
 
86
  "\n",
87
  "![image](../demo/images/output/visualize.png)"
88
  ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "%load_ext autoreload\n",
97
+ "%autoreload 2"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": []
106
  }
107
  ],
108
  "metadata": {
examples/notebook_smallobject.ipynb CHANGED
@@ -22,7 +22,6 @@
22
  "import torch\n",
23
  "from hydra import compose, initialize\n",
24
  "from PIL import Image \n",
25
- "from einops import rearrange\n",
26
  "\n",
27
  "# Ensure that the necessary repository is cloned and installed. You may need to run: \n",
28
  "# git clone [email protected]:WongKinYiu/YOLO.git\n",
@@ -30,8 +29,8 @@
30
  "# pip install .\n",
31
  "project_root = Path().resolve().parent\n",
32
  "sys.path.append(str(project_root))\n",
33
- "from yolo.config.config import NMSConfig\n",
34
- "from yolo import AugmentationComposer, bbox_nms, Config, create_model, custom_logger, draw_bboxes, Vec2Box"
35
  ]
36
  },
37
  {
@@ -63,7 +62,9 @@
63
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
64
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
65
  " transform = AugmentationComposer([], cfg.image_size)\n",
66
- " vec2box = Vec2Box(model, cfg.image_size, device)"
 
 
67
  ]
68
  },
69
  {
@@ -75,7 +76,7 @@
75
  "pil_image = Image.open(IMAGE_PATH)\n",
76
  "image, bbox, rev_tensor = transform(pil_image)\n",
77
  "image = image.to(device)[None]\n",
78
- "rev_tensor = rev_tensor.to(device)"
79
  ]
80
  },
81
  {
@@ -114,7 +115,9 @@
114
  " pred_class, _, pred_bbox = vec2box(predict[\"Main\"])\n",
115
  "pred_bbox[1:] = (pred_bbox[1: ] + total_shift[:, None]) / SLIDE\n",
116
  "pred_bbox = pred_bbox.view(1, -1, 4)\n",
117
- "pred_class = pred_class.view(1, -1, 80)"
 
 
118
  ]
119
  },
120
  {
@@ -123,7 +126,7 @@
123
  "metadata": {},
124
  "outputs": [],
125
  "source": [
126
- "pred_bbox = (pred_bbox / rev_tensor[0] - rev_tensor[None, None, 1:]) "
127
  ]
128
  },
129
  {
@@ -131,10 +134,7 @@
131
  "execution_count": null,
132
  "metadata": {},
133
  "outputs": [],
134
- "source": [
135
- "predict_box = bbox_nms(pred_class, pred_bbox, NMSConfig(0.5, 0.5))\n",
136
- "draw_bboxes(pil_image, predict_box, idx2label=cfg.class_list)"
137
- ]
138
  }
139
  ],
140
  "metadata": {
 
22
  "import torch\n",
23
  "from hydra import compose, initialize\n",
24
  "from PIL import Image \n",
 
25
  "\n",
26
  "# Ensure that the necessary repository is cloned and installed. You may need to run: \n",
27
  "# git clone [email protected]:WongKinYiu/YOLO.git\n",
 
29
  "# pip install .\n",
30
  "project_root = Path().resolve().parent\n",
31
  "sys.path.append(str(project_root))\n",
32
+ "\n",
33
+ "from yolo import AugmentationComposer, bbox_nms, Config, create_model, custom_logger, draw_bboxes, Vec2Box, NMSConfig, PostProccess"
34
  ]
35
  },
36
  {
 
62
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
63
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
64
  " transform = AugmentationComposer([], cfg.image_size)\n",
65
+ " vec2box = Vec2Box(model, cfg.image_size, device)\n",
66
+ " post_proccess = PostProccess(vec2box, NMSConfig(0.5, 0.9))\n",
67
+ " "
68
  ]
69
  },
70
  {
 
76
  "pil_image = Image.open(IMAGE_PATH)\n",
77
  "image, bbox, rev_tensor = transform(pil_image)\n",
78
  "image = image.to(device)[None]\n",
79
+ "rev_tensor = rev_tensor.to(device)[None]"
80
  ]
81
  },
82
  {
 
115
  " pred_class, _, pred_bbox = vec2box(predict[\"Main\"])\n",
116
  "pred_bbox[1:] = (pred_bbox[1: ] + total_shift[:, None]) / SLIDE\n",
117
  "pred_bbox = pred_bbox.view(1, -1, 4)\n",
118
+ "pred_class = pred_class.view(1, -1, 80)\n",
119
+ "pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]\n",
120
+ "predict_box = bbox_nms(pred_class, pred_bbox, NMSConfig(0.3, 0.5))\n"
121
  ]
122
  },
123
  {
 
126
  "metadata": {},
127
  "outputs": [],
128
  "source": [
129
+ "draw_bboxes(pil_image, predict_box, idx2label=cfg.class_list)"
130
  ]
131
  },
132
  {
 
134
  "execution_count": null,
135
  "metadata": {},
136
  "outputs": [],
137
+ "source": []
 
 
 
138
  }
139
  ],
140
  "metadata": {