Transformers documentation

EfficientLoFTR

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.53.3).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

EfficientLoFTR

PyTorch

Overview

The EfficientLoFTR model was proposed in Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed by Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou.

This model consists of matching two images together by finding pixel correspondences. It can be used to estimate the pose between them. This model is useful for tasks such as image matching, homography estimation, etc.

The abstract from the paper is the following:

We present a novel method for efficiently producing semidense matches across images. Previous detector-free matcher LoFTR has shown remarkable matching capability in handling large-viewpoint change and texture-poor scenarios but suffers from low efficiency. We revisit its design choices and derive multiple improvements for both efficiency and accuracy. One key observation is that performing the transformer over the entire feature map is redundant due to shared local information, therefore we propose an aggregated attention mechanism with adaptive token selection for efficiency. Furthermore, we find spatial variance exists in LoFTR’s fine correlation module, which is adverse to matching accuracy. A novel two-stage correlation layer is proposed to achieve accurate subpixel correspondences for accuracy improvement. Our efficiency optimized model is ∼ 2.5× faster than LoFTR which can even surpass state-of-the-art efficient sparse matching pipeline SuperPoint + LightGlue. Moreover, extensive experiments show that our method can achieve higher accuracy compared with competitive semi-dense matchers, with considerable efficiency benefits. This opens up exciting prospects for large-scale or latency-sensitive applications such as image retrieval and 3D reconstruction. Project page: https://zju3dv.github.io/efficientloftr/.

How to use

Here is a quick example of using the model.

import torch

from transformers import AutoImageProcessor, AutoModelForKeypointMatching
from transformers.image_utils import load_image


image1 = load_image("https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg")
image2 = load_image("https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg")

images = [image1, image2]

processor = AutoImageProcessor.from_pretrained("stevenbucaille/efficientloftr")
model = AutoModelForKeypointMatching.from_pretrained("stevenbucaille/efficientloftr")

inputs = processor(images, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)

You can use the post_process_keypoint_matching method from the ImageProcessor to get the keypoints and matches in a more readable format:

image_sizes = [[(image.height, image.width) for image in images]]
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
for i, output in enumerate(outputs):
    print("For the image pair", i)
    for keypoint0, keypoint1, matching_score in zip(
            output["keypoints0"], output["keypoints1"], output["matching_scores"]
    ):
        print(
            f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}."
        )

From the post processed outputs, you can visualize the matches between the two images using the following code:

images_with_matching = processor.visualize_keypoint_matching(images, outputs)

image/png

This model was contributed by stevenbucaille. The original code can be found here.

EfficientLoFTRConfig

class transformers.EfficientLoFTRConfig

< >

( stage_num_blocks: typing.Optional[list[int]] = None out_features: typing.Optional[list[int]] = None stage_stride: typing.Optional[list[int]] = None hidden_size: int = 256 activation_function: str = 'relu' q_aggregation_kernel_size: int = 4 kv_aggregation_kernel_size: int = 4 q_aggregation_stride: int = 4 kv_aggregation_stride: int = 4 num_attention_layers: int = 4 num_attention_heads: int = 8 attention_dropout: float = 0.0 attention_bias: bool = False mlp_activation_function: str = 'leaky_relu' coarse_matching_skip_softmax: bool = False coarse_matching_threshold: float = 0.2 coarse_matching_temperature: float = 0.1 coarse_matching_border_removal: int = 2 fine_kernel_size: int = 8 batch_norm_eps: float = 1e-05 embedding_size: typing.Optional[list[int]] = None rope_theta: float = 10000.0 partial_rotary_factor: float = 4.0 rope_scaling: typing.Optional[dict] = None fine_matching_slice_dim: int = 8 fine_matching_regress_temperature: float = 10.0 initializer_range: float = 0.02 **kwargs )

