# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Callable, Optional, Union import numpy as np import torch from torch import nn from torch.nn.utils.rnn import pad_sequence from ...configuration_utils import PretrainedConfig from ...image_utils import ImageInput, to_numpy_array from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TensorType, auto_docstring, is_matplotlib_available, logging from ...utils.generic import can_return_tuple from ..auto import CONFIG_MAPPING, AutoConfig from ..auto.modeling_auto import AutoModelForKeypointDetection from ..clip.modeling_clip import CLIPMLP from ..cohere.modeling_cohere import apply_rotary_pos_emb from ..llama.modeling_llama import LlamaAttention, eager_attention_forward from ..superglue.image_processing_superglue import SuperGlueImageProcessor, validate_and_format_image_pairs from ..superpoint import SuperPointConfig logger = logging.get_logger(__name__) class LightGlueConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the LightGlue [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`): The config object or dictionary of the keypoint detector. descriptor_dim (`int`, *optional*, defaults to 256): The dimension of the descriptors. num_hidden_layers (`int`, *optional*, defaults to 9): The number of self and cross attention layers. num_attention_heads (`int`, *optional*, defaults to 4): The number of heads in the multi-head attention. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. depth_confidence (`float`, *optional*, defaults to 0.95): The confidence threshold used to perform early stopping width_confidence (`float`, *optional*, defaults to 0.99): The confidence threshold used to prune points filter_threshold (`float`, *optional*, defaults to 0.1): The confidence threshold used to filter matches initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function to be used in the hidden layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. attention_bias (`bool`, *optional*, defaults to `True`): Whether to use a bias in the query, key, value and output projection layers during self-attention. trust_remote_code (`bool`, *optional*, defaults to `False`): Whether to trust remote code when using other models than SuperPoint as keypoint detector. Examples: ```python >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching >>> # Initializing a LightGlue style configuration >>> configuration = LightGlueConfig() >>> # Initializing a model from the LightGlue style configuration >>> model = LightGlueForKeypointMatching(configuration) >>> # Accessing the model configuration >>> configuration = model.config ``` """ model_type = "lightglue" sub_configs = {"keypoint_detector_config": AutoConfig} def __init__( self, keypoint_detector_config: SuperPointConfig = None, descriptor_dim: int = 256, num_hidden_layers: int = 9, num_attention_heads: int = 4, num_key_value_heads=None, depth_confidence: float = 0.95, width_confidence: float = 0.99, filter_threshold: float = 0.1, initializer_range: float = 0.02, hidden_act: str = "gelu", attention_dropout=0.0, attention_bias=True, trust_remote_code: bool = False, **kwargs, ): # LightGlue can be used with other models than SuperPoint as keypoint detector # We provide the trust_remote_code argument to allow the use of other models # that are not registered in the CONFIG_MAPPING dictionary (for example DISK) self.trust_remote_code = trust_remote_code if descriptor_dim % num_attention_heads != 0: raise ValueError("descriptor_dim % num_heads is different from zero") self.descriptor_dim = descriptor_dim self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.depth_confidence = depth_confidence self.width_confidence = width_confidence self.filter_threshold = filter_threshold self.initializer_range = initializer_range # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153 if isinstance(keypoint_detector_config, dict): keypoint_detector_config["model_type"] = ( keypoint_detector_config["model_type"] if "model_type" in keypoint_detector_config else "superpoint" ) if keypoint_detector_config["model_type"] not in CONFIG_MAPPING: keypoint_detector_config = AutoConfig.from_pretrained( keypoint_detector_config["_name_or_path"], trust_remote_code=self.trust_remote_code ) else: keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]]( **keypoint_detector_config, attn_implementation="eager" ) if keypoint_detector_config is None: keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager") self.keypoint_detector_config = keypoint_detector_config self.hidden_size = descriptor_dim self.intermediate_size = descriptor_dim * 2 self.hidden_act = hidden_act self.attention_dropout = attention_dropout self.attention_bias = attention_bias super().__init__(**kwargs) @dataclass @auto_docstring( custom_intro=""" Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching, the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint matching information. """ ) class LightGlueKeypointMatchingOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*): Loss computed during training. matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): Index of keypoint matched in the other image. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): Scores of predicted matches. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): Absolute (x, y) coordinates of predicted keypoints in a given image. prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): Pruning mask indicating which keypoints are removed and at which layer. mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`): Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching information. hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, num_keypoints)` returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True` attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, num_keypoints)` returned when `output_attentions=True` is passed or when `config.output_attentions=True` """ loss: Optional[torch.FloatTensor] = None matches: Optional[torch.FloatTensor] = None matching_scores: Optional[torch.FloatTensor] = None keypoints: Optional[torch.FloatTensor] = None prune: Optional[torch.IntTensor] = None mask: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None class LightGlueImageProcessor(SuperGlueImageProcessor): def post_process_keypoint_matching( self, outputs: LightGlueKeypointMatchingOutput, target_sizes: Union[TensorType, list[tuple]], threshold: float = 0.0, ) -> list[dict[str, torch.Tensor]]: return super().post_process_keypoint_matching(outputs, target_sizes, threshold) def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): """ Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires matplotlib to be installed. Args: images (`ImageInput`): Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or a list of list of 2 images list with pixel values ranging from 0 to 255. outputs ([`LightGlueKeypointMatchingOutput`]): Raw outputs of the model. """ if is_matplotlib_available(): import matplotlib.pyplot as plt else: raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method") images = validate_and_format_image_pairs(images) images = [to_numpy_array(image) for image in images] image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): height0, width0 = image_pair[0].shape[:2] height1, width1 = image_pair[1].shape[:2] plot_image = np.zeros((max(height0, height1), width0 + width1, 3)) plot_image[:height0, :width0] = image_pair[0] / 255.0 plot_image[:height1, width0:] = image_pair[1] / 255.0 plt.imshow(plot_image) plt.axis("off") keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] ): plt.plot( [keypoint0_x, keypoint1_x + width0], [keypoint0_y, keypoint1_y], color=plt.get_cmap("RdYlGn")(matching_score.item()), alpha=0.9, linewidth=0.5, ) plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2) plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2) plt.show() class LightGluePositionalEncoder(nn.Module): def __init__(self, config: LightGlueConfig): super().__init__() self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False) def forward( self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: projected_keypoints = self.projector(keypoints) embeddings = projected_keypoints.repeat_interleave(2, dim=-1) cosines = torch.cos(embeddings) sines = torch.sin(embeddings) embeddings = (cosines, sines) output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,) return output class LightGlueAttention(LlamaAttention): def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2) if position_embeddings is not None: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, current_attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class LightGlueMLP(CLIPMLP): def __init__(self, config: LightGlueConfig): super().__init__(config) self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size) self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.layer_norm(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class LightGlueTransformerLayer(nn.Module): def __init__(self, config: LightGlueConfig, layer_idx: int): super().__init__() self.self_attention = LightGlueAttention(config, layer_idx) self.self_mlp = LightGlueMLP(config) self.cross_attention = LightGlueAttention(config, layer_idx) self.cross_mlp = LightGlueMLP(config) def forward( self, descriptors: torch.Tensor, keypoints: torch.Tensor, attention_mask: torch.Tensor, output_hidden_states: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None if output_hidden_states: all_hidden_states = all_hidden_states + (descriptors,) batch_size, num_keypoints, descriptor_dim = descriptors.shape # Self attention block attention_output, self_attentions = self.self_attention( descriptors, position_embeddings=keypoints, attention_mask=attention_mask, output_attentions=output_attentions, ) intermediate_states = torch.cat([descriptors, attention_output], dim=-1) output_states = self.self_mlp(intermediate_states) self_attention_descriptors = descriptors + output_states if output_hidden_states: self_attention_hidden_states = (intermediate_states, output_states) # Reshape hidden_states to group by image_pairs : # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) # Flip dimension 1 to perform cross attention : # (image0, image1) -> (image1, image0) # Reshape back to original shape : # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim) encoder_hidden_states = ( self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim) .flip(1) .reshape(batch_size, num_keypoints, descriptor_dim) ) # Same for mask encoder_attention_mask = ( attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) if attention_mask is not None else None ) # Cross attention block cross_attention_output, cross_attentions = self.cross_attention( self_attention_descriptors, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1) cross_output_states = self.cross_mlp(cross_intermediate_states) descriptors = self_attention_descriptors + cross_output_states if output_hidden_states: cross_attention_hidden_states = (cross_intermediate_states, cross_output_states) all_hidden_states = ( all_hidden_states + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + self_attention_hidden_states + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + cross_attention_hidden_states ) if output_attentions: all_attentions = all_attentions + (self_attentions,) + (cross_attentions,) return descriptors, all_hidden_states, all_attentions def sigmoid_log_double_softmax( similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor ) -> torch.Tensor: """create the log assignment matrix from logits and similarity""" batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2) scores0 = nn.functional.log_softmax(similarity, 2) scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0) scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1)) scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1)) return scores class LightGlueMatchAssignmentLayer(nn.Module): def __init__(self, config: LightGlueConfig): super().__init__() self.descriptor_dim = config.descriptor_dim self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True) self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True) def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: batch_size, num_keypoints, descriptor_dim = descriptors.shape # Final projection and similarity computation m_descriptors = self.final_projection(descriptors) m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25 m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim) m_descriptors0 = m_descriptors[:, 0] m_descriptors1 = m_descriptors[:, 1] similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2) if mask is not None: mask = mask.reshape(batch_size // 2, 2, num_keypoints) mask0 = mask[:, 0].unsqueeze(-1) mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2) mask = mask0 * mask1 similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min) # Compute matchability of descriptors matchability = self.matchability(descriptors) matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1) matchability_0 = matchability[:, 0] matchability_1 = matchability[:, 1] # Compute scores from similarity and matchability scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1) return scores def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor: """Get matchability of descriptors as a probability""" matchability = self.matchability(descriptors) matchability = nn.functional.sigmoid(matchability).squeeze(-1) return matchability class LightGlueTokenConfidenceLayer(nn.Module): def __init__(self, config: LightGlueConfig): super().__init__() self.token = nn.Linear(config.descriptor_dim, 1) def forward(self, descriptors: torch.Tensor) -> torch.Tensor: token = self.token(descriptors.detach()) token = nn.functional.sigmoid(token).squeeze(-1) return token @auto_docstring class LightGluePreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config: LightGlueConfig base_model_prefix = "lightglue" main_input_name = "pixel_values" supports_gradient_checkpointing = False _supports_flash_attn = True _supports_sdpa = True def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]: """obtain matches from a score matrix [Bx M+1 x N+1]""" batch_size, _, _ = scores.shape # For each keypoint, get the best match max0 = scores[:, :-1, :-1].max(2) max1 = scores[:, :-1, :-1].max(1) matches0 = max0.indices matches1 = max1.indices # Mutual check for matches indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None] indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None] mutual0 = indices0 == matches1.gather(1, matches0) mutual1 = indices1 == matches0.gather(1, matches1) # Get matching scores and filter based on mutual check and thresholding max0 = max0.values.exp() zero = max0.new_tensor(0) matching_scores0 = torch.where(mutual0, max0, zero) matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero) valid0 = mutual0 & (matching_scores0 > threshold) valid1 = mutual1 & valid0.gather(1, matches1) # Filter matches based on mutual check and thresholding of scores matches0 = torch.where(valid0, matches0, -1) matches1 = torch.where(valid1, matches1, -1) matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1) matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1) return matches, matching_scores def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: """ Normalize keypoints locations based on image image_shape Args: keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): Keypoints locations in (x, y) format. height (`int`): Image height. width (`int`): Image width. Returns: Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). """ size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] shift = size / 2 scale = size.max(-1).values / 2 keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None] return keypoints @auto_docstring( custom_intro=""" LightGlue model taking images as inputs and outputting the matching of them. """ ) class LightGlueForKeypointMatching(LightGluePreTrainedModel): """ LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient. It consists of : 1. Keypoint Encoder 2. A Graph Neural Network with self and cross attention layers 3. Matching Assignment layers The correspondence ids use -1 to indicate non-matching points. Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed. In ICCV 2023. https://arxiv.org/pdf/2306.13643.pdf """ def __init__(self, config: LightGlueConfig): super().__init__(config) self.keypoint_detector = AutoModelForKeypointDetection.from_config( config.keypoint_detector_config, trust_remote_code=config.trust_remote_code ) self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim self.descriptor_dim = config.descriptor_dim self.num_layers = config.num_hidden_layers self.filter_threshold = config.filter_threshold self.depth_confidence = config.depth_confidence self.width_confidence = config.width_confidence if self.descriptor_dim != self.keypoint_detector_descriptor_dim: self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True) else: self.input_projection = nn.Identity() self.positional_encoder = LightGluePositionalEncoder(config) self.transformer_layers = nn.ModuleList( [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] ) self.match_assignment_layers = nn.ModuleList( [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)] ) self.token_confidence = nn.ModuleList( [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)] ) self.post_init() def _get_confidence_threshold(self, layer_index: int) -> float: """scaled confidence threshold for a given layer""" threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers) return np.clip(threshold, 0, 1) def _keypoint_processing( self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: descriptors = descriptors.detach().contiguous() projected_descriptors = self.input_projection(descriptors) keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states) return projected_descriptors, keypoint_encoding_output def _get_early_stopped_image_pairs( self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor ) -> torch.Tensor: """evaluate whether we should stop inference based on the confidence of the keypoints""" batch_size, _ = mask.shape if layer_index < self.num_layers - 1: # If the current layer is not the last layer, we compute the confidence of the keypoints and check # if we should stop the forward pass through the transformer layers for each pair of images. keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1) keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1) threshold = self._get_confidence_threshold(layer_index) ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points early_stopped_pairs = ratio_confident > self.depth_confidence else: # If the current layer is the last layer, we stop the forward pass through the transformer layers for # all pairs of images. early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) return early_stopped_pairs def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None): if early_stops is not None: descriptors = descriptors[early_stops] mask = mask[early_stops] scores = self.match_assignment_layers[layer_index](descriptors, mask) matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold) return matches, matching_scores def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor: """mask points which should be removed""" keep = scores > (1 - self.width_confidence) if confidences is not None: # Low-confidence points are never pruned. keep |= confidences <= self._get_confidence_threshold(layer_index) return keep def _do_layer_keypoint_pruning( self, descriptors: torch.Tensor, keypoints: torch.Tensor, mask: torch.Tensor, indices: torch.Tensor, prune_output: torch.Tensor, keypoint_confidences: torch.Tensor, layer_index: int, ): """ For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the descriptors. """ batch_size, _, _ = descriptors.shape descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors) pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index) pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False)) # For each image, we extract the pruned indices and the corresponding descriptors and keypoints. pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = ( [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)] for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices] ) for i in range(batch_size): prune_output[i, pruned_indices[i]] += 1 # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch. pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = ( pad_sequence(pruned_tensor, batch_first=True) for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask] ) pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1) pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1) return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output def _concat_early_stopped_outputs( self, early_stops_indices, final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores, ): early_stops_indices = torch.stack(early_stops_indices) matches, final_pruned_keypoints_indices = ( pad_sequence(tensor, batch_first=True, padding_value=-1) for tensor in [matches, final_pruned_keypoints_indices] ) matching_scores, final_pruned_keypoints_iterations = ( pad_sequence(tensor, batch_first=True, padding_value=0) for tensor in [matching_scores, final_pruned_keypoints_iterations] ) matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = ( tensor[early_stops_indices] for tensor in [ matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations, ] ) return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores def _do_final_keypoint_pruning( self, indices: torch.Tensor, matches: torch.Tensor, matching_scores: torch.Tensor, num_keypoints: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to # have tensors from batch_size, _ = indices.shape indices, matches, matching_scores = ( tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores] ) indices0 = indices[:, 0] indices1 = indices[:, 1] matches0 = matches[:, 0] matches1 = matches[:, 1] matching_scores0 = matching_scores[:, 0] matching_scores1 = matching_scores[:, 1] # Prepare final matches and matching scores _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype) _matching_scores = torch.zeros( (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype ) # Fill the matches and matching scores for each image pair for i in range(batch_size // 2): _matches[i, 0, indices0[i]] = torch.where( matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0)) ) _matches[i, 1, indices1[i]] = torch.where( matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0)) ) _matching_scores[i, 0, indices0[i]] = matching_scores0[i] _matching_scores[i, 1, indices1[i]] = matching_scores1[i] return _matches, _matching_scores def _match_image_pair( self, keypoints: torch.Tensor, descriptors: torch.Tensor, height: int, width: int, mask: torch.Tensor = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None if keypoints.shape[2] == 0: # no keypoints shape = keypoints.shape[:-1] return ( keypoints.new_full(shape, -1, dtype=torch.int), keypoints.new_zeros(shape), keypoints.new_zeros(shape), all_hidden_states, all_attentions, ) device = keypoints.device batch_size, _, initial_num_keypoints, _ = keypoints.shape num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1) # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2) mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim) image_indices = torch.arange(batch_size * 2, device=device) # Keypoint normalization keypoints = normalize_keypoints(keypoints, height, width) descriptors, keypoint_encoding_output = self._keypoint_processing( descriptors, keypoints, output_hidden_states=output_hidden_states ) keypoints = keypoint_encoding_output[0] # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the # keypoints is above a certain threshold. do_early_stop = self.depth_confidence > 0 # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of # the keypoints is below a certain threshold. do_keypoint_pruning = self.width_confidence > 0 early_stops_indices = [] matches = [] matching_scores = [] final_pruned_keypoints_indices = [] final_pruned_keypoints_iterations = [] pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1) pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices) for layer_index in range(self.num_layers): input_shape = descriptors.size() if mask is not None: extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) else: extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device) layer_output = self.transformer_layers[layer_index]( descriptors, keypoints, attention_mask=extended_attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) descriptors, hidden_states, attention = layer_output if output_hidden_states: all_hidden_states = all_hidden_states + hidden_states if output_attentions: all_attentions = all_attentions + attention if do_early_stop: if layer_index < self.num_layers - 1: # Get the confidence of the keypoints for the current layer keypoint_confidences = self.token_confidence[layer_index](descriptors) # Determine which pairs of images should be early stopped based on the confidence of the keypoints for # the current layer. early_stopped_pairs = self._get_early_stopped_image_pairs( keypoint_confidences, layer_index, mask, num_points=num_points_per_pair ) else: # Early stopping always occurs at the last layer early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) if torch.any(early_stopped_pairs): # If a pair of images is considered early stopped, we compute the matches for the remaining # keypoints and stop the forward pass through the transformer layers for this pair of images. early_stops = early_stopped_pairs.repeat_interleave(2) early_stopped_image_indices = image_indices[early_stops] early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching( descriptors, mask, layer_index, early_stops=early_stops ) early_stops_indices.extend(list(early_stopped_image_indices)) matches.extend(list(early_stopped_matches)) matching_scores.extend(list(early_stopped_matching_scores)) if do_keypoint_pruning: final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops])) final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops])) # Remove image pairs that have been early stopped from the forward pass num_points_per_pair = num_points_per_pair[~early_stopped_pairs] descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple( tensor[~early_stops] for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices] ) keypoints = (keypoints_0, keypoint_1) if do_keypoint_pruning: pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple( tensor[~early_stops] for tensor in [ pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences, ] ) # If all pairs of images are early stopped, we stop the forward pass through the transformer # layers for all pairs of images. if torch.all(early_stopped_pairs): break if do_keypoint_pruning: # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of # the keypoints is below a certain threshold. descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = ( self._do_layer_keypoint_pruning( descriptors, keypoints, mask, pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences, layer_index, ) ) if do_early_stop and do_keypoint_pruning: # Concatenate early stopped outputs together and perform final keypoint pruning final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = ( self._concat_early_stopped_outputs( early_stops_indices, final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores, ) ) matches, matching_scores = self._do_final_keypoint_pruning( final_pruned_keypoints_indices, matches, matching_scores, initial_num_keypoints, ) else: matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1) final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape( batch_size, 2, initial_num_keypoints ) return ( matches, matching_scores, final_pruned_keypoints_iterations, all_hidden_states, all_attentions, ) @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> Union[tuple, LightGlueKeypointMatchingOutput]: loss = None if labels is not None: raise ValueError("LightGlue is not trainable, no labels should be provided.") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if pixel_values.ndim != 5 or pixel_values.size(1) != 2: raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") batch_size, _, channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) keypoint_detections = self.keypoint_detector(pixel_values) keypoints, _, descriptors, mask = keypoint_detections[:4] keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values) mask = mask.reshape(batch_size, 2, -1) absolute_keypoints = keypoints.clone() absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair( absolute_keypoints, descriptors, height, width, mask=mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) return LightGlueKeypointMatchingOutput( loss=loss, matches=matches, matching_scores=matching_scores, keypoints=keypoints, prune=prune, mask=mask, hidden_states=hidden_states, attentions=attentions, ) __all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching", "LightGlueConfig", "LightGlueImageProcessor"]