team-10/venv/Lib/site-packages/transformers/models/superglue/modeling_superglue.py
2025-08-02 02:00:33 +02:00

809 lines
35 KiB
Python

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch SuperGlue model."""
import math
from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.superglue.configuration_superglue import SuperGlueConfig
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging
from ..auto import AutoModelForKeypointDetection
logger = logging.get_logger(__name__)
def concat_pairs(tensor_tuple0: tuple[torch.Tensor], tensor_tuple1: tuple[torch.Tensor]) -> tuple[torch.Tensor]:
"""
Concatenate two tuples of tensors pairwise
Args:
tensor_tuple0 (`tuple[torch.Tensor]`):
Tuple of tensors.
tensor_tuple1 (`tuple[torch.Tensor]`):
Tuple of tensors.
Returns:
(`tuple[torch.Tensor]`): Tuple of concatenated tensors.
"""
return tuple([torch.cat([tensor0, tensor1]) for tensor0, tensor1 in zip(tensor_tuple0, tensor_tuple1)])
def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
Normalize keypoints locations based on image image_shape
Args:
keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
Keypoints locations in (x, y) format.
height (`int`):
Image height.
width (`int`):
Image width.
Returns:
Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
"""
size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
center = size / 2
scaling = size.max(1, keepdim=True).values * 0.7
return (keypoints - center[:, None, :]) / scaling[:, None, :]
def log_sinkhorn_iterations(
log_cost_matrix: torch.Tensor,
log_source_distribution: torch.Tensor,
log_target_distribution: torch.Tensor,
num_iterations: int,
) -> torch.Tensor:
"""
Perform Sinkhorn Normalization in Log-space for stability
Args:
log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
Logarithm of the cost matrix.
log_source_distribution (`torch.Tensor` of shape `(batch_size, num_rows)`):
Logarithm of the source distribution.
log_target_distribution (`torch.Tensor` of shape `(batch_size, num_columns)`):
Logarithm of the target distribution.
Returns:
log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the optimal
transport matrix.
"""
log_u_scaling = torch.zeros_like(log_source_distribution)
log_v_scaling = torch.zeros_like(log_target_distribution)
for _ in range(num_iterations):
log_u_scaling = log_source_distribution - torch.logsumexp(log_cost_matrix + log_v_scaling.unsqueeze(1), dim=2)
log_v_scaling = log_target_distribution - torch.logsumexp(log_cost_matrix + log_u_scaling.unsqueeze(2), dim=1)
return log_cost_matrix + log_u_scaling.unsqueeze(2) + log_v_scaling.unsqueeze(1)
def log_optimal_transport(scores: torch.Tensor, reg_param: torch.Tensor, iterations: int) -> torch.Tensor:
"""
Perform Differentiable Optimal Transport in Log-space for stability
Args:
scores: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
Cost matrix.
reg_param: (`torch.Tensor` of shape `(batch_size, 1, 1)`):
Regularization parameter.
iterations: (`int`):
Number of Sinkhorn iterations.
Returns:
log_optimal_transport_matrix: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the
optimal transport matrix.
"""
batch_size, num_rows, num_columns = scores.shape
one_tensor = scores.new_tensor(1)
num_rows_tensor, num_columns_tensor = (num_rows * one_tensor).to(scores), (num_columns * one_tensor).to(scores)
source_reg_param = reg_param.expand(batch_size, num_rows, 1)
target_reg_param = reg_param.expand(batch_size, 1, num_columns)
reg_param = reg_param.expand(batch_size, 1, 1)
couplings = torch.cat([torch.cat([scores, source_reg_param], -1), torch.cat([target_reg_param, reg_param], -1)], 1)
log_normalization = -(num_rows_tensor + num_columns_tensor).log()
log_source_distribution = torch.cat(
[log_normalization.expand(num_rows), num_columns_tensor.log()[None] + log_normalization]
)
log_target_distribution = torch.cat(
[log_normalization.expand(num_columns), num_rows_tensor.log()[None] + log_normalization]
)
log_source_distribution, log_target_distribution = (
log_source_distribution[None].expand(batch_size, -1),
log_target_distribution[None].expand(batch_size, -1),
)
log_optimal_transport_matrix = log_sinkhorn_iterations(
couplings, log_source_distribution, log_target_distribution, num_iterations=iterations
)
log_optimal_transport_matrix = log_optimal_transport_matrix - log_normalization # multiply probabilities by M+N
return log_optimal_transport_matrix
def arange_like(x, dim: int) -> torch.Tensor:
return x.new_ones(x.shape[dim]).cumsum(0) - 1
@dataclass
@auto_docstring(
custom_intro="""
Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number
of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of
images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is
used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching
information.
"""
)
class KeypointMatchingOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Loss computed during training.
matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
Index of keypoint matched in the other image.
matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
Scores of predicted matches.
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
Absolute (x, y) coordinates of predicted keypoints in a given image.
mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
Mask indicating which values in matches and matching_scores are keypoint matching information.
hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
num_keypoints)`, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`)
attentions (`tuple[torch.FloatTensor, ...]`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`)
"""
loss: Optional[torch.FloatTensor] = None
matches: Optional[torch.FloatTensor] = None
matching_scores: Optional[torch.FloatTensor] = None
keypoints: Optional[torch.FloatTensor] = None
mask: Optional[torch.IntTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
class SuperGlueMultiLayerPerceptron(nn.Module):
def __init__(self, config: SuperGlueConfig, in_channels: int, out_channels: int) -> None:
super().__init__()
self.linear = nn.Linear(in_channels, out_channels)
self.batch_norm = nn.BatchNorm1d(out_channels)
self.activation = nn.ReLU()
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.linear(hidden_state)
hidden_state = hidden_state.transpose(-1, -2)
hidden_state = self.batch_norm(hidden_state)
hidden_state = hidden_state.transpose(-1, -2)
hidden_state = self.activation(hidden_state)
return hidden_state
class SuperGlueKeypointEncoder(nn.Module):
def __init__(self, config: SuperGlueConfig) -> None:
super().__init__()
layer_sizes = config.keypoint_encoder_sizes
hidden_size = config.hidden_size
# 3 here consists of 2 for the (x, y) coordinates and 1 for the score of the keypoint
encoder_channels = [3] + layer_sizes + [hidden_size]
layers = [
SuperGlueMultiLayerPerceptron(config, encoder_channels[i - 1], encoder_channels[i])
for i in range(1, len(encoder_channels) - 1)
]
layers.append(nn.Linear(encoder_channels[-2], encoder_channels[-1]))
self.encoder = nn.ModuleList(layers)
def forward(
self,
keypoints: torch.Tensor,
scores: torch.Tensor,
output_hidden_states: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]]]:
scores = scores.unsqueeze(2)
hidden_state = torch.cat([keypoints, scores], dim=2)
all_hidden_states = () if output_hidden_states else None
for layer in self.encoder:
hidden_state = layer(hidden_state)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
return hidden_state, all_hidden_states
class SuperGlueSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
batch_size = hidden_states.shape[0]
key_layer = (
self.key(current_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(current_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
query_layer = (
self.query(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in SuperGlueModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (None,)
return outputs
class SuperGlueSelfOutput(nn.Module):
def __init__(self, config: SuperGlueConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
return hidden_states
SUPERGLUE_SELF_ATTENTION_CLASSES = {
"eager": SuperGlueSelfAttention,
}
class SuperGlueAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation](
config,
position_embedding_type=position_embedding_type,
)
self.output = SuperGlueSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class SuperGlueAttentionalPropagation(nn.Module):
def __init__(self, config: SuperGlueConfig) -> None:
super().__init__()
hidden_size = config.hidden_size
self.attention = SuperGlueAttention(config)
mlp_channels = [hidden_size * 2, hidden_size * 2, hidden_size]
layers = [
SuperGlueMultiLayerPerceptron(config, mlp_channels[i - 1], mlp_channels[i])
for i in range(1, len(mlp_channels) - 1)
]
layers.append(nn.Linear(mlp_channels[-2], mlp_channels[-1]))
self.mlp = nn.ModuleList(layers)
def forward(
self,
descriptors: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]:
attention_outputs = self.attention(
descriptors,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
output = attention_outputs[0]
attention = attention_outputs[1:]
hidden_state = torch.cat([descriptors, output], dim=2)
all_hidden_states = () if output_hidden_states else None
for layer in self.mlp:
hidden_state = layer(hidden_state)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
return hidden_state, all_hidden_states, attention
class SuperGlueAttentionalGNN(nn.Module):
def __init__(self, config: SuperGlueConfig) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.layers_types = config.gnn_layers_types
self.layers = nn.ModuleList([SuperGlueAttentionalPropagation(config) for _ in range(len(self.layers_types))])
def forward(
self,
descriptors: torch.Tensor,
mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[tuple], Optional[tuple]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
batch_size, num_keypoints, _ = descriptors.shape
if output_hidden_states:
all_hidden_states = all_hidden_states + (descriptors,)
for gnn_layer, layer_type in zip(self.layers, self.layers_types):
encoder_hidden_states = None
encoder_attention_mask = None
if layer_type == "cross":
encoder_hidden_states = (
descriptors.reshape(-1, 2, num_keypoints, self.hidden_size)
.flip(1)
.reshape(batch_size, num_keypoints, self.hidden_size)
)
encoder_attention_mask = (
mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
if mask is not None
else None
)
gnn_outputs = gnn_layer(
descriptors,
attention_mask=mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
delta = gnn_outputs[0]
if output_hidden_states:
all_hidden_states = all_hidden_states + gnn_outputs[1]
if output_attentions:
all_attentions = all_attentions + gnn_outputs[2]
descriptors = descriptors + delta
return descriptors, all_hidden_states, all_attentions
class SuperGlueFinalProjection(nn.Module):
def __init__(self, config: SuperGlueConfig) -> None:
super().__init__()
hidden_size = config.hidden_size
self.final_proj = nn.Linear(hidden_size, hidden_size, bias=True)
def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
return self.final_proj(descriptors)
@auto_docstring
class SuperGluePreTrainedModel(PreTrainedModel):
config: SuperGlueConfig
base_model_prefix = "superglue"
main_input_name = "pixel_values"
def _init_weights(self, module: nn.Module) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm1d):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if hasattr(module, "bin_score"):
module.bin_score.data.fill_(1.0)
@auto_docstring(
custom_intro="""
SuperGlue model taking images as inputs and outputting the matching of them.
"""
)
class SuperGlueForKeypointMatching(SuperGluePreTrainedModel):
"""SuperGlue feature matching middle-end
Given two sets of keypoints and locations, we determine the
correspondences by:
1. Keypoint Encoding (normalization + visual feature and location fusion)
2. Graph Neural Network with multiple self and cross-attention layers
3. Final projection layer
4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
5. Thresholding matrix based on mutual exclusivity and a match_threshold
The correspondence ids use -1 to indicate non-matching points.
Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
Networks. In CVPR, 2020. https://huggingface.co/papers/1911.11763
"""
def __init__(self, config: SuperGlueConfig) -> None:
super().__init__(config)
self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
self.keypoint_encoder = SuperGlueKeypointEncoder(config)
self.gnn = SuperGlueAttentionalGNN(config)
self.final_projection = SuperGlueFinalProjection(config)
bin_score = torch.nn.Parameter(torch.tensor(1.0))
self.register_parameter("bin_score", bin_score)
self.post_init()
def _match_image_pair(
self,
keypoints: torch.Tensor,
descriptors: torch.Tensor,
scores: torch.Tensor,
height: int,
width: int,
mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> tuple[torch.Tensor, torch.Tensor, tuple, tuple]:
"""
Perform keypoint matching between two images.
Args:
keypoints (`torch.Tensor` of shape `(batch_size, 2, num_keypoints, 2)`):
Keypoints detected in the pair of image.
descriptors (`torch.Tensor` of shape `(batch_size, 2, descriptor_dim, num_keypoints)`):
Descriptors of the keypoints detected in the image pair.
scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
Confidence scores of the keypoints detected in the image pair.
height (`int`): Image height.
width (`int`): Image width.
mask (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`, *optional*):
Mask indicating which values in the keypoints, matches and matching_scores tensors are keypoint matching
information.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors. Default to `config.output_attentions`.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. Default to `config.output_hidden_states`.
Returns:
matches (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
For each image pair, for each keypoint in image0, the index of the keypoint in image1 that was matched
with. And for each keypoint in image1, the index of the keypoint in image0 that was matched with.
matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
Scores of predicted matches for each image pair
all_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(1, 2, num_keypoints,
num_channels)`.
all_attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(1, 2, num_heads, num_keypoints,
num_keypoints)`.
"""
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
if keypoints.shape[2] == 0: # no keypoints
shape = keypoints.shape[:-1]
return (
keypoints.new_full(shape, -1, dtype=torch.int),
keypoints.new_zeros(shape),
all_hidden_states,
all_attentions,
)
batch_size, _, num_keypoints, _ = keypoints.shape
# (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
keypoints = keypoints.reshape(batch_size * 2, num_keypoints, 2)
descriptors = descriptors.reshape(batch_size * 2, num_keypoints, self.config.hidden_size)
scores = scores.reshape(batch_size * 2, num_keypoints)
mask = mask.reshape(batch_size * 2, num_keypoints) if mask is not None else None
# Keypoint normalization
keypoints = normalize_keypoints(keypoints, height, width)
encoded_keypoints = self.keypoint_encoder(keypoints, scores, output_hidden_states=output_hidden_states)
last_hidden_state = encoded_keypoints[0]
# Keypoint MLP encoder.
descriptors = descriptors + last_hidden_state
if mask is not None:
input_shape = descriptors.size()
extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
else:
extended_attention_mask = torch.ones((batch_size, num_keypoints), device=keypoints.device)
# Multi-layer Transformer network.
gnn_outputs = self.gnn(
descriptors,
mask=extended_attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
descriptors = gnn_outputs[0]
# Final MLP projection.
projected_descriptors = self.final_projection(descriptors)
# (batch_size * 2, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
final_descriptors = projected_descriptors.reshape(batch_size, 2, num_keypoints, self.config.hidden_size)
final_descriptors0 = final_descriptors[:, 0]
final_descriptors1 = final_descriptors[:, 1]
# Compute matching descriptor distance.
scores = final_descriptors0 @ final_descriptors1.transpose(1, 2)
scores = scores / self.config.hidden_size**0.5
if mask is not None:
mask = mask.reshape(batch_size, 2, num_keypoints)
mask0 = mask[:, 0].unsqueeze(-1).expand(-1, -1, num_keypoints)
scores = scores.masked_fill(mask0 == 0, -1e9)
# Run the optimal transport.
scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations)
# Get the matches with score above "match_threshold".
max0 = scores[:, :-1, :-1].max(2)
max1 = scores[:, :-1, :-1].max(1)
indices0 = max0.indices
indices1 = max1.indices
mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
zero = scores.new_tensor(0)
matching_scores0 = torch.where(mutual0, max0.values.exp(), zero)
matching_scores0 = torch.where(matching_scores0 > self.config.matching_threshold, matching_scores0, zero)
matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, indices1), zero)
valid0 = mutual0 & (matching_scores0 > zero)
valid1 = mutual1 & valid0.gather(1, indices1)
matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1)
matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1)
if output_hidden_states:
all_hidden_states = all_hidden_states + encoded_keypoints[1]
all_hidden_states = all_hidden_states + gnn_outputs[1]
all_hidden_states = all_hidden_states + (projected_descriptors,)
all_hidden_states = tuple(
x.reshape(batch_size, 2, num_keypoints, -1).transpose(-1, -2) for x in all_hidden_states
)
if output_attentions:
all_attentions = all_attentions + gnn_outputs[2]
all_attentions = tuple(x.reshape(batch_size, 2, -1, num_keypoints, num_keypoints) for x in all_attentions)
return (
matches,
matching_scores,
all_hidden_states,
all_attentions,
)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, KeypointMatchingOutput]:
r"""
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoModel
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
>>> image1 = Image.open(requests.get(url, stream=True).raw)
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
>>> image2 = Image.open(requests.get(url, stream=True).raw)
>>> images = [image1, image2]
>>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
>>> model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
>>> with torch.no_grad():
>>> inputs = processor(images, return_tensors="pt")
>>> outputs = model(**inputs)
```"""
loss = None
if labels is not None:
raise ValueError("SuperGlue is not trainable, no labels should be provided.")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
batch_size, _, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
keypoint_detections = self.keypoint_detector(pixel_values)
keypoints, scores, descriptors, mask = keypoint_detections[:4]
keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
scores = scores.reshape(batch_size, 2, -1).to(pixel_values)
descriptors = descriptors.reshape(batch_size, 2, -1, self.config.hidden_size).to(pixel_values)
mask = mask.reshape(batch_size, 2, -1)
absolute_keypoints = keypoints.clone()
absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
matches, matching_scores, hidden_states, attentions = self._match_image_pair(
absolute_keypoints,
descriptors,
scores,
height,
width,
mask=mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if not return_dict:
return tuple(
v
for v in [loss, matches, matching_scores, keypoints, mask, hidden_states, attentions]
if v is not None
)
return KeypointMatchingOutput(
loss=loss,
matches=matches,
matching_scores=matching_scores,
keypoints=keypoints,
mask=mask,
hidden_states=hidden_states,
attentions=attentions,
)
__all__ = ["SuperGluePreTrainedModel", "SuperGlueForKeypointMatching"]