Parameters

  • stage_num_blocks (List, optional, defaults to [1, 2, 4, 14]) — The number of blocks in each stages
  • out_features (List, optional, defaults to [64, 64, 128, 256]) — The number of channels in each stage
  • stage_stride (List, optional, defaults to [2, 1, 2, 2]) — The stride used in each stage
  • hidden_size (int, optional, defaults to 256) — The dimension of the descriptors.
  • activation_function (str, optional, defaults to "relu") — The activation function used in the backbone
  • q_aggregation_kernel_size (int, optional, defaults to 4) — The kernel size of the aggregation of query states in the fusion network
  • kv_aggregation_kernel_size (int, optional, defaults to 4) — The kernel size of the aggregation of key and value states in the fusion network
  • q_aggregation_stride (int, optional, defaults to 4) — The stride of the aggregation of query states in the fusion network
  • kv_aggregation_stride (int, optional, defaults to 4) — The stride of the aggregation of key and value states in the fusion network
  • num_attention_layers (int, optional, defaults to 4) — Number of attention layers in the LocalFeatureTransformer
  • num_attention_heads (int, optional, defaults to 8) — The number of heads in the GNN layers.
  • attention_dropout (float, optional, defaults to 0.0) — The dropout ratio for the attention probabilities.
  • attention_bias (bool, optional, defaults to False) — Whether to use a bias in the query, key, value and output projection layers during attention.
  • mlp_activation_function (str, optional, defaults to "leaky_relu") — Activation function used in the attention mlp layer.
  • coarse_matching_skip_softmax (bool, optional, defaults to False) — Whether to skip softmax or not at the coarse matching step.
  • coarse_matching_threshold (float, optional, defaults to 0.2) — The threshold for the minimum score required for a match.
  • coarse_matching_temperature (float, optional, defaults to 0.1) — The temperature to apply to the coarse similarity matrix
  • coarse_matching_border_removal (int, optional, defaults to 2) — The size of the border to remove during coarse matching
  • fine_kernel_size (int, optional, defaults to 8) — Kernel size used for the fine feature matching
  • batch_norm_eps (float, optional, defaults to 1e-05) — The epsilon used by the batch normalization layers.
  • embedding_size (List, optional, defaults to [15, 20]) — The size (height, width) of the embedding for the position embeddings.
  • rope_theta (float, optional, defaults to 10000.0) — The base period of the RoPE embeddings.
  • partial_rotary_factor (float, optional, defaults to 4.0) — Dim factor for the RoPE embeddings, in EfficientLoFTR, frequencies should be generated for the whole hidden_size, so this factor is used to compensate.
  • rope_scaling (Dict, optional) — Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer max_position_embeddings, we recommend you to update this value accordingly. Expected contents: rope_type (str): The sub-variant of RoPE to use. Can be one of [‘default’, ‘linear’, ‘dynamic’, ‘yarn’, ‘longrope’, ‘llama3’, ‘2d’], with ‘default’ being the original RoPE implementation. dim (int): The dimension of the RoPE embeddings.
  • fine_matching_slice_dim (int, optional, defaults to 8) — The size of the slice used to divide the fine features for the first and second fine matching stages.
  • fine_matching_regress_temperature (float, optional, defaults to 10.0) — The temperature to apply to the fine similarity matrix
  • initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

This is the configuration class to store the configuration of a EffientLoFTRFromKeypointMatching. It is used to instantiate a EfficientLoFTR model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the EfficientLoFTR zju-community/efficientloftr architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Examples:

>>> from transformers import EfficientLoFTRConfig, EfficientLoFTRForKeypointMatching

>>> # Initializing a EfficientLoFTR configuration
>>> configuration = EfficientLoFTRConfig()

