|
import logging |
|
import shutil |
|
from pathlib import Path |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def download_and_extract_benchmark(name: str, url: Path, output: Path) -> None: |
|
benchmark_dir = output / name |
|
if not output.exists(): |
|
output.mkdir(parents=True) |
|
|
|
if benchmark_dir.exists(): |
|
logger.info(f"Benchmark {name} already exists at {benchmark_dir}, skipping download.") |
|
return |
|
|
|
if name == "stanford2d3d": |
|
|
|
txt = "\n" + "#" * 108 + "\n\n" |
|
txt += "To download the Stanford2D3D dataset, you must agree to the terms of use:\n\n" |
|
txt += ( |
|
"https://docs.google.com/forms/d/e/" |
|
+ "1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1\n\n" |
|
) |
|
txt += "#" * 108 + "\n\n" |
|
txt += "Did you fill out the data sharing and usage terms? [y/n] " |
|
choice = input(txt) |
|
if choice.lower() != "y": |
|
raise ValueError( |
|
"You must agree to the terms of use to download the Stanford2D3D dataset." |
|
) |
|
|
|
zip_file = output / f"{name}.zip" |
|
|
|
if not zip_file.exists(): |
|
logger.info(f"Downloading benchmark {name} to {zip_file} from {url}.") |
|
torch.hub.download_url_to_file(url, zip_file) |
|
|
|
logger.info(f"Extracting benchmark {name} in {output}.") |
|
shutil.unpack_archive(zip_file, output, format="zip") |
|
zip_file.unlink() |
|
|
|
|
|
def check_keys_recursive(d, pattern): |
|
if isinstance(pattern, dict): |
|
{check_keys_recursive(d[k], v) for k, v in pattern.items()} |
|
else: |
|
for k in pattern: |
|
assert k in d.keys() |
|
|
|
|
|
def plot_scatter_grid( |
|
results, x_keys, y_keys, name=None, diag=False, ax=None, line_idx=0, show_means=True |
|
): |
|
if ax is None: |
|
N, M = len(y_keys), len(x_keys) |
|
fig, ax = plt.subplots(N, M, figsize=(M * 6, N * 5)) |
|
|
|
if N == 1: |
|
ax = np.array(ax) |
|
ax = ax.reshape(1, -1) |
|
|
|
if M == 1: |
|
ax = np.array(ax) |
|
ax = ax.reshape(-1, 1) |
|
else: |
|
fig = None |
|
|
|
for j, kx in enumerate(x_keys): |
|
for i, ky in enumerate(y_keys): |
|
ax[i, j].scatter( |
|
results[kx], |
|
results[ky], |
|
s=1, |
|
alpha=0.5, |
|
label=name or None, |
|
) |
|
|
|
ax[i, j].set_xlabel(f"{' '.join(kx.split('_')).title()}") |
|
ax[i, j].set_ylabel(f"{' '.join(ky.split('_')).title()}") |
|
|
|
low = min(ax[i, j].get_xlim()[0], ax[i, j].get_ylim()[0]) |
|
high = max(ax[i, j].get_xlim()[1], ax[i, j].get_ylim()[1]) |
|
if diag == "all" or (i == j and diag): |
|
ax[i, j].plot([low, high], [low, high], ls="--", c="red", label="y=x") |
|
|
|
if name or diag == "all" or (i == j and diag): |
|
ax[i, j].legend() |
|
|
|
if not show_means: |
|
return fig, ax |
|
|
|
means = {"y": {}, "x": {}} |
|
for kx in x_keys: |
|
for ky in y_keys: |
|
means["x"][kx] = np.mean(results[kx]) |
|
means["y"][ky] = np.mean(results[ky]) |
|
|
|
for j, kx in enumerate(x_keys): |
|
for i, ky in enumerate(y_keys): |
|
xlim = np.min(results[kx]), np.max(results[kx]) |
|
ylim = np.min(results[ky]), np.max(results[ky]) |
|
means_x = [means["x"][kx]] |
|
means_y = [means["y"][ky]] |
|
color = plt.cm.tab10(line_idx) |
|
ax[i, j].vlines(means_x, *ylim, colors=[color]) |
|
ax[i, j].hlines(means_y, *xlim, colors=[color]) |
|
|
|
return fig, ax |
|
|