henry000 commited on
Commit
054a1c7
Β·
1 Parent(s): 868c821

πŸ› [Fix] #56 bugs, create_converter -> Vec2Box

Browse files
Files changed (1) hide show
  1. demo/hf_demo.py +9 -9
demo/hf_demo.py CHANGED
@@ -11,7 +11,7 @@ from yolo import (
11
  AugmentationComposer,
12
  NMSConfig,
13
  PostProccess,
14
- Vec2Box,
15
  create_model,
16
  draw_bboxes,
17
  )
@@ -25,22 +25,22 @@ def load_model(model_name, device):
25
  model_cfg.model.auxiliary = {}
26
  model = create_model(model_cfg, True)
27
  model.to(device).eval()
28
- return model
29
 
30
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- model = load_model(DEFAULT_MODEL, device)
33
- v2b = Vec2Box(model, IMAGE_SIZE, device)
34
- class_list = OmegaConf.load("yolo/config/general.yaml").class_list
35
 
36
  transform = AugmentationComposer([])
37
 
38
 
39
  def predict(model_name, image, nms_confidence, nms_iou):
40
- global DEFAULT_MODEL, model, device, v2b, class_list, post_proccess
41
  if model_name != DEFAULT_MODEL:
42
- model = load_model(model_name, device)
43
- v2b = Vec2Box(model, IMAGE_SIZE, device)
44
  DEFAULT_MODEL = model_name
45
 
46
  image_tensor, _, rev_tensor = transform(image)
@@ -49,7 +49,7 @@ def predict(model_name, image, nms_confidence, nms_iou):
49
  rev_tensor = rev_tensor.to(device)[None]
50
 
51
  nms_config = NMSConfig(nms_confidence, nms_iou)
52
- post_proccess = PostProccess(v2b, nms_config)
53
 
54
  with torch.no_grad():
55
  predict = model(image_tensor)
 
11
  AugmentationComposer,
12
  NMSConfig,
13
  PostProccess,
14
+ create_converter,
15
  create_model,
16
  draw_bboxes,
17
  )
 
25
  model_cfg.model.auxiliary = {}
26
  model = create_model(model_cfg, True)
27
  model.to(device).eval()
28
+ return model, model_cfg
29
 
30
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ model, model_cfg = load_model(DEFAULT_MODEL, device)
33
+ converter = create_converter(model_cfg.name, model, model_cfg.anchor, IMAGE_SIZE, device)
34
+ class_list = OmegaConf.load("yolo/config/dataset/coco.yaml").class_list
35
 
36
  transform = AugmentationComposer([])
37
 
38
 
39
  def predict(model_name, image, nms_confidence, nms_iou):
40
+ global DEFAULT_MODEL, model, device, converter, class_list, post_proccess
41
  if model_name != DEFAULT_MODEL:
42
+ model, model_cfg = load_model(model_name, device)
43
+ converter = create_converter(model_cfg.name, model, model_cfg.anchor, IMAGE_SIZE, device)
44
  DEFAULT_MODEL = model_name
45
 
46
  image_tensor, _, rev_tensor = transform(image)
 
49
  rev_tensor = rev_tensor.to(device)[None]
50
 
51
  nms_config = NMSConfig(nms_confidence, nms_iou)
52
+ post_proccess = PostProccess(converter, nms_config)
53
 
54
  with torch.no_grad():
55
  predict = model(image_tensor)