>>> # Initializing a model from the EfficientLoFTR configuration
>>> model = EfficientLoFTRForKeypointMatching(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

EfficientLoFTRImageProcessor

class transformers.EfficientLoFTRImageProcessor

< >

( do_resize: bool = True size: typing.Optional[dict[str, int]] = None resample: Resampling = <Resampling.BILINEAR: 2> do_rescale: bool = True rescale_factor: float = 0.00392156862745098 do_grayscale: bool = True **kwargs )

Parameters

  • do_resize (bool, optional, defaults to True) — Controls whether to resize the image’s (height, width) dimensions to the specified size. Can be overriden by do_resize in the preprocess method.
  • size (Dict[str, int] optional, defaults to {"height" -- 480, "width": 640}): Resolution of the output image after resize is applied. Only has an effect if do_resize is set to True. Can be overriden by size in the preprocess method.
  • resample (PILImageResampling, optional, defaults to Resampling.BILINEAR) — Resampling filter to use if resizing the image. Can be overriden by resample in the preprocess method.
  • do_rescale (bool, optional, defaults to True) — Whether to rescale the image by the specified scale rescale_factor. Can be overriden by do_rescale in the preprocess method.
  • rescale_factor (int or float, optional, defaults to 1/255) — Scale factor to use if rescaling the image. Can be overriden by rescale_factor in the preprocess method.
  • do_grayscale (bool, optional, defaults to True) — Whether to convert the image to grayscale. Can be overriden by do_grayscale in the preprocess method.

Constructs a EfficientLoFTR image processor.

post_process_keypoint_matching

< >

( outputs: KeypointMatchingOutput target_sizes: typing.Union[transformers.utils.generic.TensorType, list[tuple]] threshold: float = 0.0 ) List[Dict]

Parameters

  • outputs (KeypointMatchingOutput) — Raw outputs of the model.
  • target_sizes (torch.Tensor or List[Tuple[Tuple[int, int]]], optional) — Tensor of shape (batch_size, 2, 2) or list of tuples of tuples (Tuple[int, int]) containing the target size (height, width) of each image in the batch. This must be the original image size (before any processing).
  • threshold (float, optional, defaults to 0.0) — Threshold to filter out the matches with low scores.

Returns

List[Dict]

A list of dictionaries, each dictionary containing the keypoints in the first and second image of the pair, the matching scores and the matching indices.

Converts the raw output of KeypointMatchingOutput into lists of keypoints, scores and descriptors with coordinates absolute to the original image sizes.

preprocess

< >

( images do_resize: typing.Optional[bool] = None size: typing.Optional[dict[str, int]] = None resample: Resampling = None do_rescale: typing.Optional[bool] = None rescale_factor: typing.Optional[float] = None do_grayscale: typing.Optional[bool] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: ChannelDimension = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None **kwargs )

Parameters

  • images (ImageInput) — Image pairs to preprocess. Expects either a list of 2 images or a list of list of 2 images list with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set do_rescale=False.
  • do_resize (bool, optional, defaults to self.do_resize) — Whether to resize the image.
  • size (dict[str, int], optional, defaults to self.size) — Size of the output image after resize has been applied. If size["shortest_edge"] >= 384, the image is resized to (size["shortest_edge"], size["shortest_edge"]). Otherwise, the smaller edge of the image will be matched to int(size["shortest_edge"]/ crop_pct), after which the image is cropped to (size["shortest_edge"], size["shortest_edge"]). Only has an effect if do_resize is set to True.
  • resample (PILImageResampling, optional, defaults to self.resample) — Resampling filter to use if resizing the image. This can be one of PILImageResampling, filters. Only has an effect if do_resize is set to True.
  • do_rescale (bool, optional, defaults to self.do_rescale) — Whether to rescale the image values between [0 - 1].
  • rescale_factor (float, optional, defaults to self.rescale_factor) — Rescale factor to rescale the image by if do_rescale is set to True.
  • do_grayscale (bool, optional, defaults to self.do_grayscale) — Whether to convert the image to grayscale.
  • return_tensors (str or TensorType, optional) — The type of tensors to return. Can be one of:
    • Unset: Return a list of np.ndarray.
    • TensorType.TENSORFLOW or 'tf': Return a batch of type tf.Tensor.
    • TensorType.PYTORCH or 'pt': Return a batch of type torch.Tensor.
    • TensorType.NUMPY or 'np': Return a batch of type np.ndarray.
    • TensorType.JAX or 'jax': Return a batch of type jax.numpy.ndarray.
  • data_format (ChannelDimension or str, optional, defaults to ChannelDimension.FIRST) — The channel dimension format for the output image. Can be one of:
    • "channels_first" or ChannelDimension.FIRST: image in (num_channels, height, width) format.
    • "channels_last" or ChannelDimension.LAST: image in (height, width, num_channels) format.
    • Unset: Use the channel dimension format of the input image.
  • input_data_format (ChannelDimension or str, optional) — The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of:
    • "channels_first" or ChannelDimension.FIRST: image in (num_channels, height, width) format.
    • "channels_last" or ChannelDimension.LAST: image in (height, width, num_channels) format.
    • "none" or ChannelDimension.NONE: image in (height, width) format.

Preprocess an image or batch of images.

resize

< >

( image: ndarray size: dict data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None **kwargs )

Parameters

  • image (np.ndarray) — Image to resize.
  • size (dict[str, int]) — Dictionary of the form {"height": int, "width": int}, specifying the size of the output image.
  • data_format (ChannelDimension or str, optional) — The channel dimension format of the output image. If not provided, it will be inferred from the input image. Can be one of:
    • "channels_first" or ChannelDimension.FIRST: image in (num_channels, height, width) format.
    • "channels_last" or ChannelDimension.LAST: image in (height, width, num_channels) format.
    • "none" or ChannelDimension.NONE: image in (height, width) format.
  • input_data_format (ChannelDimension or str, optional) — The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of:
    • "channels_first" or ChannelDimension.FIRST: image in (num_channels, height, width) format.
    • "channels_last" or ChannelDimension.LAST: image in (height, width, num_channels) format.
    • "none" or ChannelDimension.NONE: image in (height, width) format.

Resize an image.

visualize_keypoint_matching

< >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] keypoint_matching_output: list ) List[PIL.Image.Image]

Parameters

  • images (ImageInput) — Image pairs to plot. Same as EfficientLoFTRImageProcessor.preprocess. Expects either a list of 2 images or a list of list of 2 images list with pixel values ranging from 0 to 255.
  • outputs (List[Dict[str, torch.Tensor]]]) — A post processed keypoint matching output

Returns

List[PIL.Image.Image]

A list of PIL images, each containing the image pairs side by side with the detected keypoints as well as the matching between them.

