team-10/venv/Lib/site-packages/transformers/models/lightglue/modeling_lightglue.py

915 lines
42 KiB
Python
Raw Normal View History

2025-08-02 02:00:33 +02:00
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_lightglue.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Optional, Union
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from ...activations import ACT2FN
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring
from ...utils.generic import can_return_tuple
from ..auto.modeling_auto import AutoModelForKeypointDetection
from .configuration_lightglue import LightGlueConfig
@dataclass
@auto_docstring(
custom_intro="""
Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching,
the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the
batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask
tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint
matching information.
"""
)
class LightGlueKeypointMatchingOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Loss computed during training.
matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
Index of keypoint matched in the other image.
matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
Scores of predicted matches.
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
Absolute (x, y) coordinates of predicted keypoints in a given image.
prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
Pruning mask indicating which keypoints are removed and at which layer.
mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching
information.
hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
num_keypoints)` returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`
attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
num_keypoints)` returned when `output_attentions=True` is passed or when
`config.output_attentions=True`
"""
loss: Optional[torch.FloatTensor] = None
matches: Optional[torch.FloatTensor] = None
matching_scores: Optional[torch.FloatTensor] = None
keypoints: Optional[torch.FloatTensor] = None
prune: Optional[torch.IntTensor] = None
mask: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
class LightGluePositionalEncoder(nn.Module):
def __init__(self, config: LightGlueConfig):
super().__init__()
self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False)
def forward(
self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
projected_keypoints = self.projector(keypoints)
embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
cosines = torch.cos(embeddings)
sines = torch.sin(embeddings)
embeddings = (cosines, sines)
output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,)
return output
def rotate_half(x):
# Split and rotate. Note that this function is different from e.g. Llama.
x1 = x[..., ::2]
x2 = x[..., 1::2]
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
return rot_x
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
dtype = q.dtype
q = q.float()
k = k.float()
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class LightGlueAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LightGlueConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
current_attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class LightGlueMLP(nn.Module):
def __init__(self, config: LightGlueConfig):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class LightGlueTransformerLayer(nn.Module):
def __init__(self, config: LightGlueConfig, layer_idx: int):
super().__init__()
self.self_attention = LightGlueAttention(config, layer_idx)
self.self_mlp = LightGlueMLP(config)
self.cross_attention = LightGlueAttention(config, layer_idx)
self.cross_mlp = LightGlueMLP(config)
def forward(
self,
descriptors: torch.Tensor,
keypoints: torch.Tensor,
attention_mask: torch.Tensor,
output_hidden_states: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
if output_hidden_states:
all_hidden_states = all_hidden_states + (descriptors,)
batch_size, num_keypoints, descriptor_dim = descriptors.shape
# Self attention block
attention_output, self_attentions = self.self_attention(
descriptors,
position_embeddings=keypoints,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
intermediate_states = torch.cat([descriptors, attention_output], dim=-1)
output_states = self.self_mlp(intermediate_states)
self_attention_descriptors = descriptors + output_states
if output_hidden_states:
self_attention_hidden_states = (intermediate_states, output_states)
# Reshape hidden_states to group by image_pairs :
# (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
# Flip dimension 1 to perform cross attention :
# (image0, image1) -> (image1, image0)
# Reshape back to original shape :
# (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim)
encoder_hidden_states = (
self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim)
.flip(1)
.reshape(batch_size, num_keypoints, descriptor_dim)
)
# Same for mask
encoder_attention_mask = (
attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
if attention_mask is not None
else None
)
# Cross attention block
cross_attention_output, cross_attentions = self.cross_attention(
self_attention_descriptors,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1)
cross_output_states = self.cross_mlp(cross_intermediate_states)
descriptors = self_attention_descriptors + cross_output_states
if output_hidden_states:
cross_attention_hidden_states = (cross_intermediate_states, cross_output_states)
all_hidden_states = (
all_hidden_states
+ (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
+ self_attention_hidden_states
+ (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
+ cross_attention_hidden_states
)
if output_attentions:
all_attentions = all_attentions + (self_attentions,) + (cross_attentions,)
return descriptors, all_hidden_states, all_attentions
def sigmoid_log_double_softmax(
similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor
) -> torch.Tensor:
"""create the log assignment matrix from logits and similarity"""
batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape
certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2)
scores0 = nn.functional.log_softmax(similarity, 2)
scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0)
scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties
scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1))
scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1))
return scores
class LightGlueMatchAssignmentLayer(nn.Module):
def __init__(self, config: LightGlueConfig):
super().__init__()
self.descriptor_dim = config.descriptor_dim
self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True)
self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True)
def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
batch_size, num_keypoints, descriptor_dim = descriptors.shape
# Final projection and similarity computation
m_descriptors = self.final_projection(descriptors)
m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25
m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim)
m_descriptors0 = m_descriptors[:, 0]
m_descriptors1 = m_descriptors[:, 1]
similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2)
if mask is not None:
mask = mask.reshape(batch_size // 2, 2, num_keypoints)
mask0 = mask[:, 0].unsqueeze(-1)
mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2)
mask = mask0 * mask1
similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min)
# Compute matchability of descriptors
matchability = self.matchability(descriptors)
matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1)
matchability_0 = matchability[:, 0]
matchability_1 = matchability[:, 1]
# Compute scores from similarity and matchability
scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1)
return scores
def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor:
"""Get matchability of descriptors as a probability"""
matchability = self.matchability(descriptors)
matchability = nn.functional.sigmoid(matchability).squeeze(-1)
return matchability
class LightGlueTokenConfidenceLayer(nn.Module):
def __init__(self, config: LightGlueConfig):
super().__init__()
self.token = nn.Linear(config.descriptor_dim, 1)
def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
token = self.token(descriptors.detach())
token = nn.functional.sigmoid(token).squeeze(-1)
return token
@auto_docstring
class LightGluePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config: LightGlueConfig
base_model_prefix = "lightglue"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
_supports_flash_attn = True
_supports_sdpa = True
def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
"""obtain matches from a score matrix [Bx M+1 x N+1]"""
batch_size, _, _ = scores.shape
# For each keypoint, get the best match
max0 = scores[:, :-1, :-1].max(2)
max1 = scores[:, :-1, :-1].max(1)
matches0 = max0.indices
matches1 = max1.indices
# Mutual check for matches
indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None]
indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None]
mutual0 = indices0 == matches1.gather(1, matches0)
mutual1 = indices1 == matches0.gather(1, matches1)
# Get matching scores and filter based on mutual check and thresholding
max0 = max0.values.exp()
zero = max0.new_tensor(0)
matching_scores0 = torch.where(mutual0, max0, zero)
matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero)
valid0 = mutual0 & (matching_scores0 > threshold)
valid1 = mutual1 & valid0.gather(1, matches1)
# Filter matches based on mutual check and thresholding of scores
matches0 = torch.where(valid0, matches0, -1)
matches1 = torch.where(valid1, matches1, -1)
matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1)
matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1)
return matches, matching_scores
def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
Normalize keypoints locations based on image image_shape
Args:
keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
Keypoints locations in (x, y) format.
height (`int`):
Image height.
width (`int`):
Image width.
Returns:
Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
"""
size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
shift = size / 2
scale = size.max(-1).values / 2
keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None]
return keypoints
@auto_docstring(
custom_intro="""
LightGlue model taking images as inputs and outputting the matching of them.
"""
)
class LightGlueForKeypointMatching(LightGluePreTrainedModel):
"""
LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as
SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient.
It consists of :
1. Keypoint Encoder
2. A Graph Neural Network with self and cross attention layers
3. Matching Assignment layers
The correspondence ids use -1 to indicate non-matching points.
Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed.
In ICCV 2023. https://arxiv.org/pdf/2306.13643.pdf
"""
def __init__(self, config: LightGlueConfig):
super().__init__(config)
self.keypoint_detector = AutoModelForKeypointDetection.from_config(
config.keypoint_detector_config, trust_remote_code=config.trust_remote_code
)
self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
self.descriptor_dim = config.descriptor_dim
self.num_layers = config.num_hidden_layers
self.filter_threshold = config.filter_threshold
self.depth_confidence = config.depth_confidence
self.width_confidence = config.width_confidence
if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
else:
self.input_projection = nn.Identity()
self.positional_encoder = LightGluePositionalEncoder(config)
self.transformer_layers = nn.ModuleList(
[LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
)
self.match_assignment_layers = nn.ModuleList(
[LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)]
)
self.token_confidence = nn.ModuleList(
[LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)]
)
self.post_init()
def _get_confidence_threshold(self, layer_index: int) -> float:
"""scaled confidence threshold for a given layer"""
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers)
return np.clip(threshold, 0, 1)
def _keypoint_processing(
self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
descriptors = descriptors.detach().contiguous()
projected_descriptors = self.input_projection(descriptors)
keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
return projected_descriptors, keypoint_encoding_output
def _get_early_stopped_image_pairs(
self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor
) -> torch.Tensor:
"""evaluate whether we should stop inference based on the confidence of the keypoints"""
batch_size, _ = mask.shape
if layer_index < self.num_layers - 1:
# If the current layer is not the last layer, we compute the confidence of the keypoints and check
# if we should stop the forward pass through the transformer layers for each pair of images.
keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1)
keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1)
threshold = self._get_confidence_threshold(layer_index)
ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points
early_stopped_pairs = ratio_confident > self.depth_confidence
else:
# If the current layer is the last layer, we stop the forward pass through the transformer layers for
# all pairs of images.
early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
return early_stopped_pairs
def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None):
if early_stops is not None:
descriptors = descriptors[early_stops]
mask = mask[early_stops]
scores = self.match_assignment_layers[layer_index](descriptors, mask)
matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold)
return matches, matching_scores
def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor:
"""mask points which should be removed"""
keep = scores > (1 - self.width_confidence)
if confidences is not None: # Low-confidence points are never pruned.
keep |= confidences <= self._get_confidence_threshold(layer_index)
return keep
def _do_layer_keypoint_pruning(
self,
descriptors: torch.Tensor,
keypoints: torch.Tensor,
mask: torch.Tensor,
indices: torch.Tensor,
prune_output: torch.Tensor,
keypoint_confidences: torch.Tensor,
layer_index: int,
):
"""
For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the
descriptors.
"""
batch_size, _, _ = descriptors.shape
descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors)
pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index)
pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False))
# For each image, we extract the pruned indices and the corresponding descriptors and keypoints.
pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = (
[t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)]
for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices]
)
for i in range(batch_size):
prune_output[i, pruned_indices[i]] += 1
# Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch.
pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = (
pad_sequence(pruned_tensor, batch_first=True)
for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask]
)
pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1)
pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1)
return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output
def _concat_early_stopped_outputs(
self,
early_stops_indices,
final_pruned_keypoints_indices,
final_pruned_keypoints_iterations,
matches,
matching_scores,
):
early_stops_indices = torch.stack(early_stops_indices)
matches, final_pruned_keypoints_indices = (
pad_sequence(tensor, batch_first=True, padding_value=-1)
for tensor in [matches, final_pruned_keypoints_indices]
)
matching_scores, final_pruned_keypoints_iterations = (
pad_sequence(tensor, batch_first=True, padding_value=0)
for tensor in [matching_scores, final_pruned_keypoints_iterations]
)
matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = (
tensor[early_stops_indices]
for tensor in [
matches,
matching_scores,
final_pruned_keypoints_indices,
final_pruned_keypoints_iterations,
]
)
return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores
def _do_final_keypoint_pruning(
self,
indices: torch.Tensor,
matches: torch.Tensor,
matching_scores: torch.Tensor,
num_keypoints: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
# have tensors from
batch_size, _ = indices.shape
indices, matches, matching_scores = (
tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores]
)
indices0 = indices[:, 0]
indices1 = indices[:, 1]
matches0 = matches[:, 0]
matches1 = matches[:, 1]
matching_scores0 = matching_scores[:, 0]
matching_scores1 = matching_scores[:, 1]
# Prepare final matches and matching scores
_matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype)
_matching_scores = torch.zeros(
(batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype
)
# Fill the matches and matching scores for each image pair
for i in range(batch_size // 2):
_matches[i, 0, indices0[i]] = torch.where(
matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0))
)
_matches[i, 1, indices1[i]] = torch.where(
matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0))
)
_matching_scores[i, 0, indices0[i]] = matching_scores0[i]
_matching_scores[i, 1, indices1[i]] = matching_scores1[i]
return _matches, _matching_scores
def _match_image_pair(
self,
keypoints: torch.Tensor,
descriptors: torch.Tensor,
height: int,
width: int,
mask: torch.Tensor = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
if keypoints.shape[2] == 0: # no keypoints
shape = keypoints.shape[:-1]
return (
keypoints.new_full(shape, -1, dtype=torch.int),
keypoints.new_zeros(shape),
keypoints.new_zeros(shape),
all_hidden_states,
all_attentions,
)
device = keypoints.device
batch_size, _, initial_num_keypoints, _ = keypoints.shape
num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1)
# (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
image_indices = torch.arange(batch_size * 2, device=device)
# Keypoint normalization
keypoints = normalize_keypoints(keypoints, height, width)
descriptors, keypoint_encoding_output = self._keypoint_processing(
descriptors, keypoints, output_hidden_states=output_hidden_states
)
keypoints = keypoint_encoding_output[0]
# Early stop consists of stopping the forward pass through the transformer layers when the confidence of the
# keypoints is above a certain threshold.
do_early_stop = self.depth_confidence > 0
# Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of
# the keypoints is below a certain threshold.
do_keypoint_pruning = self.width_confidence > 0
early_stops_indices = []
matches = []
matching_scores = []
final_pruned_keypoints_indices = []
final_pruned_keypoints_iterations = []
pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1)
pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices)
for layer_index in range(self.num_layers):
input_shape = descriptors.size()
if mask is not None:
extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
else:
extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device)
layer_output = self.transformer_layers[layer_index](
descriptors,
keypoints,
attention_mask=extended_attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
descriptors, hidden_states, attention = layer_output
if output_hidden_states:
all_hidden_states = all_hidden_states + hidden_states
if output_attentions:
all_attentions = all_attentions + attention
if do_early_stop:
if layer_index < self.num_layers - 1:
# Get the confidence of the keypoints for the current layer
keypoint_confidences = self.token_confidence[layer_index](descriptors)
# Determine which pairs of images should be early stopped based on the confidence of the keypoints for
# the current layer.
early_stopped_pairs = self._get_early_stopped_image_pairs(
keypoint_confidences, layer_index, mask, num_points=num_points_per_pair
)
else:
# Early stopping always occurs at the last layer
early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
if torch.any(early_stopped_pairs):
# If a pair of images is considered early stopped, we compute the matches for the remaining
# keypoints and stop the forward pass through the transformer layers for this pair of images.
early_stops = early_stopped_pairs.repeat_interleave(2)
early_stopped_image_indices = image_indices[early_stops]
early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching(
descriptors, mask, layer_index, early_stops=early_stops
)
early_stops_indices.extend(list(early_stopped_image_indices))
matches.extend(list(early_stopped_matches))
matching_scores.extend(list(early_stopped_matching_scores))
if do_keypoint_pruning:
final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops]))
final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops]))
# Remove image pairs that have been early stopped from the forward pass
num_points_per_pair = num_points_per_pair[~early_stopped_pairs]
descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple(
tensor[~early_stops]
for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices]
)
keypoints = (keypoints_0, keypoint_1)
if do_keypoint_pruning:
pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple(
tensor[~early_stops]
for tensor in [
pruned_keypoints_indices,
pruned_keypoints_iterations,
keypoint_confidences,
]
)
# If all pairs of images are early stopped, we stop the forward pass through the transformer
# layers for all pairs of images.
if torch.all(early_stopped_pairs):
break
if do_keypoint_pruning:
# Prune keypoints from the input of the transformer layers for the next iterations if the confidence of
# the keypoints is below a certain threshold.
descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = (
self._do_layer_keypoint_pruning(
descriptors,
keypoints,
mask,
pruned_keypoints_indices,
pruned_keypoints_iterations,
keypoint_confidences,
layer_index,
)
)
if do_early_stop and do_keypoint_pruning:
# Concatenate early stopped outputs together and perform final keypoint pruning
final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = (
self._concat_early_stopped_outputs(
early_stops_indices,
final_pruned_keypoints_indices,
final_pruned_keypoints_iterations,
matches,
matching_scores,
)
)
matches, matching_scores = self._do_final_keypoint_pruning(
final_pruned_keypoints_indices,
matches,
matching_scores,
initial_num_keypoints,
)
else:
matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1)
final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers
final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape(
batch_size, 2, initial_num_keypoints
)
return (
matches,
matching_scores,
final_pruned_keypoints_iterations,
all_hidden_states,
all_attentions,
)
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[tuple, LightGlueKeypointMatchingOutput]:
loss = None
if labels is not None:
raise ValueError("LightGlue is not trainable, no labels should be provided.")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
batch_size, _, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
keypoint_detections = self.keypoint_detector(pixel_values)
keypoints, _, descriptors, mask = keypoint_detections[:4]
keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
mask = mask.reshape(batch_size, 2, -1)
absolute_keypoints = keypoints.clone()
absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair(
absolute_keypoints,
descriptors,
height,
width,
mask=mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
return LightGlueKeypointMatchingOutput(
loss=loss,
matches=matches,
matching_scores=matching_scores,
keypoints=keypoints,
prune=prune,
mask=mask,
hidden_states=hidden_states,
attentions=attentions,
)
__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching"]