veichta commited on
Commit
0fc7b40
·
verified ·
1 Parent(s): 1e1a6e8

Upload folder using huggingface_hub

Browse files
geocalib/extractor.py CHANGED
@@ -22,14 +22,9 @@ class GeoCalib(nn.Module):
22
  weights (str): trained variant, "pinhole" (default) or "distorted".
23
  """
24
  super().__init__()
25
- if weights == "pinhole":
26
- url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
27
- elif weights == "distorted":
28
- url = (
29
- "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
30
- )
31
- else:
32
  raise ValueError(f"Unknown weights: {weights}")
 
33
 
34
  # load checkpoint
35
  model_dir = f"{torch.hub.get_dir()}/geocalib"
 
22
  weights (str): trained variant, "pinhole" (default) or "distorted".
23
  """
24
  super().__init__()
25
+ if weights not in {"pinhole", "distorted"}:
 
 
 
 
 
 
26
  raise ValueError(f"Unknown weights: {weights}")
27
+ url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"
28
 
29
  # load checkpoint
30
  model_dir = f"{torch.hub.get_dir()}/geocalib"
gradio_app.py CHANGED
@@ -8,7 +8,7 @@ import numpy as np
8
  import spaces
9
  import torch
10
 
11
- from geocalib import viz2d
12
  from geocalib.camera import camera_models
13
  from geocalib.extractor import GeoCalib
14
  from geocalib.perspective_fields import get_perspective_field
@@ -77,7 +77,9 @@ def format_output(results):
77
  @spaces.GPU(duration=10)
78
  def inference(img, camera_model):
79
  out = model.calibrate(img.to(device), camera_model=camera_model)
80
- save_keys = ["camera", "gravity"] + [f"{k}_uncertainty" for k in ["roll", "pitch", "vfov", "focal"]]
 
 
81
  res = {k: v.cpu() for k, v in out.items() if k in save_keys}
82
  # not converting to numpy results in gpu abort
83
  res["up_confidence"] = out["up_confidence"].cpu().numpy()
@@ -100,10 +102,9 @@ def process_results(
100
  raise gr.Error("Please upload an image first.")
101
 
102
  img = model.load_image(image_path)
103
- print("Running inference...")
104
  start = time()
105
  inference_result = inference(img, camera_model)
106
- print(f"Done ({time() - start:.2f}s)")
107
  inference_result["image"] = img.cpu()
108
 
109
  if inference_result is None:
@@ -158,7 +159,9 @@ def update_plot(
158
  viz2d.plot_confidences([torch.tensor(inference_result["up_confidence"][0])], axes=[ax[0]])
159
 
160
  if plot_latitude_confidence:
161
- viz2d.plot_confidences([torch.tensor(inference_result["latitude_confidence"][0])], axes=[ax[0]])
 
 
162
 
163
  fig.canvas.draw()
164
  img = np.array(fig.canvas.renderer.buffer_rgba())
 
8
  import spaces
9
  import torch
10
 
11
+ from geocalib import logger, viz2d
12
  from geocalib.camera import camera_models
13
  from geocalib.extractor import GeoCalib
14
  from geocalib.perspective_fields import get_perspective_field
 
77
  @spaces.GPU(duration=10)
78
  def inference(img, camera_model):
79
  out = model.calibrate(img.to(device), camera_model=camera_model)
80
+ save_keys = ["camera", "gravity"] + [
81
+ f"{k}_uncertainty" for k in ["roll", "pitch", "vfov", "focal"]
82
+ ]
83
  res = {k: v.cpu() for k, v in out.items() if k in save_keys}
84
  # not converting to numpy results in gpu abort
85
  res["up_confidence"] = out["up_confidence"].cpu().numpy()
 
102
  raise gr.Error("Please upload an image first.")
103
 
104
  img = model.load_image(image_path)
 
105
  start = time()
106
  inference_result = inference(img, camera_model)
107
+ logger.info(f"Calibration took {time() - start:.2f} sec. ({camera_model})")
108
  inference_result["image"] = img.cpu()
109
 
110
  if inference_result is None:
 
159
  viz2d.plot_confidences([torch.tensor(inference_result["up_confidence"][0])], axes=[ax[0]])
160
 
161
  if plot_latitude_confidence:
162
+ viz2d.plot_confidences(
163
+ [torch.tensor(inference_result["latitude_confidence"][0])], axes=[ax[0]]
164
+ )
165
 
166
  fig.canvas.draw()
167
  img = np.array(fig.canvas.renderer.buffer_rgba())
siclib/models/extractor.py CHANGED
@@ -22,14 +22,9 @@ class GeoCalib(nn.Module):
22
  weights (str, optional): Weights to load. Defaults to "pinhole".
23
  """
24
  super().__init__()
25
- if weights == "pinhole":
26
- url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
27
- elif weights == "distorted":
28
- url = (
29
- "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
30
- )
31
- else:
32
  raise ValueError(f"Unknown weights: {weights}")
 
33
 
34
  # load checkpoint
35
  model_dir = f"{torch.hub.get_dir()}/geocalib"
 
22
  weights (str, optional): Weights to load. Defaults to "pinhole".
23
  """
24
  super().__init__()
25
+ if weights not in {"pinhole", "distorted"}:
 
 
 
 
 
 
26
  raise ValueError(f"Unknown weights: {weights}")
27
+ url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"
28
 
29
  # load checkpoint
30
  model_dir = f"{torch.hub.get_dir()}/geocalib"