glenn-jocher commited on
Commit
9dd33fd
·
unverified ·
1 Parent(s): dd62e2d

AutoShape PosixPath support (#4047)

Browse files

* AutoShape PosixPath support

Usage example:

````python
from pathlib import Path

model = ...
file = Path('data/images/zidane.jpg')

results = model(file)
```

* Update common.py

Files changed (1) hide show
  1. models/common.py +5 -5
models/common.py CHANGED
@@ -1,7 +1,7 @@
1
  # YOLOv5 common modules
2
 
3
  from copy import copy
4
- from pathlib import Path
5
 
6
  import math
7
  import numpy as np
@@ -232,8 +232,8 @@ class AutoShape(nn.Module):
232
  @torch.no_grad()
233
  def forward(self, imgs, size=640, augment=False, profile=False):
234
  # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
235
- # filename: imgs = 'data/images/zidane.jpg'
236
- # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
237
  # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
238
  # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
239
  # numpy: = np.zeros((640,1280,3)) # HWC
@@ -251,8 +251,8 @@ class AutoShape(nn.Module):
251
  shape0, shape1, files = [], [], [] # image and inference shapes, filenames
252
  for i, im in enumerate(imgs):
253
  f = f'image{i}' # filename
254
- if isinstance(im, str): # filename or uri
255
- im, f = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im), im
256
  im = np.asarray(exif_transpose(im))
257
  elif isinstance(im, Image.Image): # PIL Image
258
  im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
 
1
  # YOLOv5 common modules
2
 
3
  from copy import copy
4
+ from pathlib import Path, PosixPath
5
 
6
  import math
7
  import numpy as np
 
232
  @torch.no_grad()
233
  def forward(self, imgs, size=640, augment=False, profile=False):
234
  # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
235
+ # filename: imgs = 'data/images/zidane.jpg' # str or PosixPath
236
+ # URI: = 'https://ultralytics.com/images/zidane.jpg'
237
  # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
238
  # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
239
  # numpy: = np.zeros((640,1280,3)) # HWC
 
251
  shape0, shape1, files = [], [], [] # image and inference shapes, filenames
252
  for i, im in enumerate(imgs):
253
  f = f'image{i}' # filename
254
+ if isinstance(im, (str, PosixPath)): # filename or uri
255
+ im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
256
  im = np.asarray(exif_transpose(im))
257
  elif isinstance(im, Image.Image): # PIL Image
258
  im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f