AniDoc / LightGlue /benchmark.py
fffiloni's picture
Migrated from GitHub
c705408 verified
raw
history blame
8.95 kB
# Benchmark script for LightGlue on real images
import argparse
import time
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch._dynamo
from lightglue import LightGlue, SuperPoint
from lightglue.utils import load_image
torch.set_grad_enabled(False)
def measure(matcher, data, device="cuda", r=100):
timings = np.zeros((r, 1))
if device.type == "cuda":
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
# warmup
for _ in range(10):
_ = matcher(data)
# measurements
with torch.no_grad():
for rep in range(r):
if device.type == "cuda":
starter.record()
_ = matcher(data)
ender.record()
# sync gpu
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
else:
start = time.perf_counter()
_ = matcher(data)
curr_time = (time.perf_counter() - start) * 1e3
timings[rep] = curr_time
mean_syn = np.sum(timings) / r
std_syn = np.std(timings)
return {"mean": mean_syn, "std": std_syn}
def print_as_table(d, title, cnames):
print()
header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
print(header)
print("-" * len(header))
for k, l in d.items():
print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
parser.add_argument(
"--device",
choices=["auto", "cuda", "cpu", "mps"],
default="auto",
help="device to benchmark on",
)
parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
parser.add_argument(
"--no_flash", action="store_true", help="disable FlashAttention"
)
parser.add_argument(
"--no_prune_thresholds",
action="store_true",
help="disable pruning thresholds (i.e. always do pruning)",
)
parser.add_argument(
"--add_superglue",
action="store_true",
help="add SuperGlue to the benchmark (requires hloc)",
)
parser.add_argument(
"--measure", default="time", choices=["time", "log-time", "throughput"]
)
parser.add_argument(
"--repeat", "--r", type=int, default=100, help="repetitions of measurements"
)
parser.add_argument(
"--num_keypoints",
nargs="+",
type=int,
default=[256, 512, 1024, 2048, 4096],
help="number of keypoints (list separated by spaces)",
)
parser.add_argument(
"--matmul_precision", default="highest", choices=["highest", "high", "medium"]
)
parser.add_argument(
"--save", default=None, type=str, help="path where figure should be saved"
)
args = parser.parse_intermixed_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.device != "auto":
device = torch.device(args.device)
print("Running benchmark on device:", device)
images = Path("assets")
inputs = {
"easy": (
load_image(images / "DSC_0411.JPG"),
load_image(images / "DSC_0410.JPG"),
),
"difficult": (
load_image(images / "sacre_coeur1.jpg"),
load_image(images / "sacre_coeur2.jpg"),
),
}
configs = {
"LightGlue-full": {
"depth_confidence": -1,
"width_confidence": -1,
},
# 'LG-prune': {
# 'width_confidence': -1,
# },
# 'LG-depth': {
# 'depth_confidence': -1,
# },
"LightGlue-adaptive": {},
}
if args.compile:
configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
sg_configs = {
# 'SuperGlue': {},
"SuperGlue-fast": {"sinkhorn_iterations": 5}
}
torch.set_float32_matmul_precision(args.matmul_precision)
results = {k: defaultdict(list) for k, v in inputs.items()}
extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
extractor = extractor.eval().to(device)
figsize = (len(inputs) * 4.5, 4.5)
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
axes = axes if len(inputs) > 1 else [axes]
fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
for title, ax in zip(inputs.keys(), axes):
ax.set_xscale("log", base=2)
bases = [2**x for x in range(7, 16)]
ax.set_xticks(bases, bases)
ax.grid(which="major")
if args.measure == "log-time":
ax.set_yscale("log")
yticks = [10**x for x in range(6)]
ax.set_yticks(yticks, yticks)
mpos = [10**x * i for x in range(6) for i in range(2, 10)]
mlabel = [
10**x * i if i in [2, 5] else None
for x in range(6)
for i in range(2, 10)
]
ax.set_yticks(mpos, mlabel, minor=True)
ax.grid(which="minor", linewidth=0.2)
ax.set_title(title)
ax.set_xlabel("# keypoints")
if args.measure == "throughput":
ax.set_ylabel("Throughput [pairs/s]")
else:
ax.set_ylabel("Latency [ms]")
for name, conf in configs.items():
print("Run benchmark for:", name)
torch.cuda.empty_cache()
matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
if args.no_prune_thresholds:
matcher.pruning_keypoint_thresholds = {
k: -1 for k in matcher.pruning_keypoint_thresholds
}
matcher = matcher.eval().to(device)
if name.endswith("compile"):
import torch._dynamo
torch._dynamo.reset() # avoid buffer overflow
matcher.compile()
for pair_name, ax in zip(inputs.keys(), axes):
image0, image1 = [x.to(device) for x in inputs[pair_name]]
runtimes = []
for num_kpts in args.num_keypoints:
extractor.conf.max_num_keypoints = num_kpts
feats0 = extractor.extract(image0)
feats1 = extractor.extract(image1)
runtime = measure(
matcher,
{"image0": feats0, "image1": feats1},
device=device,
r=args.repeat,
)["mean"]
results[pair_name][name].append(
1000 / runtime if args.measure == "throughput" else runtime
)
ax.plot(
args.num_keypoints, results[pair_name][name], label=name, marker="o"
)
del matcher, feats0, feats1
if args.add_superglue:
from hloc.matchers.superglue import SuperGlue
for name, conf in sg_configs.items():
print("Run benchmark for:", name)
matcher = SuperGlue(conf)
matcher = matcher.eval().to(device)
for pair_name, ax in zip(inputs.keys(), axes):
image0, image1 = [x.to(device) for x in inputs[pair_name]]
runtimes = []
for num_kpts in args.num_keypoints:
extractor.conf.max_num_keypoints = num_kpts
feats0 = extractor.extract(image0)
feats1 = extractor.extract(image1)
data = {
"image0": image0[None],
"image1": image1[None],
**{k + "0": v for k, v in feats0.items()},
**{k + "1": v for k, v in feats1.items()},
}
data["scores0"] = data["keypoint_scores0"]
data["scores1"] = data["keypoint_scores1"]
data["descriptors0"] = (
data["descriptors0"].transpose(-1, -2).contiguous()
)
data["descriptors1"] = (
data["descriptors1"].transpose(-1, -2).contiguous()
)
runtime = measure(matcher, data, device=device, r=args.repeat)[
"mean"
]
results[pair_name][name].append(
1000 / runtime if args.measure == "throughput" else runtime
)
ax.plot(
args.num_keypoints, results[pair_name][name], label=name, marker="o"
)
del matcher, data, image0, image1, feats0, feats1
for name, runtimes in results.items():
print_as_table(runtimes, name, args.num_keypoints)
axes[0].legend()
fig.tight_layout()
if args.save:
plt.savefig(args.save, dpi=fig.dpi)
plt.show()