tidalove commited on
Commit
7fa23ea
·
verified ·
1 Parent(s): 04e0ad9

clearer padding logic

Browse files
Files changed (1) hide show
  1. tools/demo_api.py +13 -8
tools/demo_api.py CHANGED
@@ -75,10 +75,9 @@ class Predictor(object):
75
 
76
  ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
77
  pad = torch.tensor([0,0,0,0])
78
- pad[1] = ( self.test_size[0] - img.shape[0] * ratio ) / 2
79
- pad[0] = ( self.test_size[1] - img.shape[1] * ratio ) / 2
80
- pad[2] = pad[0]
81
- pad[3] = pad[1]
82
  img_info["ratio"] = ratio
83
 
84
  img, _ = self.preproc(img, None, self.test_size)
@@ -99,7 +98,7 @@ class Predictor(object):
99
  self.nmsthre, class_agnostic=True
100
  )
101
  logger.info("Infer time: {:.4f}s".format(time.time() - t0))
102
- return outputs, img_info, pad
103
 
104
  def build_predictor(
105
  exp_file, model_name, ckpt_path, device="cpu", fp16=False, fuse=False, trt=False, conf=0, nms=0, tsize=None
@@ -150,8 +149,9 @@ def run_detection(predictor, path):
150
 
151
  for img_id, image_name in enumerate(files):
152
 
153
- outputs, img_info, pad = predictor.inference(image_name)
154
  ratio = img_info["ratio"]
 
155
 
156
  img_entry = {"id": img_id,
157
  "filename": image_name }
@@ -159,10 +159,15 @@ def run_detection(predictor, path):
159
 
160
  if outputs[0] is not None:
161
  for id, output in enumerate(outputs[0]):
162
- print(output)
 
 
 
 
 
163
  ann_entry = {"id": id,
164
  "image_id": img_id,
165
- "bbox": ((output[:4] - pad) / ratio).tolist(),
166
  "cls": output[6].item(),
167
  "score": (output[4] * output[5]).item() }
168
  ann_list.append(ann_entry)
 
75
 
76
  ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
77
  pad = torch.tensor([0,0,0,0])
78
+ pad_h = ( self.test_size[0] - img.shape[0] * ratio ) / 2
79
+ pad_w = ( self.test_size[1] - img.shape[1] * ratio ) / 2
80
+ img_info["pad"] = (pad_w, pad_h)
 
81
  img_info["ratio"] = ratio
82
 
83
  img, _ = self.preproc(img, None, self.test_size)
 
98
  self.nmsthre, class_agnostic=True
99
  )
100
  logger.info("Infer time: {:.4f}s".format(time.time() - t0))
101
+ return outputs, img_info
102
 
103
  def build_predictor(
104
  exp_file, model_name, ckpt_path, device="cpu", fp16=False, fuse=False, trt=False, conf=0, nms=0, tsize=None
 
149
 
150
  for img_id, image_name in enumerate(files):
151
 
152
+ outputs, img_info = predictor.inference(image_name)
153
  ratio = img_info["ratio"]
154
+ pad_w, pad_h = img_info["pad"]
155
 
156
  img_entry = {"id": img_id,
157
  "filename": image_name }
 
159
 
160
  if outputs[0] is not None:
161
  for id, output in enumerate(outputs[0]):
162
+ print(output)
163
+ x1, y1, x2, y2 = output[:4]
164
+ x1 = (x1 - pad_w) / ratio
165
+ y1 = (y1 - pad_h) / ratio
166
+ x2 = (x2 - pad_w) / ratio
167
+ y2 = (y2 - pad_h) / ratio
168
  ann_entry = {"id": id,
169
  "image_id": img_id,
170
+ "bbox": [float(x1), float(y1), float(x2), float(y2)],
171
  "cls": output[6].item(),
172
  "score": (output[4] * output[5]).item() }
173
  ann_list.append(ann_entry)