henry000 commited on
Commit
c9112cd
Β·
1 Parent(s): 2222e19

πŸš€ [Update] PostProccess for huggingface demo

Browse files
Files changed (2) hide show
  1. demo/hf_demo.py +7 -8
  2. yolo/__init__.py +2 -0
demo/hf_demo.py CHANGED
@@ -10,8 +10,8 @@ sys.path.append(str(Path(__file__).resolve().parent.parent))
10
  from yolo import (
11
  AugmentationComposer,
12
  NMSConfig,
 
13
  Vec2Box,
14
- bbox_nms,
15
  create_model,
16
  draw_bboxes,
17
  )
@@ -37,7 +37,7 @@ transform = AugmentationComposer([])
37
 
38
 
39
  def predict(model_name, image, nms_confidence, nms_iou):
40
- global DEFAULT_MODEL, model, device, v2b, class_list
41
  if model_name != DEFAULT_MODEL:
42
  model = load_model(model_name, device)
43
  v2b = Vec2Box(model, IMAGE_SIZE, device)
@@ -46,16 +46,15 @@ def predict(model_name, image, nms_confidence, nms_iou):
46
  image_tensor, _, rev_tensor = transform(image)
47
 
48
  image_tensor = image_tensor.to(device)[None]
49
- rev_tensor = rev_tensor.to(device)
 
 
 
50
 
51
  with torch.no_grad():
52
  predict = model(image_tensor)
53
- pred_class, _, pred_bbox = v2b(predict["Main"])
54
-
55
- nms_config = NMSConfig(nms_confidence, nms_iou)
56
 
57
- pred_bbox = pred_bbox / rev_tensor[0] - rev_tensor[None, None, 1:]
58
- pred_bbox = bbox_nms(pred_class, pred_bbox, nms_config)
59
  result_image = draw_bboxes(image, pred_bbox, idx2label=class_list)
60
 
61
  return result_image
 
10
  from yolo import (
11
  AugmentationComposer,
12
  NMSConfig,
13
+ PostProccess,
14
  Vec2Box,
 
15
  create_model,
16
  draw_bboxes,
17
  )
 
37
 
38
 
39
  def predict(model_name, image, nms_confidence, nms_iou):
40
+ global DEFAULT_MODEL, model, device, v2b, class_list, post_proccess
41
  if model_name != DEFAULT_MODEL:
42
  model = load_model(model_name, device)
43
  v2b = Vec2Box(model, IMAGE_SIZE, device)
 
46
  image_tensor, _, rev_tensor = transform(image)
47
 
48
  image_tensor = image_tensor.to(device)[None]
49
+ rev_tensor = rev_tensor.to(device)[None]
50
+
51
+ nms_config = NMSConfig(nms_confidence, nms_iou)
52
+ post_proccess = PostProccess(v2b, nms_config)
53
 
54
  with torch.no_grad():
55
  predict = model(image_tensor)
56
+ pred_bbox = post_proccess(predict, rev_tensor)
 
 
57
 
 
 
58
  result_image = draw_bboxes(image, pred_bbox, idx2label=class_list)
59
 
60
  return result_image
yolo/__init__.py CHANGED
@@ -6,6 +6,7 @@ from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
 
9
 
10
  all = [
11
  "create_model",
@@ -22,4 +23,5 @@ all = [
22
  "ModelTester",
23
  "ModelTrainer",
24
  "ModelValidator",
 
25
  ]
 
6
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
9
+ from yolo.utils.model_utils import PostProccess
10
 
11
  all = [
12
  "create_model",
 
23
  "ModelTester",
24
  "ModelTrainer",
25
  "ModelValidator",
26
+ "PostProccess",
27
  ]