# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Callable, Optional, Union import torch from torch import nn from ...activations import ACT2CLS, ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_int, ) from ...utils.generic import check_model_inputs from .configuration_efficientloftr import EfficientLoFTRConfig @dataclass @auto_docstring( custom_intro=""" Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching information. """ ) class KeypointMatchingOutput(ModelOutput): r""" matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): Index of keypoint matched in the other image. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): Scores of predicted matches. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): Absolute (x, y) coordinates of predicted keypoints in a given image. hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*): Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, num_keypoints)`, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) attentions (`tuple[torch.FloatTensor, ...]`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) """ matches: Optional[torch.FloatTensor] = None matching_scores: Optional[torch.FloatTensor] = None keypoints: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None class EfficientLoFTRRotaryEmbedding(nn.Module): def __init__(self, config: EfficientLoFTRConfig, device=None): super().__init__() self.config = config self.rope_type = config.rope_scaling["rope_type"] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, _ = self.rope_init_fn(self.config, device) inv_freq_expanded = inv_freq[None, None, None, :].float().expand(1, 1, 1, -1) embed_height, embed_width = config.embedding_size i_indices = torch.ones(embed_height, embed_width).cumsum(0).float().unsqueeze(-1) j_indices = torch.ones(embed_height, embed_width).cumsum(1).float().unsqueeze(-1) emb = torch.zeros(1, embed_height, embed_width, self.config.hidden_size // 2) emb[:, :, :, 0::2] = i_indices * inv_freq_expanded emb[:, :, :, 1::2] = j_indices * inv_freq_expanded self.register_buffer("inv_freq", emb, persistent=False) @torch.no_grad() def forward( self, x: torch.Tensor, position_ids: Optional[tuple[torch.LongTensor, torch.LongTensor]] = None ) -> tuple[torch.Tensor, torch.Tensor]: device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 emb = self.inv_freq sin = emb.sin() cos = emb.cos() sin = sin.repeat_interleave(2, dim=-1) cos = cos.repeat_interleave(2, dim=-1) sin = sin.to(device=x.device, dtype=x.dtype) cos = cos.to(device=x.device, dtype=x.dtype) return cos, sin # Copied from transformers.models.rt_detr_v2.modeling_rt_detr_v2.RTDetrV2ConvNormLayer with RTDetrV2->EfficientLoFTR class EfficientLoFTRConvNormLayer(nn.Module): def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): super().__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding=(kernel_size - 1) // 2 if padding is None else padding, bias=False, ) self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() def forward(self, hidden_state): hidden_state = self.conv(hidden_state) hidden_state = self.norm(hidden_state) hidden_state = self.activation(hidden_state) return hidden_state class EfficientLoFTRRepVGGBlock(GradientCheckpointingLayer): """ RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". """ def __init__(self, config: EfficientLoFTRConfig, stage_idx: int, block_idx: int): super().__init__() in_channels = config.stage_block_in_channels[stage_idx][block_idx] out_channels = config.stage_block_out_channels[stage_idx][block_idx] stride = config.stage_block_stride[stage_idx][block_idx] activation = config.activation_function self.conv1 = EfficientLoFTRConvNormLayer( config, in_channels, out_channels, kernel_size=3, stride=stride, padding=1 ) self.conv2 = EfficientLoFTRConvNormLayer( config, in_channels, out_channels, kernel_size=1, stride=stride, padding=0 ) self.identity = nn.BatchNorm2d(in_channels) if in_channels == out_channels and stride == 1 else None self.activation = nn.Identity() if activation is None else ACT2FN[activation] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.identity is not None: identity_out = self.identity(hidden_states) else: identity_out = 0 hidden_states = self.conv1(hidden_states) + self.conv2(hidden_states) + identity_out hidden_states = self.activation(hidden_states) return hidden_states class EfficientLoFTRRepVGGStage(nn.Module): def __init__(self, config: EfficientLoFTRConfig, stage_idx: int): super().__init__() self.blocks = nn.ModuleList([]) for block_idx in range(config.stage_num_blocks[stage_idx]): self.blocks.append( EfficientLoFTRRepVGGBlock( config, stage_idx, block_idx, ) ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for block in self.blocks: hidden_states = block(hidden_states) return hidden_states class EfficientLoFTRepVGG(nn.Module): def __init__(self, config: EfficientLoFTRConfig): super().__init__() self.stages = nn.ModuleList([]) for stage_idx in range(len(config.stage_stride)): stage = EfficientLoFTRRepVGGStage(config, stage_idx) self.stages.append(stage) def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]: outputs = [] for stage in self.stages: hidden_states = stage(hidden_states) outputs.append(hidden_states) # Exclude first stage in outputs outputs = outputs[1:] return outputs class EfficientLoFTRAggregationLayer(nn.Module): def __init__(self, config: EfficientLoFTRConfig): super().__init__() hidden_size = config.hidden_size self.q_aggregation = nn.Conv2d( hidden_size, hidden_size, kernel_size=config.q_aggregation_kernel_size, padding=0, stride=config.q_aggregation_stride, bias=False, groups=hidden_size, ) self.kv_aggregation = torch.nn.MaxPool2d( kernel_size=config.kv_aggregation_kernel_size, stride=config.kv_aggregation_stride ) self.norm = nn.LayerNorm(hidden_size) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: query_states = hidden_states is_cross_attention = encoder_hidden_states is not None kv_states = encoder_hidden_states if is_cross_attention else hidden_states query_states = self.q_aggregation(query_states) kv_states = self.kv_aggregation(kv_states) query_states = query_states.permute(0, 2, 3, 1) kv_states = kv_states.permute(0, 2, 3, 1) hidden_states = self.norm(query_states) encoder_hidden_states = self.norm(kv_states) return hidden_states, encoder_hidden_states # Copied from transformers.models.cohere.modeling_cohere.rotate_half def rotate_half(x): # Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] x2 = x[..., 1::2] rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) return rot_x # Copied from transformers.models.cohere.modeling_cohere.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ dtype = q.dtype q = q.float() k = k.float() cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) # Copied from transformers.models.cohere.modeling_cohere.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) # Copied from transformers.models.llama.modeling_llama.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class EfficientLoFTRAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" # Copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->EfficientLoFTR def __init__(self, config: EfficientLoFTRConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: batch_size, seq_len, dim = hidden_states.shape input_shape = hidden_states.shape[:-1] query_states = self.q_proj(hidden_states).view(batch_size, seq_len, -1, dim) is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states key_states = self.k_proj(current_states).view(batch_size, seq_len, -1, dim) value_states = self.v_proj(current_states).view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is not None: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2) query_states = query_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask=None, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class EfficientLoFTRMLP(nn.Module): def __init__(self, config: EfficientLoFTRConfig): super().__init__() hidden_size = config.hidden_size intermediate_size = config.intermediate_size self.fc1 = nn.Linear(hidden_size * 2, intermediate_size, bias=False) self.activation = ACT2FN[config.mlp_activation_function] self.fc2 = nn.Linear(intermediate_size, hidden_size, bias=False) self.layer_norm = nn.LayerNorm(hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = self.layer_norm(hidden_states) return hidden_states class EfficientLoFTRAggregatedAttention(nn.Module): def __init__(self, config: EfficientLoFTRConfig, layer_idx: int): super().__init__() self.q_aggregation_kernel_size = config.q_aggregation_kernel_size self.aggregation = EfficientLoFTRAggregationLayer(config) self.attention = EfficientLoFTRAttention(config, layer_idx) self.mlp = EfficientLoFTRMLP(config) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: batch_size, embed_dim, _, _ = hidden_states.shape # Aggregate features aggregated_hidden_states, aggregated_encoder_hidden_states = self.aggregation( hidden_states, encoder_hidden_states ) _, aggregated_h, aggregated_w, _ = aggregated_hidden_states.shape # Multi-head attention aggregated_hidden_states = aggregated_hidden_states.reshape(batch_size, -1, embed_dim) aggregated_encoder_hidden_states = aggregated_encoder_hidden_states.reshape(batch_size, -1, embed_dim) attn_output, _ = self.attention( aggregated_hidden_states, aggregated_encoder_hidden_states, position_embeddings=position_embeddings, **kwargs, ) # Upsample features # (batch_size, seq_len, embed_dim) -> (batch_size, embed_dim, h, w) with seq_len = h * w attn_output = attn_output.permute(0, 2, 1) attn_output = attn_output.reshape(batch_size, embed_dim, aggregated_h, aggregated_w) attn_output = torch.nn.functional.interpolate( attn_output, scale_factor=self.q_aggregation_kernel_size, mode="bilinear", align_corners=False ) intermediate_states = torch.cat([hidden_states, attn_output], dim=1) intermediate_states = intermediate_states.permute(0, 2, 3, 1) output_states = self.mlp(intermediate_states) output_states = output_states.permute(0, 3, 1, 2) hidden_states = hidden_states + output_states return hidden_states class EfficientLoFTRLocalFeatureTransformerLayer(GradientCheckpointingLayer): def __init__(self, config: EfficientLoFTRConfig, layer_idx: int): super().__init__() self.self_attention = EfficientLoFTRAggregatedAttention(config, layer_idx) self.cross_attention = EfficientLoFTRAggregatedAttention(config, layer_idx) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: batch_size, _, embed_dim, height, width = hidden_states.shape hidden_states = hidden_states.reshape(-1, embed_dim, height, width) hidden_states = self.self_attention(hidden_states, position_embeddings=position_embeddings, **kwargs) encoder_hidden_states = hidden_states.reshape(-1, 2, embed_dim, height, width) encoder_hidden_states = encoder_hidden_states.flip(1) encoder_hidden_states = encoder_hidden_states.reshape(-1, embed_dim, height, width) hidden_states = self.cross_attention(hidden_states, encoder_hidden_states, **kwargs) hidden_states = hidden_states.reshape(batch_size, -1, embed_dim, height, width) return hidden_states class EfficientLoFTRLocalFeatureTransformer(nn.Module): def __init__(self, config: EfficientLoFTRConfig): super().__init__() self.layers = nn.ModuleList( [ EfficientLoFTRLocalFeatureTransformerLayer(config, layer_idx=i) for i in range(config.num_attention_layers) ] ) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: for layer in self.layers: hidden_states = layer(hidden_states, position_embeddings=position_embeddings, **kwargs) return hidden_states class EfficientLoFTROutConvBlock(nn.Module): def __init__(self, config: EfficientLoFTRConfig, hidden_size: int, intermediate_size: int): super().__init__() self.out_conv1 = nn.Conv2d(hidden_size, intermediate_size, kernel_size=1, stride=1, padding=0, bias=False) self.out_conv2 = nn.Conv2d( intermediate_size, intermediate_size, kernel_size=3, stride=1, padding=1, bias=False ) self.batch_norm = nn.BatchNorm2d(intermediate_size) self.activation = ACT2CLS[config.mlp_activation_function]() self.out_conv3 = nn.Conv2d(intermediate_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=False) def forward(self, hidden_states: torch.Tensor, residual_states: torch.Tensor) -> torch.Tensor: residual_states = self.out_conv1(residual_states) residual_states = residual_states + hidden_states residual_states = self.out_conv2(residual_states) residual_states = self.batch_norm(residual_states) residual_states = self.activation(residual_states) residual_states = self.out_conv3(residual_states) residual_states = nn.functional.interpolate( residual_states, scale_factor=2.0, mode="bilinear", align_corners=False ) return residual_states class EfficientLoFTRFineFusionLayer(nn.Module): def __init__(self, config: EfficientLoFTRConfig): super().__init__() self.fine_kernel_size = config.fine_kernel_size fine_fusion_dims = config.fine_fusion_dims self.out_conv = nn.Conv2d( fine_fusion_dims[0], fine_fusion_dims[0], kernel_size=1, stride=1, padding=0, bias=False ) self.out_conv_layers = nn.ModuleList() for i in range(1, len(fine_fusion_dims)): out_conv = EfficientLoFTROutConvBlock(config, fine_fusion_dims[i], fine_fusion_dims[i - 1]) self.out_conv_layers.append(out_conv) def forward_pyramid( self, hidden_states: torch.Tensor, residual_states: list[torch.Tensor], ) -> torch.Tensor: hidden_states = self.out_conv(hidden_states) hidden_states = nn.functional.interpolate( hidden_states, scale_factor=2.0, mode="bilinear", align_corners=False ) for i, layer in enumerate(self.out_conv_layers): hidden_states = layer(hidden_states, residual_states[i]) return hidden_states def forward( self, coarse_features: torch.Tensor, residual_features: list[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: """ For each image pair, compute the fine features of pixels. In both images, compute a patch of fine features center cropped around each coarse pixel. In the first image, the feature patch is kernel_size large and long. In the second image, it is (kernel_size + 2) large and long. """ batch_size, _, embed_dim, coarse_height, coarse_width = coarse_features.shape coarse_features = coarse_features.reshape(-1, embed_dim, coarse_height, coarse_width) residual_features = list(reversed(residual_features)) # 1. Fine feature extraction fine_features = self.forward_pyramid(coarse_features, residual_features) _, fine_embed_dim, fine_height, fine_width = fine_features.shape fine_features = fine_features.reshape(batch_size, 2, fine_embed_dim, fine_height, fine_width) fine_features_0 = fine_features[:, 0] fine_features_1 = fine_features[:, 1] # 2. Unfold all local windows in crops stride = int(fine_height // coarse_height) fine_features_0 = nn.functional.unfold( fine_features_0, kernel_size=self.fine_kernel_size, stride=stride, padding=0 ) _, _, seq_len = fine_features_0.shape fine_features_0 = fine_features_0.reshape(batch_size, -1, self.fine_kernel_size**2, seq_len) fine_features_0 = fine_features_0.permute(0, 3, 2, 1) fine_features_1 = nn.functional.unfold( fine_features_1, kernel_size=self.fine_kernel_size + 2, stride=stride, padding=1 ) fine_features_1 = fine_features_1.reshape(batch_size, -1, (self.fine_kernel_size + 2) ** 2, seq_len) fine_features_1 = fine_features_1.permute(0, 3, 2, 1) return fine_features_0, fine_features_1 @auto_docstring class EfficientLoFTRPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = EfficientLoFTRConfig base_model_prefix = "efficientloftr" main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_flash_attn = True _supports_sdpa = True _can_record_outputs = { "hidden_states": EfficientLoFTRRepVGGBlock, "attentions": EfficientLoFTRAttention, } def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: """ Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same, extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for EfficientLoFTR. This is a workaround for the issue discussed in : https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 Args: pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width) Returns: pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width) """ return pixel_values[:, 0, :, :][:, None, :, :] @auto_docstring( custom_intro=""" EfficientLoFTR model taking images as inputs and outputting the features of the images. """ ) class EfficientLoFTRModel(EfficientLoFTRPreTrainedModel): def __init__(self, config: EfficientLoFTRConfig): super().__init__(config) self.config = config self.backbone = EfficientLoFTRepVGG(config) self.local_feature_transformer = EfficientLoFTRLocalFeatureTransformer(config) self.rotary_emb = EfficientLoFTRRotaryEmbedding(config=config) self.post_init() @check_model_inputs @auto_docstring def forward( self, pixel_values: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BackboneOutput: r""" Examples: ```python >>> from transformers import AutoImageProcessor, AutoModel >>> import torch >>> from PIL import Image >>> import requests >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true" >>> image1 = Image.open(requests.get(url, stream=True).raw) >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true" >>> image2 = Image.open(requests.get(url, stream=True).raw) >>> images = [image1, image2] >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr") >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr") >>> with torch.no_grad(): >>> inputs = processor(images, return_tensors="pt") >>> outputs = model(**inputs) ```""" if labels is not None: raise ValueError("EfficientLoFTR is not trainable, no labels should be provided.") if pixel_values.ndim != 5 or pixel_values.size(1) != 2: raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") batch_size, _, channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) pixel_values = self.extract_one_channel_pixel_values(pixel_values) # 1. Local Feature CNN features = self.backbone(pixel_values) # Last stage outputs are coarse outputs coarse_features = features[-1] # Rest is residual features used in EfficientLoFTRFineFusionLayer residual_features = features[:-1] coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:] # 2. Coarse-level LoFTR module cos, sin = self.rotary_emb(coarse_features) cos = cos.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim) sin = sin.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim) position_embeddings = (cos, sin) coarse_features = coarse_features.reshape(batch_size, 2, coarse_embed_dim, coarse_height, coarse_width) coarse_features = self.local_feature_transformer( coarse_features, position_embeddings=position_embeddings, **kwargs ) features = (coarse_features,) + tuple(residual_features) return BackboneOutput(feature_maps=features) def mask_border(tensor: torch.Tensor, border_margin: int, value: Union[bool, float, int]) -> torch.Tensor: """ Mask a tensor border with a given value Args: tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): The tensor to mask border_margin (`int`) : The size of the border value (`Union[bool, int, float]`): The value to place in the tensor's borders Returns: tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): The masked tensor """ if border_margin <= 0: return tensor tensor[:, :border_margin, :border_margin, :border_margin, :border_margin] = value tensor[:, -border_margin:, -border_margin:, -border_margin:, -border_margin:] = value return tensor def create_meshgrid( height: Union[int, torch.Tensor], width: Union[int, torch.Tensor], normalized_coordinates: bool = False, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """ Copied from kornia library : kornia/kornia/utils/grid.py:26 Generate a coordinate grid for an image. When the flag ``normalized_coordinates`` is set to True, the grid is normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch function :py:func:`torch.nn.functional.grid_sample`. Args: height (`int`): The image height (rows). width (`int`): The image width (cols). normalized_coordinates (`bool`): Whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the PyTorch function :py:func:`torch.nn.functional.grid_sample`. device (`torch.device`): The device on which the grid will be generated. dtype (`torch.dtype`): The data type of the generated grid. Return: grid (`torch.Tensor` of shape `(1, height, width, 2)`): The grid tensor. Example: >>> create_meshgrid(2, 2) tensor([[[[-1., -1.], [ 1., -1.]], [[-1., 1.], [ 1., 1.]]]]) >>> create_meshgrid(2, 2, normalized_coordinates=False) tensor([[[[0., 0.], [1., 0.]], [[0., 1.], [1., 1.]]]]) """ xs = torch.linspace(0, width - 1, width, device=device, dtype=dtype) ys = torch.linspace(0, height - 1, height, device=device, dtype=dtype) if normalized_coordinates: xs = (xs / (width - 1) - 0.5) * 2 ys = (ys / (height - 1) - 0.5) * 2 grid = torch.stack(torch.meshgrid(ys, xs, indexing="ij"), dim=-1) grid = grid.permute(1, 0, 2).unsqueeze(0) return grid def spatial_expectation2d(input: torch.Tensor, normalized_coordinates: bool = True) -> torch.Tensor: r""" Copied from kornia library : kornia/geometry/subpix/dsnt.py:76 Compute the expectation of coordinate values using spatial probabilities. The input heatmap is assumed to represent a valid spatial probability distribution, which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`. Args: input (`torch.Tensor` of shape `(batch_size, embed_dim, height, width)`): The input tensor representing dense spatial probabilities. normalized_coordinates (`bool`): Whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape. Returns: output (`torch.Tensor` of shape `(batch_size, embed_dim, 2)`) Expected value of the 2D coordinates. Output order of the coordinates is (x, y). Examples: >>> heatmaps = torch.tensor([[[ ... [0., 0., 0.], ... [0., 0., 0.], ... [0., 1., 0.]]]]) >>> spatial_expectation2d(heatmaps, False) tensor([[[1., 2.]]]) """ batch_size, embed_dim, height, width = input.shape # Create coordinates grid. grid = create_meshgrid(height, width, normalized_coordinates, input.device) grid = grid.to(input.dtype) pos_x = grid[..., 0].reshape(-1) pos_y = grid[..., 1].reshape(-1) input_flat = input.view(batch_size, embed_dim, -1) # Compute the expectation of the coordinates. expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True) expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True) output = torch.cat([expected_x, expected_y], -1) return output.view(batch_size, embed_dim, 2) @auto_docstring( custom_intro=""" EfficientLoFTR model taking images as inputs and outputting the matching of them. """ ) class EfficientLoFTRForKeypointMatching(EfficientLoFTRPreTrainedModel): """EfficientLoFTR dense image matcher Given two images, we determine the correspondences by: 1. Extracting coarse and fine features through a backbone 2. Transforming coarse features through self and cross attention 3. Matching coarse features to obtain coarse coordinates of matches 4. Obtaining full resolution fine features by fusing transformed and backbone coarse features 5. Refining the coarse matches using fine feature patches centered at each coarse match in a two-stage refinement Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou. Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed In CVPR, 2024. https://arxiv.org/abs/2403.04765 """ def __init__(self, config: EfficientLoFTRConfig): super().__init__(config) self.config = config self.efficientloftr = EfficientLoFTRModel(config) self.refinement_layer = EfficientLoFTRFineFusionLayer(config) self.post_init() def _get_matches_from_scores(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Based on a keypoint score matrix, compute the best keypoint matches between the first and second image. Since each image pair can have different number of matches, the matches are concatenated together for all pair in the batch and a batch_indices tensor is returned to specify which match belong to which element in the batch. Note: This step can be done as a postprocessing step, because does not involve any model weights/params. However, we keep it in the modeling code for consistency with other keypoint matching models AND for easier torch.compile/torch.export (all ops are in torch). Args: scores (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): Scores of keypoints Returns: matched_indices (`torch.Tensor` of shape `(2, num_matches)`): Indices representing which pixel in the first image matches which pixel in the second image matching_scores (`torch.Tensor` of shape `(num_matches,)`): Scores of each match """ batch_size, height0, width0, height1, width1 = scores.shape scores = scores.view(batch_size, height0 * width0, height1 * width1) # For each keypoint, get the best match max_0 = scores.max(2, keepdim=True).values max_1 = scores.max(1, keepdim=True).values # 1. Thresholding mask = scores > self.config.coarse_matching_threshold # 2. Border removal mask = mask.reshape(batch_size, height0, width0, height1, width1) mask = mask_border(mask, self.config.coarse_matching_border_removal, False) mask = mask.reshape(batch_size, height0 * width0, height1 * width1) # 3. Mutual nearest neighbors mask = mask * (scores == max_0) * (scores == max_1) # 4. Fine coarse matches masked_scores = scores * mask matching_scores_0, max_indices_0 = masked_scores.max(1) matching_scores_1, max_indices_1 = masked_scores.max(2) matching_indices = torch.cat([max_indices_0, max_indices_1]).reshape(batch_size, 2, -1) matching_scores = torch.stack([matching_scores_0, matching_scores_1], dim=1) # For the keypoints not meeting the threshold score, set the indices to -1 which corresponds to no matches found matching_indices = torch.where(matching_scores > 0, matching_indices, -1) return matching_indices, matching_scores def _coarse_matching( self, coarse_features: torch.Tensor, coarse_scale: float ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ For each image pair, compute the matching confidence between each coarse element (by default (image_height / 8) * (image_width / 8 elements)) from the first image to the second image. Note: This step can be done as a postprocessing step, because does not involve any model weights/params. However, we keep it in the modeling code for consistency with other keypoint matching models AND for easier torch.compile/torch.export (all ops are in torch). Args: coarse_features (`torch.Tensor` of shape `(batch_size, 2, hidden_size, coarse_height, coarse_width)`): Coarse features coarse_scale (`float`): Scale between the image size and the coarse size Returns: keypoints (`torch.Tensor` of shape `(batch_size, 2, num_matches, 2)`): Keypoints coordinates. matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_matches)`): The confidence matching score of each keypoint. matched_indices (`torch.Tensor` of shape `(batch_size, 2, num_matches)`): Indices which indicates which keypoint in an image matched with which keypoint in the other image. For both image in the pair. """ batch_size, _, embed_dim, height, width = coarse_features.shape # (batch_size, 2, embed_dim, height, width) -> (batch_size, 2, height * width, embed_dim) coarse_features = coarse_features.permute(0, 1, 3, 4, 2) coarse_features = coarse_features.reshape(batch_size, 2, -1, embed_dim) coarse_features = coarse_features / coarse_features.shape[-1] ** 0.5 coarse_features_0 = coarse_features[:, 0] coarse_features_1 = coarse_features[:, 1] similarity = coarse_features_0 @ coarse_features_1.transpose(-1, -2) similarity = similarity / self.config.coarse_matching_temperature if self.config.coarse_matching_skip_softmax: confidence = similarity else: confidence = nn.functional.softmax(similarity, 1) * nn.functional.softmax(similarity, 2) confidence = confidence.view(batch_size, height, width, height, width) matched_indices, matching_scores = self._get_matches_from_scores(confidence) keypoints = torch.stack([matched_indices % width, matched_indices // width], dim=-1) * coarse_scale return keypoints, matching_scores, matched_indices def _get_first_stage_fine_matching( self, fine_confidence: torch.Tensor, coarse_matched_keypoints: torch.Tensor, fine_window_size: int, fine_scale: float, ) -> tuple[torch.Tensor, torch.Tensor]: """ For each coarse pixel, retrieve the highest fine confidence score and index. The index represents the matching between a pixel position in the fine window in the first image and a pixel position in the fine window of the second image. For example, for a fine_window_size of 64 (8 * 8), the index 2474 represents the matching between the index 38 (2474 // 64) in the fine window of the first image, and the index 42 in the second image. This means that 38 which corresponds to the position (4, 6) (4 // 8 and 4 % 8) is matched with the position (5, 2). In this example the coarse matched coordinate will be shifted to the matched fine coordinates in the first and second image. Note: This step can be done as a postprocessing step, because does not involve any model weights/params. However, we keep it in the modeling code for consistency with other keypoint matching models AND for easier torch.compile/torch.export (all ops are in torch). Args: fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`): First stage confidence of matching fine features between the first and the second image coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`): Coarse matched keypoint between the first and the second image. fine_window_size (`int`): Size of the window used to refine matches fine_scale (`float`): Scale between the size of fine features and coarse features Returns: indices (`torch.Tensor` of shape `(2, num_matches, 1)`): Indices of the fine coordinate matched in the fine window fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`): Coordinates of matched keypoints after the first fine stage """ batch_size, num_keypoints, _, _ = fine_confidence.shape fine_kernel_size = torch_int(fine_window_size**0.5) fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, -1) values, indices = torch.max(fine_confidence, dim=-1) indices = indices[..., None] indices_0 = indices // fine_window_size indices_1 = indices % fine_window_size grid = create_meshgrid( fine_kernel_size, fine_kernel_size, normalized_coordinates=False, device=fine_confidence.device, dtype=fine_confidence.dtype, ) grid = grid - (fine_kernel_size // 2) + 0.5 grid = grid.reshape(1, 1, -1, 2).expand(batch_size, num_keypoints, -1, -1) delta_0 = torch.gather(grid, 1, indices_0.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2) delta_1 = torch.gather(grid, 1, indices_1.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2) fine_matches_0 = coarse_matched_keypoints[:, 0] + delta_0 * fine_scale fine_matches_1 = coarse_matched_keypoints[:, 1] + delta_1 * fine_scale indices = torch.stack([indices_0, indices_1], dim=1) fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1) return indices, fine_matches def _get_second_stage_fine_matching( self, indices: torch.Tensor, fine_matches: torch.Tensor, fine_confidence: torch.Tensor, fine_window_size: int, fine_scale: float, ) -> torch.Tensor: """ For the given position in their respective fine windows, retrieve the 3x3 fine confidences around this position. After applying softmax to these confidences, compute the 2D spatial expected coordinates. Shift the first stage fine matching with these expected coordinates. Note: This step can be done as a postprocessing step, because does not involve any model weights/params. However, we keep it in the modeling code for consistency with other keypoint matching models AND for easier torch.compile/torch.export (all ops are in torch). Args: indices (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): Indices representing the position of each keypoint in the fine window fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`): Coordinates of matched keypoints after the first fine stage fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`): Second stage confidence of matching fine features between the first and the second image fine_window_size (`int`): Size of the window used to refine matches fine_scale (`float`): Scale between the size of fine features and coarse features Returns: fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`): Coordinates of matched keypoints after the second fine stage """ batch_size, num_keypoints, _, _ = fine_confidence.shape fine_kernel_size = torch_int(fine_window_size**0.5) indices_0 = indices[:, 0] indices_1 = indices[:, 1] indices_1_i = indices_1 // fine_kernel_size indices_1_j = indices_1 % fine_kernel_size # matches_indices, indices_0, indices_1_i, indices_1_j of shape (num_matches, 3, 3) batch_indices = torch.arange(batch_size, device=indices_0.device).reshape(batch_size, 1, 1, 1) matches_indices = torch.arange(num_keypoints, device=indices_0.device).reshape(1, num_keypoints, 1, 1) indices_0 = indices_0[..., None] indices_1_i = indices_1_i[..., None] indices_1_j = indices_1_j[..., None] delta = create_meshgrid(3, 3, normalized_coordinates=True, device=indices_0.device).to(torch.long) delta = delta[None, ...] indices_1_i = indices_1_i + delta[..., 1] indices_1_j = indices_1_j + delta[..., 0] fine_confidence = fine_confidence.reshape( batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2 ) # (batch_size, seq_len, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2) -> (batch_size, seq_len, 3, 3) fine_confidence = fine_confidence[batch_indices, matches_indices, indices_0, indices_1_i, indices_1_j] fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, 9) fine_confidence = nn.functional.softmax( fine_confidence / self.config.fine_matching_regress_temperature, dim=-1 ) heatmap = fine_confidence.reshape(batch_size, num_keypoints, 3, 3) fine_coordinates_normalized = spatial_expectation2d(heatmap, True)[0] fine_matches_0 = fine_matches[:, 0] fine_matches_1 = fine_matches[:, 1] + (fine_coordinates_normalized * (3 // 2) * fine_scale) fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1) return fine_matches def _fine_matching( self, fine_features_0: torch.Tensor, fine_features_1: torch.Tensor, coarse_matched_keypoints: torch.Tensor, fine_scale: float, ) -> torch.Tensor: """ For each coarse pixel with a corresponding window of fine features, compute the matching confidence between fine features in the first image and the second image. Fine features are sliced in two part : - The first part used for the first stage are the first fine_hidden_size - config.fine_matching_slicedim (64 - 8 = 56 by default) features. - The second part used for the second stage are the last config.fine_matching_slicedim (8 by default) features. Each part is used to compute a fine confidence tensor of the following shape : (batch_size, (coarse_height * coarse_width), fine_window_size, fine_window_size) They correspond to the score between each fine pixel in the first image and each fine pixel in the second image. Args: fine_features_0 (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size ** 2)`): Fine features from the first image fine_features_1 (`torch.Tensor` of shape `(num_matches, (fine_kernel_size + 2) ** 2, (fine_kernel_size + 2) ** 2)`): Fine features from the second image coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`): Keypoint coordinates found in coarse matching for the first and second image fine_scale (`int`): Scale between the size of fine features and coarse features Returns: fine_coordinates (`torch.Tensor` of shape `(2, num_matches, 2)`): Matched keypoint between the first and the second image. All matched keypoints are concatenated in the second dimension. """ batch_size, num_keypoints, fine_window_size, fine_embed_dim = fine_features_0.shape fine_matching_slice_dim = self.config.fine_matching_slice_dim fine_kernel_size = torch_int(fine_window_size**0.5) # Split fine features into first and second stage features split_fine_features_0 = torch.split(fine_features_0, fine_embed_dim - fine_matching_slice_dim, -1) split_fine_features_1 = torch.split(fine_features_1, fine_embed_dim - fine_matching_slice_dim, -1) # Retrieve first stage fine features fine_features_0 = split_fine_features_0[0] fine_features_1 = split_fine_features_1[0] # Normalize first stage fine features fine_features_0 = fine_features_0 / fine_features_0.shape[-1] ** 0.5 fine_features_1 = fine_features_1 / fine_features_1.shape[-1] ** 0.5 # Compute first stage confidence fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2) fine_confidence = nn.functional.softmax(fine_confidence, 1) * nn.functional.softmax(fine_confidence, 2) fine_confidence = fine_confidence.reshape( batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2 ) fine_confidence = fine_confidence[..., 1:-1, 1:-1] first_stage_fine_confidence = fine_confidence.reshape( batch_size, num_keypoints, fine_window_size, fine_window_size ) fine_indices, fine_matches = self._get_first_stage_fine_matching( first_stage_fine_confidence, coarse_matched_keypoints, fine_window_size, fine_scale, ) # Retrieve second stage fine features fine_features_0 = split_fine_features_0[1] fine_features_1 = split_fine_features_1[1] # Normalize second stage fine features fine_features_1 = fine_features_1 / fine_matching_slice_dim**0.5 # Compute second stage fine confidence second_stage_fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2) fine_coordinates = self._get_second_stage_fine_matching( fine_indices, fine_matches, second_stage_fine_confidence, fine_window_size, fine_scale, ) return fine_coordinates @auto_docstring @can_return_tuple def forward( self, pixel_values: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> KeypointMatchingOutput: r""" Examples: ```python >>> from transformers import AutoImageProcessor, AutoModel >>> import torch >>> from PIL import Image >>> import requests >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true" >>> image1 = Image.open(requests.get(url, stream=True).raw) >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true" >>> image2 = Image.open(requests.get(url, stream=True).raw) >>> images = [image1, image2] >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr") >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr") >>> with torch.no_grad(): >>> inputs = processor(images, return_tensors="pt") >>> outputs = model(**inputs) ```""" if labels is not None: raise ValueError("SuperGlue is not trainable, no labels should be provided.") # 1. Extract coarse and residual features model_outputs: BackboneOutput = self.efficientloftr(pixel_values, **kwargs) features = model_outputs.feature_maps # 2. Compute coarse-level matching coarse_features = features[0] coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:] batch_size, _, channels, height, width = pixel_values.shape coarse_scale = height / coarse_height coarse_keypoints, coarse_matching_scores, coarse_matched_indices = self._coarse_matching( coarse_features, coarse_scale ) # 3. Fine-level refinement residual_features = features[1:] fine_features_0, fine_features_1 = self.refinement_layer(coarse_features, residual_features) # Filter fine features with coarse matches indices _, _, num_keypoints = coarse_matching_scores.shape batch_indices = torch.arange(batch_size)[..., None] fine_features_0 = fine_features_0[batch_indices, coarse_matched_indices[:, 0]] fine_features_1 = fine_features_1[batch_indices, coarse_matched_indices[:, 1]] # 4. Computer fine-level matching fine_height = torch_int(coarse_height * coarse_scale) fine_scale = height / fine_height matching_keypoints = self._fine_matching(fine_features_0, fine_features_1, coarse_keypoints, fine_scale) matching_keypoints[:, :, :, 0] = matching_keypoints[:, :, :, 0] / width matching_keypoints[:, :, :, 1] = matching_keypoints[:, :, :, 1] / height return KeypointMatchingOutput( matches=coarse_matched_indices, matching_scores=coarse_matching_scores, keypoints=matching_keypoints, hidden_states=model_outputs.hidden_states, attentions=model_outputs.attentions, ) __all__ = ["EfficientLoFTRPreTrainedModel", "EfficientLoFTRModel", "EfficientLoFTRForKeypointMatching"]