henry000 commited on
Commit
329fd0a
Β·
1 Parent(s): 2b2044d

πŸš€ [Update] the postproccess -> PostProccess class

Browse files
Files changed (1) hide show
  1. yolo/tools/solver.py +3 -5
yolo/tools/solver.py CHANGED
@@ -116,10 +116,9 @@ class ModelTester:
116
  def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
117
  self.model = model
118
  self.device = device
119
- self.vec2box = vec2box
120
  self.progress = progress
121
 
122
- self.nms = cfg.task.nms
123
  self.save_path = os.path.join(progress.save_path, "images")
124
  os.makedirs(self.save_path, exist_ok=True)
125
  self.save_predict = getattr(cfg.task, "save_predict", None)
@@ -140,9 +139,8 @@ class ModelTester:
140
  images = images.to(self.device)
141
  with torch.no_grad():
142
  predicts = self.model(images)
143
- predicts = self.vec2box(predicts["Main"])
144
- nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
145
- img = draw_bboxes(images, nms_out, idx2label=self.idx2label)
146
 
147
  if dataloader.is_stream:
148
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
 
116
  def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
117
  self.model = model
118
  self.device = device
 
119
  self.progress = progress
120
 
121
+ self.post_proccess = PostProccess(vec2box, cfg.task.nms)
122
  self.save_path = os.path.join(progress.save_path, "images")
123
  os.makedirs(self.save_path, exist_ok=True)
124
  self.save_predict = getattr(cfg.task, "save_predict", None)
 
139
  images = images.to(self.device)
140
  with torch.no_grad():
141
  predicts = self.model(images)
142
+ predicts = self.post_proccess(predicts, rev_tensor)
143
+ img = draw_bboxes(origin_frame, predicts, idx2label=self.idx2label)
 
144
 
145
  if dataloader.is_stream:
146
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)