Spaces:
Sleeping
Sleeping
import collections.abc as collections | |
from pathlib import Path | |
from types import SimpleNamespace | |
from typing import Callable, List, Optional, Tuple, Union | |
import cv2 | |
import kornia | |
import numpy as np | |
import torch | |
class ImagePreprocessor: | |
default_conf = { | |
"resize": None, # target edge length, None for no resizing | |
"side": "long", | |
"interpolation": "bilinear", | |
"align_corners": None, | |
"antialias": True, | |
} | |
def __init__(self, **conf) -> None: | |
super().__init__() | |
self.conf = {**self.default_conf, **conf} | |
self.conf = SimpleNamespace(**self.conf) | |
def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Resize and preprocess an image, return image and resize scale""" | |
h, w = img.shape[-2:] | |
if self.conf.resize is not None: | |
img = kornia.geometry.transform.resize( | |
img, | |
self.conf.resize, | |
side=self.conf.side, | |
antialias=self.conf.antialias, | |
align_corners=self.conf.align_corners, | |
) | |
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) | |
return img, scale | |
def map_tensor(input_, func: Callable): | |
string_classes = (str, bytes) | |
if isinstance(input_, string_classes): | |
return input_ | |
elif isinstance(input_, collections.Mapping): | |
return {k: map_tensor(sample, func) for k, sample in input_.items()} | |
elif isinstance(input_, collections.Sequence): | |
return [map_tensor(sample, func) for sample in input_] | |
elif isinstance(input_, torch.Tensor): | |
return func(input_) | |
else: | |
return input_ | |
def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True): | |
"""Move batch (dict) to device""" | |
def _func(tensor): | |
return tensor.to(device=device, non_blocking=non_blocking).detach() | |
return map_tensor(batch, _func) | |
def rbd(data: dict) -> dict: | |
"""Remove batch dimension from elements in data""" | |
return { | |
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v | |
for k, v in data.items() | |
} | |
def read_image(path: Path, grayscale: bool = False) -> np.ndarray: | |
"""Read an image from path as RGB or grayscale""" | |
if not Path(path).exists(): | |
raise FileNotFoundError(f"No image at path {path}.") | |
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR | |
image = cv2.imread(str(path), mode) | |
if image is None: | |
raise IOError(f"Could not read image at {path}.") | |
if not grayscale: | |
image = image[..., ::-1] | |
return image | |
def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: | |
"""Normalize the image tensor and reorder the dimensions.""" | |
if image.ndim == 3: | |
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW | |
elif image.ndim == 2: | |
image = image[None] # add channel axis | |
else: | |
raise ValueError(f"Not an image: {image.shape}") | |
return torch.tensor(image / 255.0, dtype=torch.float) | |
def resize_image( | |
image: np.ndarray, | |
size: Union[List[int], int], | |
fn: str = "max", | |
interp: Optional[str] = "area", | |
) -> np.ndarray: | |
"""Resize an image to a fixed size, or according to max or min edge.""" | |
h, w = image.shape[:2] | |
fn = {"max": max, "min": min}[fn] | |
if isinstance(size, int): | |
scale = size / fn(h, w) | |
h_new, w_new = int(round(h * scale)), int(round(w * scale)) | |
scale = (w_new / w, h_new / h) | |
elif isinstance(size, (tuple, list)): | |
h_new, w_new = size | |
scale = (w_new / w, h_new / h) | |
else: | |
raise ValueError(f"Incorrect new size: {size}") | |
mode = { | |
"linear": cv2.INTER_LINEAR, | |
"cubic": cv2.INTER_CUBIC, | |
"nearest": cv2.INTER_NEAREST, | |
"area": cv2.INTER_AREA, | |
}[interp] | |
return cv2.resize(image, (w_new, h_new), interpolation=mode), scale | |
def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor: | |
image = read_image(path) | |
if resize is not None: | |
image, _ = resize_image(image, resize, **kwargs) | |
return numpy_image_to_torch(image) | |
class Extractor(torch.nn.Module): | |
def __init__(self, **conf): | |
super().__init__() | |
self.conf = SimpleNamespace(**{**self.default_conf, **conf}) | |
def extract(self, img: torch.Tensor, **conf) -> dict: | |
"""Perform extraction with online resizing""" | |
if img.dim() == 3: | |
img = img[None] # add batch dim | |
assert img.dim() == 4 and img.shape[0] == 1 | |
shape = img.shape[-2:][::-1] | |
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) | |
feats = self.forward({"image": img}) | |
feats["image_size"] = torch.tensor(shape)[None].to(img).float() | |
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 | |
return feats | |
def match_pair( | |
extractor, | |
matcher, | |
image0: torch.Tensor, | |
image1: torch.Tensor, | |
device: str = "cpu", | |
**preprocess, | |
): | |
"""Match a pair of images (image0, image1) with an extractor and matcher""" | |
feats0 = extractor.extract(image0, **preprocess) | |
feats1 = extractor.extract(image1, **preprocess) | |
matches01 = matcher({"image0": feats0, "image1": feats1}) | |
data = [feats0, feats1, matches01] | |
# remove batch dim and move to target device | |
feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data] | |
return feats0, feats1, matches01 | |