sarlinpe commited on
Commit
c42e81b
·
1 Parent(s): c222d08

Use the PerspectiveFields inference model

Browse files
Files changed (3) hide show
  1. demo.ipynb +0 -0
  2. maploc/demo.py +77 -85
  3. requirements/demo.txt +1 -1
demo.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
maploc/demo.py CHANGED
@@ -1,9 +1,10 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
- from typing import Optional, Tuple
4
 
5
  import numpy as np
6
  import torch
 
7
 
8
  from . import logger
9
  from .data.image import pad_image, rectify_image, resize_image
@@ -23,61 +24,51 @@ try:
23
  except ImportError:
24
  geolocator = None
25
 
26
- try:
27
- from gradio_client import Client
28
-
29
- calibrator = Client("https://jinlinyi-perspectivefields.hf.space/")
30
- except (ImportError, ValueError):
31
- calibrator = None
32
-
33
-
34
- def image_calibration(image_path):
35
- logger.info("Calling the PerspectiveFields calibrator, this may take some time.")
36
- result = calibrator.predict(
37
- image_path, "NEW:Paramnet-360Cities-edina-centered", api_name="/predict"
38
- )
39
- result = dict(r.rsplit(" ", 1) for r in result[1].split("\n"))
40
- roll_pitch = float(result["roll"]), float(result["pitch"])
41
- return roll_pitch, float(result["vertical fov"])
42
-
43
-
44
- def camera_from_exif(exif: EXIF, fov: Optional[float] = None) -> Camera:
45
- w, h = image_size = exif.extract_image_size()
46
- _, f_ratio = exif.extract_focal()
47
- if f_ratio == 0:
48
- if fov is not None:
49
- # This is the vertical FoV.
50
- f = h / 2 / np.tan(np.deg2rad(fov) / 2)
51
- else:
52
- return None
53
- else:
54
- f = f_ratio * max(image_size)
55
- return Camera.from_dict(
56
- dict(
57
- model="SIMPLE_PINHOLE",
58
- width=w,
59
- height=h,
60
- params=[f, w / 2 + 0.5, h / 2 + 0.5],
61
  )
62
- )
63
 
64
 
65
- def read_input_image(
66
- image_path: str,
67
  prior_latlon: Optional[Tuple[float, float]] = None,
68
  prior_address: Optional[str] = None,
69
- fov: Optional[float] = None,
70
- tile_size_meters: int = 64,
71
- ):
72
- image = read_image(image_path)
73
- with open(image_path, "rb") as fid:
74
- exif = EXIF(fid, lambda: image.shape[:2])
75
-
76
  latlon = None
77
  if prior_latlon is not None:
78
  latlon = prior_latlon
79
  logger.info("Using prior latlon %s.", prior_latlon)
80
- if prior_address is not None:
81
  if geolocator is None:
82
  raise ValueError("geocoding unavailable, install geopy.")
83
  location = geolocator.geocode(prior_address)
@@ -93,32 +84,11 @@ def read_input_image(
93
  latlon = (geo["latitude"], geo["longitude"], alt)
94
  logger.info("Using prior location from EXIF.")
95
  else:
96
- logger.info("Could not find any prior location in the image EXIF metadata.")
97
- if latlon is None:
98
- raise ValueError(
99
- "No location prior given or found in the image EXIF metadata: "
100
- "maybe provide the name of a street, building or neighborhood?"
101
- )
102
- latlon = np.array(latlon)
103
-
104
- roll_pitch = None
105
- if calibrator is not None:
106
- roll_pitch, fov = image_calibration(image_path)
107
- else:
108
- logger.info("Could not call PerspectiveFields, maybe install gradio_client?")
109
- if roll_pitch is not None:
110
- logger.info("Using (roll, pitch) %s.", roll_pitch)
111
-
112
- camera = camera_from_exif(exif, fov)
113
- if camera is None:
114
- raise ValueError(
115
- "No camera intrinsics found in the EXIF, provide an FoV guess."
116
- )
117
-
118
- proj = Projection(*latlon)
119
- center = proj.project(latlon)
120
- bbox = BoundaryBox(center, center) + tile_size_meters
121
- return image, camera, roll_pitch, proj, bbox, latlon
122
 
123
 
124
  class Demo:
@@ -141,19 +111,41 @@ class Demo:
141
  model.load_state_dict(state, strict=True)
142
  if device is None:
143
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
- model = model.to(device)
 
 
145
 
146
- self.model = model
147
  self.config = config
148
  self.device = device
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def prepare_data(
151
  self,
152
  image: np.ndarray,
153
  camera: Camera,
154
  canvas: Canvas,
155
- roll_pitch: Optional[Tuple[float]] = None,
156
- ):
157
  assert image.shape[:2][::-1] == tuple(camera.size.tolist())
158
  target_focal_length = self.config.data.resize_image / 2
159
  factor = target_focal_length / camera.f
@@ -161,8 +153,8 @@ class Demo:
161
 
162
  image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255)
163
  valid = None
164
- if roll_pitch is not None:
165
- roll, pitch = roll_pitch
166
  image, valid = rectify_image(
167
  image,
168
  camera.float(),
@@ -180,12 +172,12 @@ class Demo:
180
  image, size.tolist(), camera, crop_and_center=True
181
  )
182
 
183
- return dict(
184
- image=image,
185
- map=torch.from_numpy(canvas.raster).long(),
186
- camera=camera.float(),
187
- valid=valid,
188
- )
189
 
190
  def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs):
