GeoCalib / siclib /eval /utils.py
veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
3.71 kB
import logging
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
def download_and_extract_benchmark(name: str, url: Path, output: Path) -> None:
benchmark_dir = output / name
if not output.exists():
output.mkdir(parents=True)
if benchmark_dir.exists():
logger.info(f"Benchmark {name} already exists at {benchmark_dir}, skipping download.")
return
if name == "stanford2d3d":
# prompt user to sign data sharing and usage terms
txt = "\n" + "#" * 108 + "\n\n"
txt += "To download the Stanford2D3D dataset, you must agree to the terms of use:\n\n"
txt += (
"https://docs.google.com/forms/d/e/"
+ "1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1\n\n"
)
txt += "#" * 108 + "\n\n"
txt += "Did you fill out the data sharing and usage terms? [y/n] "
choice = input(txt)
if choice.lower() != "y":
raise ValueError(
"You must agree to the terms of use to download the Stanford2D3D dataset."
)
zip_file = output / f"{name}.zip"
if not zip_file.exists():
logger.info(f"Downloading benchmark {name} to {zip_file} from {url}.")
torch.hub.download_url_to_file(url, zip_file)
logger.info(f"Extracting benchmark {name} in {output}.")
shutil.unpack_archive(zip_file, output, format="zip")
zip_file.unlink()
def check_keys_recursive(d, pattern):
if isinstance(pattern, dict):
{check_keys_recursive(d[k], v) for k, v in pattern.items()}
else:
for k in pattern:
assert k in d.keys()
def plot_scatter_grid(
results, x_keys, y_keys, name=None, diag=False, ax=None, line_idx=0, show_means=True
): # sourcery skip: low-code-quality
if ax is None:
N, M = len(y_keys), len(x_keys)
fig, ax = plt.subplots(N, M, figsize=(M * 6, N * 5))
if N == 1:
ax = np.array(ax)
ax = ax.reshape(1, -1)
if M == 1:
ax = np.array(ax)
ax = ax.reshape(-1, 1)
else:
fig = None
for j, kx in enumerate(x_keys):
for i, ky in enumerate(y_keys):
ax[i, j].scatter(
results[kx],
results[ky],
s=1,
alpha=0.5,
label=name or None,
)
ax[i, j].set_xlabel(f"{' '.join(kx.split('_')).title()}")
ax[i, j].set_ylabel(f"{' '.join(ky.split('_')).title()}")
low = min(ax[i, j].get_xlim()[0], ax[i, j].get_ylim()[0])
high = max(ax[i, j].get_xlim()[1], ax[i, j].get_ylim()[1])
if diag == "all" or (i == j and diag):
ax[i, j].plot([low, high], [low, high], ls="--", c="red", label="y=x")
if name or diag == "all" or (i == j and diag):
ax[i, j].legend()
if not show_means:
return fig, ax
means = {"y": {}, "x": {}}
for kx in x_keys:
for ky in y_keys:
means["x"][kx] = np.mean(results[kx])
means["y"][ky] = np.mean(results[ky])
for j, kx in enumerate(x_keys):
for i, ky in enumerate(y_keys):
xlim = np.min(results[kx]), np.max(results[kx])
ylim = np.min(results[ky]), np.max(results[ky])
means_x = [means["x"][kx]]
means_y = [means["y"][ky]]
color = plt.cm.tab10(line_idx)
ax[i, j].vlines(means_x, *ylim, colors=[color])
ax[i, j].hlines(means_y, *xlim, colors=[color])
return fig, ax