glenn-jocher commited on
Commit
15a1060
·
unverified ·
1 Parent(s): 187f7c2

autoshape() update for PIL greyscale inputs (#1279)

Browse files

* autoshape update for PIL greyscale inputs

* autoshape update for PIL greyscale inputs

Files changed (1) hide show
  1. models/common.py +3 -1
models/common.py CHANGED
@@ -1,6 +1,7 @@
1
  # This file contains modules common to various models
2
 
3
  import math
 
4
  import numpy as np
5
  import torch
6
  import torch.nn as nn
@@ -144,7 +145,8 @@ class autoShape(nn.Module):
144
  shape0, shape1 = [], [] # image and inference shapes
145
  batch = range(len(x)) # batch size
146
  for i in batch:
147
- x[i] = np.array(x[i])[:, :, :3] # up to 3 channels if png
 
148
  s = x[i].shape[:2] # HWC
149
  shape0.append(s) # image shape
150
  g = (size / max(s)) # gain
 
1
  # This file contains modules common to various models
2
 
3
  import math
4
+
5
  import numpy as np
6
  import torch
7
  import torch.nn as nn
 
145
  shape0, shape1 = [], [] # image and inference shapes
146
  batch = range(len(x)) # batch size
147
  for i in batch:
148
+ x[i] = np.array(x[i]) # to numpy
149
+ x[i] = x[i][:, :, :3] if x[i].ndim == 3 else np.tile(x[i][:, :, None], 3) # enforce 3ch input
150
  s = x[i].shape[:2] # HWC
151
  shape0.append(s) # image shape
152
  g = (size / max(s)) # gain