191
  data = self.prepare_data(image, camera, canvas, **kwargs)
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
+ from typing import Dict, Optional, Tuple
4
 
5
  import numpy as np
6
  import torch
7
+ from perspective2d import PerspectiveFields
8
 
9
  from . import logger
10
  from .data.image import pad_image, rectify_image, resize_image
 
24
  except ImportError:
25
  geolocator = None
26
 
27
+
28
+ class ImageCalibrator(PerspectiveFields):
29
+ def __init__(self, version: str = "Paramnet-360Cities-edina-centered"):
30
+ super().__init__(version)
31
+ self.eval()
32
+
33
+ def run(
34
+ self,
35
+ image_rgb: np.ndarray,
36
+ focal_length: Optional[float] = None,
37
+ exif: Optional[EXIF] = None,
38
+ ) -> Tuple[Tuple[float, float], Camera]:
39
+ h, w, *_ = image_rgb.shape
40
+ if focal_length is None and exif is not None:
41
+ _, focal_ratio = exif.extract_focal()
42
+ if focal_ratio != 0:
43
+ focal_length = focal_ratio * max(h, w)
44
+
45
+ calib = self.inference(img_bgr=image_rgb[..., ::-1])
46
+ roll_pitch = (calib["pred_roll"].item(), calib["pred_pitch"].item())
47
+ if focal_length is None:
48
+ vfov = calib["pred_vfov"].item()
49
+ focal_length = h / 2 / np.tan(np.deg2rad(vfov) / 2)
50
+
51
+ camera = Camera.from_dict(
52
+ {
53
+ "model": "SIMPLE_PINHOLE",
54
+ "width": w,
55
+ "height": h,
56
+ "params": [focal_length, w / 2 + 0.5, h / 2 + 0.5],
57
+ }
 
 
 
 
58
  )
59
+ return roll_pitch, camera
60
 
61
 
62
+ def parse_location_prior(
63
+ exif: EXIF,
64
  prior_latlon: Optional[Tuple[float, float]] = None,
65
  prior_address: Optional[str] = None,
66
+ ) -> np.ndarray:
 
 
 
 
 
 
67
  latlon = None
68
  if prior_latlon is not None:
69
  latlon = prior_latlon
70
  logger.info("Using prior latlon %s.", prior_latlon)
71
+ elif prior_address is not None:
72
  if geolocator is None:
73
  raise ValueError("geocoding unavailable, install geopy.")
74
  location = geolocator.geocode(prior_address)
 
84
  latlon = (geo["latitude"], geo["longitude"], alt)
85
  logger.info("Using prior location from EXIF.")
86
  else:
87
+ raise ValueError(
88
+ "No location prior given or found in the image EXIF metadata: "
89
+ "maybe provide the name of a street, building or neighborhood?"
90
+ )
91
+ return np.array(latlon)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  class Demo:
 
111
  model.load_state_dict(state, strict=True)
112
  if device is None:
113
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+ self.model = model.to(device)
115
+
116
+ self.calibrator = ImageCalibrator().to(device)
117
 
 
118
  self.config = config
119
  self.device = device
120
 
121
+ def read_input_image(
122
+ self,
123
+ image_path: str,
124
+ prior_latlon: Optional[Tuple[float, float]] = None,
125
+ prior_address: Optional[str] = None,
126
+ focal_length: Optional[float] = None,
127
+ tile_size_meters: int = 64,
128
+ ) -> Tuple[np.ndarray, Camera, Tuple[str, str], Projection, BoundaryBox]:
129
+ image = read_image(image_path)
130
+ with open(image_path, "rb") as fid:
131
+ exif = EXIF(fid, lambda: image.shape[:2])
132
+
133
+ gravity, camera = self.calibrator.run(image, focal_length, exif)
134
+ logger.info("Using (roll, pitch) %s.", gravity)
135
+
136
+ latlon = parse_location_prior(exif, prior_latlon, prior_address)
137
+ proj = Projection(*latlon)
138
+ center = proj.project(latlon)
139
+ bbox = BoundaryBox(center, center) + tile_size_meters
140
+ return image, camera, gravity, proj, bbox
141
+
142
  def prepare_data(
143
  self,
144
  image: np.ndarray,
145
  camera: Camera,
146
  canvas: Canvas,
147
+ gravity: Optional[Tuple[float]] = None,
148
+ ) -> Dict[str, torch.Tensor]:
149
  assert image.shape[:2][::-1] == tuple(camera.size.tolist())
150
  target_focal_length = self.config.data.resize_image / 2
151
  factor = target_focal_length / camera.f
 
153
 
154
  image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255)
155
  valid = None
156
+ if gravity is not None:
157
+ roll, pitch = gravity
158
  image, valid = rectify_image(
159
  image,
160
  camera.float(),
 
172
  image, size.tolist(), camera, crop_and_center=True
173
  )
174
 
175
+ return {
176
+ "image": image,
177
+ "map": torch.from_numpy(canvas.raster).long(),
178
+ "camera": camera.float(),
179
+ "valid": valid,
180
+ }
181
 
182
  def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs):
183
  data = self.prepare_data(image, camera, canvas, **kwargs)
requirements/demo.txt CHANGED
@@ -16,5 +16,5 @@ rtree
16
  scikit-learn
17
  geopy
18
  exifread
19
- gradio_client
20
  urllib3>=2
 
 
16
  scikit-learn
17
  geopy
18
  exifread
 
19
  urllib3>=2
20
+ perspective2d @ git+https://github.com/jinlinyi/PerspectiveFields.git