Plots the image pairs side by side with the detected keypoints as well as the matching between them.

  • preprocess
  • post_process_keypoint_matching
  • visualize_keypoint_matching

EfficientLoFTRModel

class transformers.EfficientLoFTRModel

< >

( config: EfficientLoFTRConfig )

Parameters

  • config (EfficientLoFTRConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

EfficientLoFTR model taking images as inputs and outputting the features of the images.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: FloatTensor labels: typing.Optional[torch.LongTensor] = None **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) transformers.modeling_outputs.BackboneOutput or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size)) — The tensors corresponding to the input images. Pixel values can be obtained using EfficientLoFTRImageProcessor. See EfficientLoFTRImageProcessor.call() for details (processor_class uses EfficientLoFTRImageProcessor for processing images).
  • labels (torch.LongTensor of shape (batch_size, sequence_length), optional) — Labels for computing the masked language modeling loss. Indices should either be in [0, ..., config.vocab_size] or -100 (see input_ids docstring). Tokens with indices set to -100 are ignored (masked), the loss is only computed for the tokens with labels in [0, ..., config.vocab_size].

Returns

transformers.modeling_outputs.BackboneOutput or tuple(torch.FloatTensor)

A transformers.modeling_outputs.BackboneOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (EfficientLoFTRConfig) and inputs.

  • feature_maps (tuple(torch.FloatTensor) of shape (batch_size, num_channels, height, width)) — Feature maps of the stages.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) — Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size) or (batch_size, num_channels, height, width), depending on the backbone.

    Hidden-states of the model at the output of each stage plus the initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) — Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length). Only applicable if the backbone uses attention.

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

The EfficientLoFTRModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

>>> from transformers import AutoImageProcessor, AutoModel
>>> import torch
>>> from PIL import Image
>>> import requests

>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
>>> image1 = Image.open(requests.get(url, stream=True).raw)
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
>>> image2 = Image.open(requests.get(url, stream=True).raw)
>>> images = [image1, image2]

>>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
>>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")

>>> with torch.no_grad():
>>>     inputs = processor(images, return_tensors="pt")
>>>     outputs = model(**inputs)
  • forward

EfficientLoFTRForKeypointMatching

class transformers.EfficientLoFTRForKeypointMatching

< >

( config: EfficientLoFTRConfig )

Parameters

  • config (EfficientLoFTRConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

EfficientLoFTR model taking images as inputs and outputting the matching of them.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: FloatTensor labels: typing.Optional[torch.LongTensor] = None **kwargs: typing_extensions.Unpack[transformers.utils.generic.TransformersKwargs] ) transformers.models.efficientloftr.modeling_efficientloftr.KeypointMatchingOutput or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size)) — The tensors corresponding to the input images. Pixel values can be obtained using image_processor_class. See image_processor_class.__call__ for details (processor_class uses image_processor_class for processing images).
  • labels (torch.LongTensor of shape (batch_size, sequence_length), optional) — Labels for computing the masked language modeling loss. Indices should either be in [0, ..., config.vocab_size] or -100 (see input_ids docstring). Tokens with indices set to -100 are ignored (masked), the loss is only computed for the tokens with labels in [0, ..., config.vocab_size].

Returns

transformers.models.efficientloftr.modeling_efficientloftr.KeypointMatchingOutput or tuple(torch.FloatTensor)

A transformers.models.efficientloftr.modeling_efficientloftr.KeypointMatchingOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (None) and inputs.

  • matches (torch.FloatTensor of shape (batch_size, 2, num_matches)) — Index of keypoint matched in the other image.
  • matching_scores (torch.FloatTensor of shape (batch_size, 2, num_matches)) — Scores of predicted matches.
  • keypoints (torch.FloatTensor of shape (batch_size, num_keypoints, 2)) — Absolute (x, y) coordinates of predicted keypoints in a given image.
  • hidden_states (tuple[torch.FloatTensor, ...], optional) — Tuple of torch.FloatTensor (one for the output of each stage) of shape (batch_size, 2, num_channels, num_keypoints), returned when output_hidden_states=True is passed or when config.output_hidden_states=True)
  • attentions (tuple[torch.FloatTensor, ...], optional) — Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, 2, num_heads, num_keypoints, num_keypoints), returned when output_attentions=True is passed or when config.output_attentions=True)

The EfficientLoFTRForKeypointMatching forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

>>> from transformers import AutoImageProcessor, AutoModel
>>> import torch
>>> from PIL import Image
>>> import requests

>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
>>> image1 = Image.open(requests.get(url, stream=True).raw)
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
>>> image2 = Image.open(requests.get(url, stream=True).raw)
>>> images = [image1, image2]

>>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
>>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")

>>> with torch.no_grad():
>>>     inputs = processor(images, return_tensors="pt")
>>>     outputs = model(**inputs)
  • forward
< > Update on GitHub