Daniel Khromov glenn-jocher commited on
Commit
3e560e2
·
unverified ·
1 Parent(s): 17ac94b

YOLOv5 PyTorch Hub results.save() method retains filenames (#2194)

Browse files

* save results with name

* debug

* save original imgs names

* Update common.py

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. models/common.py +9 -6
models/common.py CHANGED
@@ -196,10 +196,11 @@ class autoShape(nn.Module):
196
 
197
  # Pre-process
198
  n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
199
- shape0, shape1 = [], [] # image and inference shapes
200
  for i, im in enumerate(imgs):
201
  if isinstance(im, str): # filename or uri
202
  im = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im) # open
 
203
  im = np.array(im) # to numpy
204
  if im.shape[0] < 5: # image in CHW
205
  im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
@@ -224,18 +225,19 @@ class autoShape(nn.Module):
224
  for i in range(n):
225
  scale_coords(shape1, y[i][:, :4], shape0[i])
226
 
227
- return Detections(imgs, y, self.names)
228
 
229
 
230
  class Detections:
231
  # detections class for YOLOv5 inference results
232
- def __init__(self, imgs, pred, names=None):
233
  super(Detections, self).__init__()
234
  d = pred[0].device # device
235
  gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
236
  self.imgs = imgs # list of images as numpy arrays
237
  self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
238
  self.names = names # class names
 
239
  self.xyxy = pred # xyxy pixels
240
  self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
241
  self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
@@ -258,9 +260,9 @@ class Detections:
258
  if pprint:
259
  print(str.rstrip(', '))
260
  if show:
261
- img.show(f'image {i}') # show
262
  if save:
263
- f = Path(save_dir) / f'results{i}.jpg'
264
  img.save(f) # save
265
  print(f"{'Saving' * (i == 0)} {f},", end='' if i < self.n - 1 else ' done.\n')
266
  if render:
@@ -272,7 +274,8 @@ class Detections:
272
  def show(self):
273
  self.display(show=True) # show results
274
 
275
- def save(self, save_dir=''):
 
276
  self.display(save=True, save_dir=save_dir) # save results
277
 
278
  def render(self):
 
196
 
197
  # Pre-process
198
  n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
199
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
200
  for i, im in enumerate(imgs):
201
  if isinstance(im, str): # filename or uri
202
  im = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im) # open
203
+ files.append(Path(im.filename).with_suffix('.jpg').name if isinstance(im, Image.Image) else f'image{i}.jpg')
204
  im = np.array(im) # to numpy
205
  if im.shape[0] < 5: # image in CHW
206
  im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
 
225
  for i in range(n):
226
  scale_coords(shape1, y[i][:, :4], shape0[i])
227
 
228
+ return Detections(imgs, y, files, self.names)
229
 
230
 
231
  class Detections:
232
  # detections class for YOLOv5 inference results
233
+ def __init__(self, imgs, pred, files, names=None):
234
  super(Detections, self).__init__()
235
  d = pred[0].device # device
236
  gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
237
  self.imgs = imgs # list of images as numpy arrays
238
  self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
239
  self.names = names # class names
240
+ self.files = files # image filenames
241
  self.xyxy = pred # xyxy pixels
242
  self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
243
  self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
 
260
  if pprint:
261
  print(str.rstrip(', '))
262
  if show:
263
+ img.show(self.files[i]) # show
264
  if save:
265
+ f = Path(save_dir) / self.files[i]
266
  img.save(f) # save
267
  print(f"{'Saving' * (i == 0)} {f},", end='' if i < self.n - 1 else ' done.\n')
268
  if render:
 
274
  def show(self):
275
  self.display(show=True) # show results
276
 
277
+ def save(self, save_dir='results/'):
278
+ Path(save_dir).mkdir(exist_ok=True)
279
  self.display(save=True, save_dir=save_dir) # save results
280
 
281
  def render(self):