Paul-Edouard Sarlin commited on
Commit
9aff230
·
unverified ·
1 Parent(s): 09f8fa3

Fix bug with torchvision resize (#37)

Browse files
Files changed (2) hide show
  1. maploc/data/image.py +2 -2
  2. maploc/demo.py +3 -3
maploc/data/image.py CHANGED
@@ -69,11 +69,11 @@ def resize_image(
69
  if fn is not None:
70
  assert isinstance(size, int)
71
  scale = size / fn(h, w)
72
- h_new, w_new = int(round(h * scale)), int(round(w * scale))
73
  scale = (scale, scale)
74
  else:
75
  if isinstance(size, (collections.abc.Sequence, np.ndarray)):
76
- w_new, h_new = size
77
  elif isinstance(size, int):
78
  w_new = h_new = size
79
  else:
 
69
  if fn is not None:
70
  assert isinstance(size, int)
71
  scale = size / fn(h, w)
72
+ h_new, w_new = (int(round(x * scale)) for x in (h, w))
73
  scale = (scale, scale)
74
  else:
75
  if isinstance(size, (collections.abc.Sequence, np.ndarray)):
76
+ w_new, h_new = (int(x) for x in size)
77
  elif isinstance(size, int):
78
  w_new = h_new = size
79
  else:
maploc/demo.py CHANGED
@@ -170,14 +170,14 @@ class Demo:
170
  pitch=-pitch,
171
  )
172
  image, _, camera, *maybe_valid = resize_image(
173
- image, size.numpy(), camera=camera, valid=valid
174
  )
175
  valid = None if valid is None else maybe_valid
176
 
177
  max_stride = max(self.model.image_encoder.layer_strides)
178
- size = (np.ceil((size / max_stride)) * max_stride).int()
179
  image, valid, camera = pad_image(
180
- image, size.numpy(), camera, crop_and_center=True
181
  )
182
 
183
  return dict(
 
170
  pitch=-pitch,
171
  )
172
  image, _, camera, *maybe_valid = resize_image(
173
+ image, size.tolist(), camera=camera, valid=valid
174
  )
175
  valid = None if valid is None else maybe_valid
176
 
177
  max_stride = max(self.model.image_encoder.layer_strides)
178
+ size = (torch.ceil(size / max_stride) * max_stride).int()
179
  image, valid, camera = pad_image(
180
+ image, size.tolist(), camera, crop_and_center=True
181
  )
182
 
183
  return dict(