π [Update] PostProccess for huggingface demo
Browse files- demo/hf_demo.py +7 -8
- yolo/__init__.py +2 -0
demo/hf_demo.py
CHANGED
@@ -10,8 +10,8 @@ sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
10 |
from yolo import (
|
11 |
AugmentationComposer,
|
12 |
NMSConfig,
|
|
|
13 |
Vec2Box,
|
14 |
-
bbox_nms,
|
15 |
create_model,
|
16 |
draw_bboxes,
|
17 |
)
|
@@ -37,7 +37,7 @@ transform = AugmentationComposer([])
|
|
37 |
|
38 |
|
39 |
def predict(model_name, image, nms_confidence, nms_iou):
|
40 |
-
global DEFAULT_MODEL, model, device, v2b, class_list
|
41 |
if model_name != DEFAULT_MODEL:
|
42 |
model = load_model(model_name, device)
|
43 |
v2b = Vec2Box(model, IMAGE_SIZE, device)
|
@@ -46,16 +46,15 @@ def predict(model_name, image, nms_confidence, nms_iou):
|
|
46 |
image_tensor, _, rev_tensor = transform(image)
|
47 |
|
48 |
image_tensor = image_tensor.to(device)[None]
|
49 |
-
rev_tensor = rev_tensor.to(device)
|
|
|
|
|
|
|
50 |
|
51 |
with torch.no_grad():
|
52 |
predict = model(image_tensor)
|
53 |
-
|
54 |
-
|
55 |
-
nms_config = NMSConfig(nms_confidence, nms_iou)
|
56 |
|
57 |
-
pred_bbox = pred_bbox / rev_tensor[0] - rev_tensor[None, None, 1:]
|
58 |
-
pred_bbox = bbox_nms(pred_class, pred_bbox, nms_config)
|
59 |
result_image = draw_bboxes(image, pred_bbox, idx2label=class_list)
|
60 |
|
61 |
return result_image
|
|
|
10 |
from yolo import (
|
11 |
AugmentationComposer,
|
12 |
NMSConfig,
|
13 |
+
PostProccess,
|
14 |
Vec2Box,
|
|
|
15 |
create_model,
|
16 |
draw_bboxes,
|
17 |
)
|
|
|
37 |
|
38 |
|
39 |
def predict(model_name, image, nms_confidence, nms_iou):
|
40 |
+
global DEFAULT_MODEL, model, device, v2b, class_list, post_proccess
|
41 |
if model_name != DEFAULT_MODEL:
|
42 |
model = load_model(model_name, device)
|
43 |
v2b = Vec2Box(model, IMAGE_SIZE, device)
|
|
|
46 |
image_tensor, _, rev_tensor = transform(image)
|
47 |
|
48 |
image_tensor = image_tensor.to(device)[None]
|
49 |
+
rev_tensor = rev_tensor.to(device)[None]
|
50 |
+
|
51 |
+
nms_config = NMSConfig(nms_confidence, nms_iou)
|
52 |
+
post_proccess = PostProccess(v2b, nms_config)
|
53 |
|
54 |
with torch.no_grad():
|
55 |
predict = model(image_tensor)
|
56 |
+
pred_bbox = post_proccess(predict, rev_tensor)
|
|
|
|
|
57 |
|
|
|
|
|
58 |
result_image = draw_bboxes(image, pred_bbox, idx2label=class_list)
|
59 |
|
60 |
return result_image
|
yolo/__init__.py
CHANGED
@@ -6,6 +6,7 @@ from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
|
6 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
from yolo.utils.logging_utils import custom_logger
|
|
|
9 |
|
10 |
all = [
|
11 |
"create_model",
|
@@ -22,4 +23,5 @@ all = [
|
|
22 |
"ModelTester",
|
23 |
"ModelTrainer",
|
24 |
"ModelValidator",
|
|
|
25 |
]
|
|
|
6 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
from yolo.utils.logging_utils import custom_logger
|
9 |
+
from yolo.utils.model_utils import PostProccess
|
10 |
|
11 |
all = [
|
12 |
"create_model",
|
|
|
23 |
"ModelTester",
|
24 |
"ModelTrainer",
|
25 |
"ModelValidator",
|
26 |
+
"PostProccess",
|
27 |
]
|