# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch SuperGlue model.""" import math from dataclasses import dataclass from typing import Optional, Union import torch from torch import nn from transformers import PreTrainedModel from transformers.models.superglue.configuration_superglue import SuperGlueConfig from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging from ..auto import AutoModelForKeypointDetection logger = logging.get_logger(__name__) def concat_pairs(tensor_tuple0: tuple[torch.Tensor], tensor_tuple1: tuple[torch.Tensor]) -> tuple[torch.Tensor]: """ Concatenate two tuples of tensors pairwise Args: tensor_tuple0 (`tuple[torch.Tensor]`): Tuple of tensors. tensor_tuple1 (`tuple[torch.Tensor]`): Tuple of tensors. Returns: (`tuple[torch.Tensor]`): Tuple of concatenated tensors. """ return tuple([torch.cat([tensor0, tensor1]) for tensor0, tensor1 in zip(tensor_tuple0, tensor_tuple1)]) def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: """ Normalize keypoints locations based on image image_shape Args: keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): Keypoints locations in (x, y) format. height (`int`): Image height. width (`int`): Image width. Returns: Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). """ size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] center = size / 2 scaling = size.max(1, keepdim=True).values * 0.7 return (keypoints - center[:, None, :]) / scaling[:, None, :] def log_sinkhorn_iterations( log_cost_matrix: torch.Tensor, log_source_distribution: torch.Tensor, log_target_distribution: torch.Tensor, num_iterations: int, ) -> torch.Tensor: """ Perform Sinkhorn Normalization in Log-space for stability Args: log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the cost matrix. log_source_distribution (`torch.Tensor` of shape `(batch_size, num_rows)`): Logarithm of the source distribution. log_target_distribution (`torch.Tensor` of shape `(batch_size, num_columns)`): Logarithm of the target distribution. Returns: log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the optimal transport matrix. """ log_u_scaling = torch.zeros_like(log_source_distribution) log_v_scaling = torch.zeros_like(log_target_distribution) for _ in range(num_iterations): log_u_scaling = log_source_distribution - torch.logsumexp(log_cost_matrix + log_v_scaling.unsqueeze(1), dim=2) log_v_scaling = log_target_distribution - torch.logsumexp(log_cost_matrix + log_u_scaling.unsqueeze(2), dim=1) return log_cost_matrix + log_u_scaling.unsqueeze(2) + log_v_scaling.unsqueeze(1) def log_optimal_transport(scores: torch.Tensor, reg_param: torch.Tensor, iterations: int) -> torch.Tensor: """ Perform Differentiable Optimal Transport in Log-space for stability Args: scores: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Cost matrix. reg_param: (`torch.Tensor` of shape `(batch_size, 1, 1)`): Regularization parameter. iterations: (`int`): Number of Sinkhorn iterations. Returns: log_optimal_transport_matrix: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the optimal transport matrix. """ batch_size, num_rows, num_columns = scores.shape one_tensor = scores.new_tensor(1) num_rows_tensor, num_columns_tensor = (num_rows * one_tensor).to(scores), (num_columns * one_tensor).to(scores) source_reg_param = reg_param.expand(batch_size, num_rows, 1) target_reg_param = reg_param.expand(batch_size, 1, num_columns) reg_param = reg_param.expand(batch_size, 1, 1) couplings = torch.cat([torch.cat([scores, source_reg_param], -1), torch.cat([target_reg_param, reg_param], -1)], 1) log_normalization = -(num_rows_tensor + num_columns_tensor).log() log_source_distribution = torch.cat( [log_normalization.expand(num_rows), num_columns_tensor.log()[None] + log_normalization] ) log_target_distribution = torch.cat( [log_normalization.expand(num_columns), num_rows_tensor.log()[None] + log_normalization] ) log_source_distribution, log_target_distribution = ( log_source_distribution[None].expand(batch_size, -1), log_target_distribution[None].expand(batch_size, -1), ) log_optimal_transport_matrix = log_sinkhorn_iterations( couplings, log_source_distribution, log_target_distribution, num_iterations=iterations ) log_optimal_transport_matrix = log_optimal_transport_matrix - log_normalization # multiply probabilities by M+N return log_optimal_transport_matrix def arange_like(x, dim: int) -> torch.Tensor: return x.new_ones(x.shape[dim]).cumsum(0) - 1 @dataclass @auto_docstring( custom_intro=""" Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching information. """ ) class KeypointMatchingOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*): Loss computed during training. matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): Index of keypoint matched in the other image. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): Scores of predicted matches. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): Absolute (x, y) coordinates of predicted keypoints in a given image. mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): Mask indicating which values in matches and matching_scores are keypoint matching information. hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*): Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, num_keypoints)`, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) attentions (`tuple[torch.FloatTensor, ...]`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) """ loss: Optional[torch.FloatTensor] = None matches: Optional[torch.FloatTensor] = None matching_scores: Optional[torch.FloatTensor] = None keypoints: Optional[torch.FloatTensor] = None mask: Optional[torch.IntTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None class SuperGlueMultiLayerPerceptron(nn.Module): def __init__(self, config: SuperGlueConfig, in_channels: int, out_channels: int) -> None: super().__init__() self.linear = nn.Linear(in_channels, out_channels) self.batch_norm = nn.BatchNorm1d(out_channels) self.activation = nn.ReLU() def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.linear(hidden_state) hidden_state = hidden_state.transpose(-1, -2) hidden_state = self.batch_norm(hidden_state) hidden_state = hidden_state.transpose(-1, -2) hidden_state = self.activation(hidden_state) return hidden_state class SuperGlueKeypointEncoder(nn.Module): def __init__(self, config: SuperGlueConfig) -> None: super().__init__() layer_sizes = config.keypoint_encoder_sizes hidden_size = config.hidden_size # 3 here consists of 2 for the (x, y) coordinates and 1 for the score of the keypoint encoder_channels = [3] + layer_sizes + [hidden_size] layers = [ SuperGlueMultiLayerPerceptron(config, encoder_channels[i - 1], encoder_channels[i]) for i in range(1, len(encoder_channels) - 1) ] layers.append(nn.Linear(encoder_channels[-2], encoder_channels[-1])) self.encoder = nn.ModuleList(layers) def forward( self, keypoints: torch.Tensor, scores: torch.Tensor, output_hidden_states: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]]]: scores = scores.unsqueeze(2) hidden_state = torch.cat([keypoints, scores], dim=2) all_hidden_states = () if output_hidden_states else None for layer in self.encoder: hidden_state = layer(hidden_state) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) return hidden_state, all_hidden_states class SuperGlueSelfAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = position_embedding_type or getattr( config, "position_embedding_type", "absolute" ) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else attention_mask batch_size = hidden_states.shape[0] key_layer = ( self.key(current_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) value_layer = ( self.value(current_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) query_layer = ( self.query(hidden_states) .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in SuperGlueModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: outputs = outputs + (None,) return outputs class SuperGlueSelfOutput(nn.Module): def __init__(self, config: SuperGlueConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor: hidden_states = self.dense(hidden_states) return hidden_states SUPERGLUE_SELF_ATTENTION_CLASSES = { "eager": SuperGlueSelfAttention, } class SuperGlueAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation]( config, position_embedding_type=position_embedding_type, ) self.output = SuperGlueSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class SuperGlueAttentionalPropagation(nn.Module): def __init__(self, config: SuperGlueConfig) -> None: super().__init__() hidden_size = config.hidden_size self.attention = SuperGlueAttention(config) mlp_channels = [hidden_size * 2, hidden_size * 2, hidden_size] layers = [ SuperGlueMultiLayerPerceptron(config, mlp_channels[i - 1], mlp_channels[i]) for i in range(1, len(mlp_channels) - 1) ] layers.append(nn.Linear(mlp_channels[-2], mlp_channels[-1])) self.mlp = nn.ModuleList(layers) def forward( self, descriptors: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]: attention_outputs = self.attention( descriptors, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) output = attention_outputs[0] attention = attention_outputs[1:] hidden_state = torch.cat([descriptors, output], dim=2) all_hidden_states = () if output_hidden_states else None for layer in self.mlp: hidden_state = layer(hidden_state) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) return hidden_state, all_hidden_states, attention class SuperGlueAttentionalGNN(nn.Module): def __init__(self, config: SuperGlueConfig) -> None: super().__init__() self.hidden_size = config.hidden_size self.layers_types = config.gnn_layers_types self.layers = nn.ModuleList([SuperGlueAttentionalPropagation(config) for _ in range(len(self.layers_types))]) def forward( self, descriptors: torch.Tensor, mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[tuple], Optional[tuple]]: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None batch_size, num_keypoints, _ = descriptors.shape if output_hidden_states: all_hidden_states = all_hidden_states + (descriptors,) for gnn_layer, layer_type in zip(self.layers, self.layers_types): encoder_hidden_states = None encoder_attention_mask = None if layer_type == "cross": encoder_hidden_states = ( descriptors.reshape(-1, 2, num_keypoints, self.hidden_size) .flip(1) .reshape(batch_size, num_keypoints, self.hidden_size) ) encoder_attention_mask = ( mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) if mask is not None else None ) gnn_outputs = gnn_layer( descriptors, attention_mask=mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) delta = gnn_outputs[0] if output_hidden_states: all_hidden_states = all_hidden_states + gnn_outputs[1] if output_attentions: all_attentions = all_attentions + gnn_outputs[2] descriptors = descriptors + delta return descriptors, all_hidden_states, all_attentions class SuperGlueFinalProjection(nn.Module): def __init__(self, config: SuperGlueConfig) -> None: super().__init__() hidden_size = config.hidden_size self.final_proj = nn.Linear(hidden_size, hidden_size, bias=True) def forward(self, descriptors: torch.Tensor) -> torch.Tensor: return self.final_proj(descriptors) @auto_docstring class SuperGluePreTrainedModel(PreTrainedModel): config: SuperGlueConfig base_model_prefix = "superglue" main_input_name = "pixel_values" def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.BatchNorm1d): module.bias.data.zero_() module.weight.data.fill_(1.0) if hasattr(module, "bin_score"): module.bin_score.data.fill_(1.0) @auto_docstring( custom_intro=""" SuperGlue model taking images as inputs and outputting the matching of them. """ ) class SuperGlueForKeypointMatching(SuperGluePreTrainedModel): """SuperGlue feature matching middle-end Given two sets of keypoints and locations, we determine the correspondences by: 1. Keypoint Encoding (normalization + visual feature and location fusion) 2. Graph Neural Network with multiple self and cross-attention layers 3. Final projection layer 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm) 5. Thresholding matrix based on mutual exclusivity and a match_threshold The correspondence ids use -1 to indicate non-matching points. Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural Networks. In CVPR, 2020. https://huggingface.co/papers/1911.11763 """ def __init__(self, config: SuperGlueConfig) -> None: super().__init__(config) self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config) self.keypoint_encoder = SuperGlueKeypointEncoder(config) self.gnn = SuperGlueAttentionalGNN(config) self.final_projection = SuperGlueFinalProjection(config) bin_score = torch.nn.Parameter(torch.tensor(1.0)) self.register_parameter("bin_score", bin_score) self.post_init() def _match_image_pair( self, keypoints: torch.Tensor, descriptors: torch.Tensor, scores: torch.Tensor, height: int, width: int, mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> tuple[torch.Tensor, torch.Tensor, tuple, tuple]: """ Perform keypoint matching between two images. Args: keypoints (`torch.Tensor` of shape `(batch_size, 2, num_keypoints, 2)`): Keypoints detected in the pair of image. descriptors (`torch.Tensor` of shape `(batch_size, 2, descriptor_dim, num_keypoints)`): Descriptors of the keypoints detected in the image pair. scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): Confidence scores of the keypoints detected in the image pair. height (`int`): Image height. width (`int`): Image width. mask (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`, *optional*): Mask indicating which values in the keypoints, matches and matching_scores tensors are keypoint matching information. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors. Default to `config.output_attentions`. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. Default to `config.output_hidden_states`. Returns: matches (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): For each image pair, for each keypoint in image0, the index of the keypoint in image1 that was matched with. And for each keypoint in image1, the index of the keypoint in image0 that was matched with. matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): Scores of predicted matches for each image pair all_hidden_states (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(1, 2, num_keypoints, num_channels)`. all_attentions (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(1, 2, num_heads, num_keypoints, num_keypoints)`. """ all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None if keypoints.shape[2] == 0: # no keypoints shape = keypoints.shape[:-1] return ( keypoints.new_full(shape, -1, dtype=torch.int), keypoints.new_zeros(shape), all_hidden_states, all_attentions, ) batch_size, _, num_keypoints, _ = keypoints.shape # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) keypoints = keypoints.reshape(batch_size * 2, num_keypoints, 2) descriptors = descriptors.reshape(batch_size * 2, num_keypoints, self.config.hidden_size) scores = scores.reshape(batch_size * 2, num_keypoints) mask = mask.reshape(batch_size * 2, num_keypoints) if mask is not None else None # Keypoint normalization keypoints = normalize_keypoints(keypoints, height, width) encoded_keypoints = self.keypoint_encoder(keypoints, scores, output_hidden_states=output_hidden_states) last_hidden_state = encoded_keypoints[0] # Keypoint MLP encoder. descriptors = descriptors + last_hidden_state if mask is not None: input_shape = descriptors.size() extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) else: extended_attention_mask = torch.ones((batch_size, num_keypoints), device=keypoints.device) # Multi-layer Transformer network. gnn_outputs = self.gnn( descriptors, mask=extended_attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) descriptors = gnn_outputs[0] # Final MLP projection. projected_descriptors = self.final_projection(descriptors) # (batch_size * 2, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) final_descriptors = projected_descriptors.reshape(batch_size, 2, num_keypoints, self.config.hidden_size) final_descriptors0 = final_descriptors[:, 0] final_descriptors1 = final_descriptors[:, 1] # Compute matching descriptor distance. scores = final_descriptors0 @ final_descriptors1.transpose(1, 2) scores = scores / self.config.hidden_size**0.5 if mask is not None: mask = mask.reshape(batch_size, 2, num_keypoints) mask0 = mask[:, 0].unsqueeze(-1).expand(-1, -1, num_keypoints) scores = scores.masked_fill(mask0 == 0, -1e9) # Run the optimal transport. scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations) # Get the matches with score above "match_threshold". max0 = scores[:, :-1, :-1].max(2) max1 = scores[:, :-1, :-1].max(1) indices0 = max0.indices indices1 = max1.indices mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) zero = scores.new_tensor(0) matching_scores0 = torch.where(mutual0, max0.values.exp(), zero) matching_scores0 = torch.where(matching_scores0 > self.config.matching_threshold, matching_scores0, zero) matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, indices1), zero) valid0 = mutual0 & (matching_scores0 > zero) valid1 = mutual1 & valid0.gather(1, indices1) matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1) matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1) if output_hidden_states: all_hidden_states = all_hidden_states + encoded_keypoints[1] all_hidden_states = all_hidden_states + gnn_outputs[1] all_hidden_states = all_hidden_states + (projected_descriptors,) all_hidden_states = tuple( x.reshape(batch_size, 2, num_keypoints, -1).transpose(-1, -2) for x in all_hidden_states ) if output_attentions: all_attentions = all_attentions + gnn_outputs[2] all_attentions = tuple(x.reshape(batch_size, 2, -1, num_keypoints, num_keypoints) for x in all_attentions) return ( matches, matching_scores, all_hidden_states, all_attentions, ) @auto_docstring def forward( self, pixel_values: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, KeypointMatchingOutput]: r""" Examples: ```python >>> from transformers import AutoImageProcessor, AutoModel >>> import torch >>> from PIL import Image >>> import requests >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true" >>> image1 = Image.open(requests.get(url, stream=True).raw) >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true" >>> image2 = Image.open(requests.get(url, stream=True).raw) >>> images = [image1, image2] >>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor") >>> model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor") >>> with torch.no_grad(): >>> inputs = processor(images, return_tensors="pt") >>> outputs = model(**inputs) ```""" loss = None if labels is not None: raise ValueError("SuperGlue is not trainable, no labels should be provided.") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values.ndim != 5 or pixel_values.size(1) != 2: raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") batch_size, _, channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) keypoint_detections = self.keypoint_detector(pixel_values) keypoints, scores, descriptors, mask = keypoint_detections[:4] keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) scores = scores.reshape(batch_size, 2, -1).to(pixel_values) descriptors = descriptors.reshape(batch_size, 2, -1, self.config.hidden_size).to(pixel_values) mask = mask.reshape(batch_size, 2, -1) absolute_keypoints = keypoints.clone() absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height matches, matching_scores, hidden_states, attentions = self._match_image_pair( absolute_keypoints, descriptors, scores, height, width, mask=mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if not return_dict: return tuple( v for v in [loss, matches, matching_scores, keypoints, mask, hidden_states, attentions] if v is not None ) return KeypointMatchingOutput( loss=loss, matches=matches, matching_scores=matching_scores, keypoints=keypoints, mask=mask, hidden_states=hidden_states, attentions=attentions, ) __all__ = ["SuperGluePreTrainedModel", "SuperGlueForKeypointMatching"]