1302 lines
56 KiB
Python
1302 lines
56 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ...activations import ACT2CLS, ACT2FN
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import BackboneOutput
|
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import (
|
|
ModelOutput,
|
|
TransformersKwargs,
|
|
auto_docstring,
|
|
can_return_tuple,
|
|
torch_int,
|
|
)
|
|
from ...utils.generic import check_model_inputs
|
|
from .configuration_efficientloftr import EfficientLoFTRConfig
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number
|
|
of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of
|
|
images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is
|
|
used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching
|
|
information.
|
|
"""
|
|
)
|
|
class KeypointMatchingOutput(ModelOutput):
|
|
r"""
|
|
matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
|
|
Index of keypoint matched in the other image.
|
|
matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
|
|
Scores of predicted matches.
|
|
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
|
|
Absolute (x, y) coordinates of predicted keypoints in a given image.
|
|
hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*):
|
|
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
|
|
num_keypoints)`, returned when `output_hidden_states=True` is passed or when
|
|
`config.output_hidden_states=True`)
|
|
attentions (`tuple[torch.FloatTensor, ...]`, *optional*):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
|
|
num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`)
|
|
"""
|
|
|
|
matches: Optional[torch.FloatTensor] = None
|
|
matching_scores: Optional[torch.FloatTensor] = None
|
|
keypoints: Optional[torch.FloatTensor] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
|
|
|
|
class EfficientLoFTRRotaryEmbedding(nn.Module):
|
|
def __init__(self, config: EfficientLoFTRConfig, device=None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.rope_type = config.rope_scaling["rope_type"]
|
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
|
|
inv_freq, _ = self.rope_init_fn(self.config, device)
|
|
inv_freq_expanded = inv_freq[None, None, None, :].float().expand(1, 1, 1, -1)
|
|
|
|
embed_height, embed_width = config.embedding_size
|
|
i_indices = torch.ones(embed_height, embed_width).cumsum(0).float().unsqueeze(-1)
|
|
j_indices = torch.ones(embed_height, embed_width).cumsum(1).float().unsqueeze(-1)
|
|
|
|
emb = torch.zeros(1, embed_height, embed_width, self.config.hidden_size // 2)
|
|
emb[:, :, :, 0::2] = i_indices * inv_freq_expanded
|
|
emb[:, :, :, 1::2] = j_indices * inv_freq_expanded
|
|
|
|
self.register_buffer("inv_freq", emb, persistent=False)
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self, x: torch.Tensor, position_ids: Optional[tuple[torch.LongTensor, torch.LongTensor]] = None
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
emb = self.inv_freq
|
|
sin = emb.sin()
|
|
cos = emb.cos()
|
|
|
|
sin = sin.repeat_interleave(2, dim=-1)
|
|
cos = cos.repeat_interleave(2, dim=-1)
|
|
|
|
sin = sin.to(device=x.device, dtype=x.dtype)
|
|
cos = cos.to(device=x.device, dtype=x.dtype)
|
|
|
|
return cos, sin
|
|
|
|
|
|
# Copied from transformers.models.rt_detr_v2.modeling_rt_detr_v2.RTDetrV2ConvNormLayer with RTDetrV2->EfficientLoFTR
|
|
class EfficientLoFTRConvNormLayer(nn.Module):
|
|
def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
|
bias=False,
|
|
)
|
|
self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
|
|
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
|
|
|
def forward(self, hidden_state):
|
|
hidden_state = self.conv(hidden_state)
|
|
hidden_state = self.norm(hidden_state)
|
|
hidden_state = self.activation(hidden_state)
|
|
return hidden_state
|
|
|
|
|
|
class EfficientLoFTRRepVGGBlock(GradientCheckpointingLayer):
|
|
"""
|
|
RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
|
|
"""
|
|
|
|
def __init__(self, config: EfficientLoFTRConfig, stage_idx: int, block_idx: int):
|
|
super().__init__()
|
|
in_channels = config.stage_block_in_channels[stage_idx][block_idx]
|
|
out_channels = config.stage_block_out_channels[stage_idx][block_idx]
|
|
stride = config.stage_block_stride[stage_idx][block_idx]
|
|
activation = config.activation_function
|
|
self.conv1 = EfficientLoFTRConvNormLayer(
|
|
config, in_channels, out_channels, kernel_size=3, stride=stride, padding=1
|
|
)
|
|
self.conv2 = EfficientLoFTRConvNormLayer(
|
|
config, in_channels, out_channels, kernel_size=1, stride=stride, padding=0
|
|
)
|
|
self.identity = nn.BatchNorm2d(in_channels) if in_channels == out_channels and stride == 1 else None
|
|
self.activation = nn.Identity() if activation is None else ACT2FN[activation]
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
if self.identity is not None:
|
|
identity_out = self.identity(hidden_states)
|
|
else:
|
|
identity_out = 0
|
|
hidden_states = self.conv1(hidden_states) + self.conv2(hidden_states) + identity_out
|
|
hidden_states = self.activation(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class EfficientLoFTRRepVGGStage(nn.Module):
|
|
def __init__(self, config: EfficientLoFTRConfig, stage_idx: int):
|
|
super().__init__()
|
|
self.blocks = nn.ModuleList([])
|
|
for block_idx in range(config.stage_num_blocks[stage_idx]):
|
|
self.blocks.append(
|
|
EfficientLoFTRRepVGGBlock(
|
|
config,
|
|
stage_idx,
|
|
block_idx,
|
|
)
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
for block in self.blocks:
|
|
hidden_states = block(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class EfficientLoFTRepVGG(nn.Module):
|
|
def __init__(self, config: EfficientLoFTRConfig):
|
|
super().__init__()
|
|
|
|
self.stages = nn.ModuleList([])
|
|
|
|
for stage_idx in range(len(config.stage_stride)):
|
|
stage = EfficientLoFTRRepVGGStage(config, stage_idx)
|
|
self.stages.append(stage)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]:
|
|
outputs = []
|
|
for stage in self.stages:
|
|
hidden_states = stage(hidden_states)
|
|
outputs.append(hidden_states)
|
|
|
|
# Exclude first stage in outputs
|
|
outputs = outputs[1:]
|
|
return outputs
|
|
|
|
|
|
class EfficientLoFTRAggregationLayer(nn.Module):
|
|
def __init__(self, config: EfficientLoFTRConfig):
|
|
super().__init__()
|
|
|
|
hidden_size = config.hidden_size
|
|
|
|
self.q_aggregation = nn.Conv2d(
|
|
hidden_size,
|
|
hidden_size,
|
|
kernel_size=config.q_aggregation_kernel_size,
|
|
padding=0,
|
|
stride=config.q_aggregation_stride,
|
|
bias=False,
|
|
groups=hidden_size,
|
|
)
|
|
self.kv_aggregation = torch.nn.MaxPool2d(
|
|
kernel_size=config.kv_aggregation_kernel_size, stride=config.kv_aggregation_stride
|
|
)
|
|
self.norm = nn.LayerNorm(hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
query_states = hidden_states
|
|
is_cross_attention = encoder_hidden_states is not None
|
|
kv_states = encoder_hidden_states if is_cross_attention else hidden_states
|
|
|
|
query_states = self.q_aggregation(query_states)
|
|
kv_states = self.kv_aggregation(kv_states)
|
|
query_states = query_states.permute(0, 2, 3, 1)
|
|
kv_states = kv_states.permute(0, 2, 3, 1)
|
|
hidden_states = self.norm(query_states)
|
|
encoder_hidden_states = self.norm(kv_states)
|
|
return hidden_states, encoder_hidden_states
|
|
|
|
|
|
# Copied from transformers.models.cohere.modeling_cohere.rotate_half
|
|
def rotate_half(x):
|
|
# Split and rotate. Note that this function is different from e.g. Llama.
|
|
x1 = x[..., ::2]
|
|
x2 = x[..., 1::2]
|
|
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
|
|
return rot_x
|
|
|
|
|
|
# Copied from transformers.models.cohere.modeling_cohere.apply_rotary_pos_emb
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
Args:
|
|
q (`torch.Tensor`): The query tensor.
|
|
k (`torch.Tensor`): The key tensor.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
position_ids (`torch.Tensor`, *optional*):
|
|
Deprecated and unused.
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
Returns:
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
"""
|
|
dtype = q.dtype
|
|
q = q.float()
|
|
k = k.float()
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
|
|
|
|
|
|
# Copied from transformers.models.cohere.modeling_cohere.repeat_kv
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
"""
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor],
|
|
scaling: float,
|
|
dropout: float = 0.0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
key_states = repeat_kv(key, module.num_key_value_groups)
|
|
value_states = repeat_kv(value, module.num_key_value_groups)
|
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
|
if attention_mask is not None:
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class EfficientLoFTRAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->EfficientLoFTR
|
|
def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attention_dropout = config.attention_dropout
|
|
self.is_causal = True
|
|
|
|
self.q_proj = nn.Linear(
|
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
|
)
|
|
self.k_proj = nn.Linear(
|
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
)
|
|
self.v_proj = nn.Linear(
|
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
)
|
|
self.o_proj = nn.Linear(
|
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
batch_size, seq_len, dim = hidden_states.shape
|
|
input_shape = hidden_states.shape[:-1]
|
|
|
|
query_states = self.q_proj(hidden_states).view(batch_size, seq_len, -1, dim)
|
|
|
|
is_cross_attention = encoder_hidden_states is not None
|
|
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
|
|
|
key_states = self.k_proj(current_states).view(batch_size, seq_len, -1, dim)
|
|
value_states = self.v_proj(current_states).view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
|
|
|
|
if position_embeddings is not None:
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2)
|
|
|
|
query_states = query_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask=None,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class EfficientLoFTRMLP(nn.Module):
|
|
def __init__(self, config: EfficientLoFTRConfig):
|
|
super().__init__()
|
|
hidden_size = config.hidden_size
|
|
intermediate_size = config.intermediate_size
|
|
self.fc1 = nn.Linear(hidden_size * 2, intermediate_size, bias=False)
|
|
self.activation = ACT2FN[config.mlp_activation_function]
|
|
self.fc2 = nn.Linear(intermediate_size, hidden_size, bias=False)
|
|
self.layer_norm = nn.LayerNorm(hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class EfficientLoFTRAggregatedAttention(nn.Module):
|
|
def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
|
|
super().__init__()
|
|
|
|
self.q_aggregation_kernel_size = config.q_aggregation_kernel_size
|
|
self.aggregation = EfficientLoFTRAggregationLayer(config)
|
|
self.attention = EfficientLoFTRAttention(config, layer_idx)
|
|
self.mlp = EfficientLoFTRMLP(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.Tensor:
|
|
batch_size, embed_dim, _, _ = hidden_states.shape
|
|
|
|
# Aggregate features
|
|
aggregated_hidden_states, aggregated_encoder_hidden_states = self.aggregation(
|
|
hidden_states, encoder_hidden_states
|
|
)
|
|
_, aggregated_h, aggregated_w, _ = aggregated_hidden_states.shape
|
|
|
|
# Multi-head attention
|
|
aggregated_hidden_states = aggregated_hidden_states.reshape(batch_size, -1, embed_dim)
|
|
aggregated_encoder_hidden_states = aggregated_encoder_hidden_states.reshape(batch_size, -1, embed_dim)
|
|
attn_output, _ = self.attention(
|
|
aggregated_hidden_states,
|
|
aggregated_encoder_hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
# Upsample features
|
|
# (batch_size, seq_len, embed_dim) -> (batch_size, embed_dim, h, w) with seq_len = h * w
|
|
attn_output = attn_output.permute(0, 2, 1)
|
|
attn_output = attn_output.reshape(batch_size, embed_dim, aggregated_h, aggregated_w)
|
|
attn_output = torch.nn.functional.interpolate(
|
|
attn_output, scale_factor=self.q_aggregation_kernel_size, mode="bilinear", align_corners=False
|
|
)
|
|
intermediate_states = torch.cat([hidden_states, attn_output], dim=1)
|
|
intermediate_states = intermediate_states.permute(0, 2, 3, 1)
|
|
output_states = self.mlp(intermediate_states)
|
|
output_states = output_states.permute(0, 3, 1, 2)
|
|
|
|
hidden_states = hidden_states + output_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class EfficientLoFTRLocalFeatureTransformerLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
|
|
super().__init__()
|
|
|
|
self.self_attention = EfficientLoFTRAggregatedAttention(config, layer_idx)
|
|
self.cross_attention = EfficientLoFTRAggregatedAttention(config, layer_idx)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.Tensor:
|
|
batch_size, _, embed_dim, height, width = hidden_states.shape
|
|
|
|
hidden_states = hidden_states.reshape(-1, embed_dim, height, width)
|
|
hidden_states = self.self_attention(hidden_states, position_embeddings=position_embeddings, **kwargs)
|
|
|
|
encoder_hidden_states = hidden_states.reshape(-1, 2, embed_dim, height, width)
|
|
encoder_hidden_states = encoder_hidden_states.flip(1)
|
|
encoder_hidden_states = encoder_hidden_states.reshape(-1, embed_dim, height, width)
|
|
|
|
hidden_states = self.cross_attention(hidden_states, encoder_hidden_states, **kwargs)
|
|
hidden_states = hidden_states.reshape(batch_size, -1, embed_dim, height, width)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class EfficientLoFTRLocalFeatureTransformer(nn.Module):
|
|
def __init__(self, config: EfficientLoFTRConfig):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
EfficientLoFTRLocalFeatureTransformerLayer(config, layer_idx=i)
|
|
for i in range(config.num_attention_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> torch.Tensor:
|
|
for layer in self.layers:
|
|
hidden_states = layer(hidden_states, position_embeddings=position_embeddings, **kwargs)
|
|
return hidden_states
|
|
|
|
|
|
class EfficientLoFTROutConvBlock(nn.Module):
|
|
def __init__(self, config: EfficientLoFTRConfig, hidden_size: int, intermediate_size: int):
|
|
super().__init__()
|
|
|
|
self.out_conv1 = nn.Conv2d(hidden_size, intermediate_size, kernel_size=1, stride=1, padding=0, bias=False)
|
|
self.out_conv2 = nn.Conv2d(
|
|
intermediate_size, intermediate_size, kernel_size=3, stride=1, padding=1, bias=False
|
|
)
|
|
self.batch_norm = nn.BatchNorm2d(intermediate_size)
|
|
self.activation = ACT2CLS[config.mlp_activation_function]()
|
|
self.out_conv3 = nn.Conv2d(intermediate_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=False)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, residual_states: torch.Tensor) -> torch.Tensor:
|
|
residual_states = self.out_conv1(residual_states)
|
|
residual_states = residual_states + hidden_states
|
|
residual_states = self.out_conv2(residual_states)
|
|
residual_states = self.batch_norm(residual_states)
|
|
residual_states = self.activation(residual_states)
|
|
residual_states = self.out_conv3(residual_states)
|
|
residual_states = nn.functional.interpolate(
|
|
residual_states, scale_factor=2.0, mode="bilinear", align_corners=False
|
|
)
|
|
return residual_states
|
|
|
|
|
|
class EfficientLoFTRFineFusionLayer(nn.Module):
|
|
def __init__(self, config: EfficientLoFTRConfig):
|
|
super().__init__()
|
|
|
|
self.fine_kernel_size = config.fine_kernel_size
|
|
|
|
fine_fusion_dims = config.fine_fusion_dims
|
|
self.out_conv = nn.Conv2d(
|
|
fine_fusion_dims[0], fine_fusion_dims[0], kernel_size=1, stride=1, padding=0, bias=False
|
|
)
|
|
self.out_conv_layers = nn.ModuleList()
|
|
for i in range(1, len(fine_fusion_dims)):
|
|
out_conv = EfficientLoFTROutConvBlock(config, fine_fusion_dims[i], fine_fusion_dims[i - 1])
|
|
self.out_conv_layers.append(out_conv)
|
|
|
|
def forward_pyramid(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
residual_states: list[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
hidden_states = self.out_conv(hidden_states)
|
|
hidden_states = nn.functional.interpolate(
|
|
hidden_states, scale_factor=2.0, mode="bilinear", align_corners=False
|
|
)
|
|
for i, layer in enumerate(self.out_conv_layers):
|
|
hidden_states = layer(hidden_states, residual_states[i])
|
|
|
|
return hidden_states
|
|
|
|
def forward(
|
|
self,
|
|
coarse_features: torch.Tensor,
|
|
residual_features: list[torch.Tensor],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
For each image pair, compute the fine features of pixels.
|
|
In both images, compute a patch of fine features center cropped around each coarse pixel.
|
|
In the first image, the feature patch is kernel_size large and long.
|
|
In the second image, it is (kernel_size + 2) large and long.
|
|
"""
|
|
batch_size, _, embed_dim, coarse_height, coarse_width = coarse_features.shape
|
|
|
|
coarse_features = coarse_features.reshape(-1, embed_dim, coarse_height, coarse_width)
|
|
residual_features = list(reversed(residual_features))
|
|
|
|
# 1. Fine feature extraction
|
|
fine_features = self.forward_pyramid(coarse_features, residual_features)
|
|
_, fine_embed_dim, fine_height, fine_width = fine_features.shape
|
|
|
|
fine_features = fine_features.reshape(batch_size, 2, fine_embed_dim, fine_height, fine_width)
|
|
fine_features_0 = fine_features[:, 0]
|
|
fine_features_1 = fine_features[:, 1]
|
|
|
|
# 2. Unfold all local windows in crops
|
|
stride = int(fine_height // coarse_height)
|
|
fine_features_0 = nn.functional.unfold(
|
|
fine_features_0, kernel_size=self.fine_kernel_size, stride=stride, padding=0
|
|
)
|
|
_, _, seq_len = fine_features_0.shape
|
|
fine_features_0 = fine_features_0.reshape(batch_size, -1, self.fine_kernel_size**2, seq_len)
|
|
fine_features_0 = fine_features_0.permute(0, 3, 2, 1)
|
|
|
|
fine_features_1 = nn.functional.unfold(
|
|
fine_features_1, kernel_size=self.fine_kernel_size + 2, stride=stride, padding=1
|
|
)
|
|
fine_features_1 = fine_features_1.reshape(batch_size, -1, (self.fine_kernel_size + 2) ** 2, seq_len)
|
|
fine_features_1 = fine_features_1.permute(0, 3, 2, 1)
|
|
|
|
return fine_features_0, fine_features_1
|
|
|
|
|
|
@auto_docstring
|
|
class EfficientLoFTRPreTrainedModel(PreTrainedModel):
|
|
"""
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
models.
|
|
"""
|
|
|
|
config_class = EfficientLoFTRConfig
|
|
base_model_prefix = "efficientloftr"
|
|
main_input_name = "pixel_values"
|
|
supports_gradient_checkpointing = True
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
_can_record_outputs = {
|
|
"hidden_states": EfficientLoFTRRepVGGBlock,
|
|
"attentions": EfficientLoFTRAttention,
|
|
}
|
|
|
|
def _init_weights(self, module: nn.Module) -> None:
|
|
"""Initialize the weights"""
|
|
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d)):
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
|
|
# Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR
|
|
def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
|
|
"""
|
|
Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same,
|
|
extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for EfficientLoFTR. This is
|
|
a workaround for the issue discussed in :
|
|
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
|
|
|
|
Args:
|
|
pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width)
|
|
|
|
Returns:
|
|
pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width)
|
|
|
|
"""
|
|
return pixel_values[:, 0, :, :][:, None, :, :]
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
EfficientLoFTR model taking images as inputs and outputting the features of the images.
|
|
"""
|
|
)
|
|
class EfficientLoFTRModel(EfficientLoFTRPreTrainedModel):
|
|
def __init__(self, config: EfficientLoFTRConfig):
|
|
super().__init__(config)
|
|
|
|
self.config = config
|
|
self.backbone = EfficientLoFTRepVGG(config)
|
|
self.local_feature_transformer = EfficientLoFTRLocalFeatureTransformer(config)
|
|
self.rotary_emb = EfficientLoFTRRotaryEmbedding(config=config)
|
|
|
|
self.post_init()
|
|
|
|
@check_model_inputs
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> BackboneOutput:
|
|
r"""
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import AutoImageProcessor, AutoModel
|
|
>>> import torch
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
|
|
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
|
|
>>> image1 = Image.open(requests.get(url, stream=True).raw)
|
|
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
|
|
>>> image2 = Image.open(requests.get(url, stream=True).raw)
|
|
>>> images = [image1, image2]
|
|
|
|
>>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
|
|
>>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")
|
|
|
|
>>> with torch.no_grad():
|
|
>>> inputs = processor(images, return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
```"""
|
|
if labels is not None:
|
|
raise ValueError("EfficientLoFTR is not trainable, no labels should be provided.")
|
|
|
|
if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
|
|
raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
|
|
|
|
batch_size, _, channels, height, width = pixel_values.shape
|
|
pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
|
|
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
|
|
|
|
# 1. Local Feature CNN
|
|
features = self.backbone(pixel_values)
|
|
# Last stage outputs are coarse outputs
|
|
coarse_features = features[-1]
|
|
# Rest is residual features used in EfficientLoFTRFineFusionLayer
|
|
residual_features = features[:-1]
|
|
coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:]
|
|
|
|
# 2. Coarse-level LoFTR module
|
|
cos, sin = self.rotary_emb(coarse_features)
|
|
cos = cos.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim)
|
|
sin = sin.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim)
|
|
position_embeddings = (cos, sin)
|
|
|
|
coarse_features = coarse_features.reshape(batch_size, 2, coarse_embed_dim, coarse_height, coarse_width)
|
|
coarse_features = self.local_feature_transformer(
|
|
coarse_features, position_embeddings=position_embeddings, **kwargs
|
|
)
|
|
|
|
features = (coarse_features,) + tuple(residual_features)
|
|
|
|
return BackboneOutput(feature_maps=features)
|
|
|
|
|
|
def mask_border(tensor: torch.Tensor, border_margin: int, value: Union[bool, float, int]) -> torch.Tensor:
|
|
"""
|
|
Mask a tensor border with a given value
|
|
|
|
Args:
|
|
tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
|
|
The tensor to mask
|
|
border_margin (`int`) :
|
|
The size of the border
|
|
value (`Union[bool, int, float]`):
|
|
The value to place in the tensor's borders
|
|
|
|
Returns:
|
|
tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
|
|
The masked tensor
|
|
"""
|
|
if border_margin <= 0:
|
|
return tensor
|
|
|
|
tensor[:, :border_margin, :border_margin, :border_margin, :border_margin] = value
|
|
tensor[:, -border_margin:, -border_margin:, -border_margin:, -border_margin:] = value
|
|
return tensor
|
|
|
|
|
|
def create_meshgrid(
|
|
height: Union[int, torch.Tensor],
|
|
width: Union[int, torch.Tensor],
|
|
normalized_coordinates: bool = False,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Copied from kornia library : kornia/kornia/utils/grid.py:26
|
|
|
|
Generate a coordinate grid for an image.
|
|
|
|
When the flag ``normalized_coordinates`` is set to True, the grid is
|
|
normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
|
|
function :py:func:`torch.nn.functional.grid_sample`.
|
|
|
|
Args:
|
|
height (`int`):
|
|
The image height (rows).
|
|
width (`int`):
|
|
The image width (cols).
|
|
normalized_coordinates (`bool`):
|
|
Whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the
|
|
PyTorch function :py:func:`torch.nn.functional.grid_sample`.
|
|
device (`torch.device`):
|
|
The device on which the grid will be generated.
|
|
dtype (`torch.dtype`):
|
|
The data type of the generated grid.
|
|
|
|
Return:
|
|
grid (`torch.Tensor` of shape `(1, height, width, 2)`):
|
|
The grid tensor.
|
|
|
|
Example:
|
|
>>> create_meshgrid(2, 2)
|
|
tensor([[[[-1., -1.],
|
|
[ 1., -1.]],
|
|
<BLANKLINE>
|
|
[[-1., 1.],
|
|
[ 1., 1.]]]])
|
|
|
|
>>> create_meshgrid(2, 2, normalized_coordinates=False)
|
|
tensor([[[[0., 0.],
|
|
[1., 0.]],
|
|
<BLANKLINE>
|
|
[[0., 1.],
|
|
[1., 1.]]]])
|
|
|
|
"""
|
|
xs = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
|
|
ys = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
|
|
if normalized_coordinates:
|
|
xs = (xs / (width - 1) - 0.5) * 2
|
|
ys = (ys / (height - 1) - 0.5) * 2
|
|
grid = torch.stack(torch.meshgrid(ys, xs, indexing="ij"), dim=-1)
|
|
grid = grid.permute(1, 0, 2).unsqueeze(0)
|
|
return grid
|
|
|
|
|
|
def spatial_expectation2d(input: torch.Tensor, normalized_coordinates: bool = True) -> torch.Tensor:
|
|
r"""
|
|
Copied from kornia library : kornia/geometry/subpix/dsnt.py:76
|
|
Compute the expectation of coordinate values using spatial probabilities.
|
|
|
|
The input heatmap is assumed to represent a valid spatial probability distribution,
|
|
which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`.
|
|
|
|
Args:
|
|
input (`torch.Tensor` of shape `(batch_size, embed_dim, height, width)`):
|
|
The input tensor representing dense spatial probabilities.
|
|
normalized_coordinates (`bool`):
|
|
Whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return
|
|
the coordinates in the range of the input shape.
|
|
|
|
Returns:
|
|
output (`torch.Tensor` of shape `(batch_size, embed_dim, 2)`)
|
|
Expected value of the 2D coordinates. Output order of the coordinates is (x, y).
|
|
|
|
Examples:
|
|
>>> heatmaps = torch.tensor([[[
|
|
... [0., 0., 0.],
|
|
... [0., 0., 0.],
|
|
... [0., 1., 0.]]]])
|
|
>>> spatial_expectation2d(heatmaps, False)
|
|
tensor([[[1., 2.]]])
|
|
|
|
"""
|
|
batch_size, embed_dim, height, width = input.shape
|
|
|
|
# Create coordinates grid.
|
|
grid = create_meshgrid(height, width, normalized_coordinates, input.device)
|
|
grid = grid.to(input.dtype)
|
|
|
|
pos_x = grid[..., 0].reshape(-1)
|
|
pos_y = grid[..., 1].reshape(-1)
|
|
|
|
input_flat = input.view(batch_size, embed_dim, -1)
|
|
|
|
# Compute the expectation of the coordinates.
|
|
expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True)
|
|
expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True)
|
|
|
|
output = torch.cat([expected_x, expected_y], -1)
|
|
|
|
return output.view(batch_size, embed_dim, 2)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
EfficientLoFTR model taking images as inputs and outputting the matching of them.
|
|
"""
|
|
)
|
|
class EfficientLoFTRForKeypointMatching(EfficientLoFTRPreTrainedModel):
|
|
"""EfficientLoFTR dense image matcher
|
|
|
|
Given two images, we determine the correspondences by:
|
|
1. Extracting coarse and fine features through a backbone
|
|
2. Transforming coarse features through self and cross attention
|
|
3. Matching coarse features to obtain coarse coordinates of matches
|
|
4. Obtaining full resolution fine features by fusing transformed and backbone coarse features
|
|
5. Refining the coarse matches using fine feature patches centered at each coarse match in a two-stage refinement
|
|
|
|
Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou.
|
|
Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed
|
|
In CVPR, 2024. https://arxiv.org/abs/2403.04765
|
|
"""
|
|
|
|
def __init__(self, config: EfficientLoFTRConfig):
|
|
super().__init__(config)
|
|
|
|
self.config = config
|
|
self.efficientloftr = EfficientLoFTRModel(config)
|
|
self.refinement_layer = EfficientLoFTRFineFusionLayer(config)
|
|
|
|
self.post_init()
|
|
|
|
def _get_matches_from_scores(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Based on a keypoint score matrix, compute the best keypoint matches between the first and second image.
|
|
Since each image pair can have different number of matches, the matches are concatenated together for all pair
|
|
in the batch and a batch_indices tensor is returned to specify which match belong to which element in the batch.
|
|
|
|
Note:
|
|
This step can be done as a postprocessing step, because does not involve any model weights/params.
|
|
However, we keep it in the modeling code for consistency with other keypoint matching models AND for
|
|
easier torch.compile/torch.export (all ops are in torch).
|
|
|
|
Args:
|
|
scores (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
|
|
Scores of keypoints
|
|
|
|
Returns:
|
|
matched_indices (`torch.Tensor` of shape `(2, num_matches)`):
|
|
Indices representing which pixel in the first image matches which pixel in the second image
|
|
matching_scores (`torch.Tensor` of shape `(num_matches,)`):
|
|
Scores of each match
|
|
"""
|
|
batch_size, height0, width0, height1, width1 = scores.shape
|
|
|
|
scores = scores.view(batch_size, height0 * width0, height1 * width1)
|
|
|
|
# For each keypoint, get the best match
|
|
max_0 = scores.max(2, keepdim=True).values
|
|
max_1 = scores.max(1, keepdim=True).values
|
|
|
|
# 1. Thresholding
|
|
mask = scores > self.config.coarse_matching_threshold
|
|
|
|
# 2. Border removal
|
|
mask = mask.reshape(batch_size, height0, width0, height1, width1)
|
|
mask = mask_border(mask, self.config.coarse_matching_border_removal, False)
|
|
mask = mask.reshape(batch_size, height0 * width0, height1 * width1)
|
|
|
|
# 3. Mutual nearest neighbors
|
|
mask = mask * (scores == max_0) * (scores == max_1)
|
|
|
|
# 4. Fine coarse matches
|
|
masked_scores = scores * mask
|
|
matching_scores_0, max_indices_0 = masked_scores.max(1)
|
|
matching_scores_1, max_indices_1 = masked_scores.max(2)
|
|
|
|
matching_indices = torch.cat([max_indices_0, max_indices_1]).reshape(batch_size, 2, -1)
|
|
matching_scores = torch.stack([matching_scores_0, matching_scores_1], dim=1)
|
|
|
|
# For the keypoints not meeting the threshold score, set the indices to -1 which corresponds to no matches found
|
|
matching_indices = torch.where(matching_scores > 0, matching_indices, -1)
|
|
|
|
return matching_indices, matching_scores
|
|
|
|
def _coarse_matching(
|
|
self, coarse_features: torch.Tensor, coarse_scale: float
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
For each image pair, compute the matching confidence between each coarse element (by default (image_height / 8)
|
|
* (image_width / 8 elements)) from the first image to the second image.
|
|
|
|
Note:
|
|
This step can be done as a postprocessing step, because does not involve any model weights/params.
|
|
However, we keep it in the modeling code for consistency with other keypoint matching models AND for
|
|
easier torch.compile/torch.export (all ops are in torch).
|
|
|
|
Args:
|
|
coarse_features (`torch.Tensor` of shape `(batch_size, 2, hidden_size, coarse_height, coarse_width)`):
|
|
Coarse features
|
|
coarse_scale (`float`): Scale between the image size and the coarse size
|
|
|
|
Returns:
|
|
keypoints (`torch.Tensor` of shape `(batch_size, 2, num_matches, 2)`):
|
|
Keypoints coordinates.
|
|
matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_matches)`):
|
|
The confidence matching score of each keypoint.
|
|
matched_indices (`torch.Tensor` of shape `(batch_size, 2, num_matches)`):
|
|
Indices which indicates which keypoint in an image matched with which keypoint in the other image. For
|
|
both image in the pair.
|
|
"""
|
|
batch_size, _, embed_dim, height, width = coarse_features.shape
|
|
|
|
# (batch_size, 2, embed_dim, height, width) -> (batch_size, 2, height * width, embed_dim)
|
|
coarse_features = coarse_features.permute(0, 1, 3, 4, 2)
|
|
coarse_features = coarse_features.reshape(batch_size, 2, -1, embed_dim)
|
|
|
|
coarse_features = coarse_features / coarse_features.shape[-1] ** 0.5
|
|
coarse_features_0 = coarse_features[:, 0]
|
|
coarse_features_1 = coarse_features[:, 1]
|
|
|
|
similarity = coarse_features_0 @ coarse_features_1.transpose(-1, -2)
|
|
similarity = similarity / self.config.coarse_matching_temperature
|
|
|
|
if self.config.coarse_matching_skip_softmax:
|
|
confidence = similarity
|
|
else:
|
|
confidence = nn.functional.softmax(similarity, 1) * nn.functional.softmax(similarity, 2)
|
|
|
|
confidence = confidence.view(batch_size, height, width, height, width)
|
|
matched_indices, matching_scores = self._get_matches_from_scores(confidence)
|
|
|
|
keypoints = torch.stack([matched_indices % width, matched_indices // width], dim=-1) * coarse_scale
|
|
|
|
return keypoints, matching_scores, matched_indices
|
|
|
|
def _get_first_stage_fine_matching(
|
|
self,
|
|
fine_confidence: torch.Tensor,
|
|
coarse_matched_keypoints: torch.Tensor,
|
|
fine_window_size: int,
|
|
fine_scale: float,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
For each coarse pixel, retrieve the highest fine confidence score and index.
|
|
The index represents the matching between a pixel position in the fine window in the first image and a pixel
|
|
position in the fine window of the second image.
|
|
For example, for a fine_window_size of 64 (8 * 8), the index 2474 represents the matching between the index 38
|
|
(2474 // 64) in the fine window of the first image, and the index 42 in the second image. This means that 38
|
|
which corresponds to the position (4, 6) (4 // 8 and 4 % 8) is matched with the position (5, 2). In this example
|
|
the coarse matched coordinate will be shifted to the matched fine coordinates in the first and second image.
|
|
|
|
Note:
|
|
This step can be done as a postprocessing step, because does not involve any model weights/params.
|
|
However, we keep it in the modeling code for consistency with other keypoint matching models AND for
|
|
easier torch.compile/torch.export (all ops are in torch).
|
|
|
|
Args:
|
|
fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`):
|
|
First stage confidence of matching fine features between the first and the second image
|
|
coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`):
|
|
Coarse matched keypoint between the first and the second image.
|
|
fine_window_size (`int`):
|
|
Size of the window used to refine matches
|
|
fine_scale (`float`):
|
|
Scale between the size of fine features and coarse features
|
|
|
|
Returns:
|
|
indices (`torch.Tensor` of shape `(2, num_matches, 1)`):
|
|
Indices of the fine coordinate matched in the fine window
|
|
fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
|
|
Coordinates of matched keypoints after the first fine stage
|
|
"""
|
|
batch_size, num_keypoints, _, _ = fine_confidence.shape
|
|
fine_kernel_size = torch_int(fine_window_size**0.5)
|
|
|
|
fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, -1)
|
|
values, indices = torch.max(fine_confidence, dim=-1)
|
|
indices = indices[..., None]
|
|
indices_0 = indices // fine_window_size
|
|
indices_1 = indices % fine_window_size
|
|
|
|
grid = create_meshgrid(
|
|
fine_kernel_size,
|
|
fine_kernel_size,
|
|
normalized_coordinates=False,
|
|
device=fine_confidence.device,
|
|
dtype=fine_confidence.dtype,
|
|
)
|
|
grid = grid - (fine_kernel_size // 2) + 0.5
|
|
grid = grid.reshape(1, 1, -1, 2).expand(batch_size, num_keypoints, -1, -1)
|
|
delta_0 = torch.gather(grid, 1, indices_0.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2)
|
|
delta_1 = torch.gather(grid, 1, indices_1.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2)
|
|
|
|
fine_matches_0 = coarse_matched_keypoints[:, 0] + delta_0 * fine_scale
|
|
fine_matches_1 = coarse_matched_keypoints[:, 1] + delta_1 * fine_scale
|
|
|
|
indices = torch.stack([indices_0, indices_1], dim=1)
|
|
fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1)
|
|
|
|
return indices, fine_matches
|
|
|
|
def _get_second_stage_fine_matching(
|
|
self,
|
|
indices: torch.Tensor,
|
|
fine_matches: torch.Tensor,
|
|
fine_confidence: torch.Tensor,
|
|
fine_window_size: int,
|
|
fine_scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
For the given position in their respective fine windows, retrieve the 3x3 fine confidences around this position.
|
|
After applying softmax to these confidences, compute the 2D spatial expected coordinates.
|
|
Shift the first stage fine matching with these expected coordinates.
|
|
|
|
Note:
|
|
This step can be done as a postprocessing step, because does not involve any model weights/params.
|
|
However, we keep it in the modeling code for consistency with other keypoint matching models AND for
|
|
easier torch.compile/torch.export (all ops are in torch).
|
|
|
|
Args:
|
|
indices (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
|
|
Indices representing the position of each keypoint in the fine window
|
|
fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
|
|
Coordinates of matched keypoints after the first fine stage
|
|
fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`):
|
|
Second stage confidence of matching fine features between the first and the second image
|
|
fine_window_size (`int`):
|
|
Size of the window used to refine matches
|
|
fine_scale (`float`):
|
|
Scale between the size of fine features and coarse features
|
|
|
|
Returns:
|
|
fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
|
|
Coordinates of matched keypoints after the second fine stage
|
|
"""
|
|
batch_size, num_keypoints, _, _ = fine_confidence.shape
|
|
fine_kernel_size = torch_int(fine_window_size**0.5)
|
|
|
|
indices_0 = indices[:, 0]
|
|
indices_1 = indices[:, 1]
|
|
indices_1_i = indices_1 // fine_kernel_size
|
|
indices_1_j = indices_1 % fine_kernel_size
|
|
|
|
# matches_indices, indices_0, indices_1_i, indices_1_j of shape (num_matches, 3, 3)
|
|
batch_indices = torch.arange(batch_size, device=indices_0.device).reshape(batch_size, 1, 1, 1)
|
|
matches_indices = torch.arange(num_keypoints, device=indices_0.device).reshape(1, num_keypoints, 1, 1)
|
|
indices_0 = indices_0[..., None]
|
|
indices_1_i = indices_1_i[..., None]
|
|
indices_1_j = indices_1_j[..., None]
|
|
|
|
delta = create_meshgrid(3, 3, normalized_coordinates=True, device=indices_0.device).to(torch.long)
|
|
delta = delta[None, ...]
|
|
|
|
indices_1_i = indices_1_i + delta[..., 1]
|
|
indices_1_j = indices_1_j + delta[..., 0]
|
|
|
|
fine_confidence = fine_confidence.reshape(
|
|
batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2
|
|
)
|
|
# (batch_size, seq_len, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2) -> (batch_size, seq_len, 3, 3)
|
|
fine_confidence = fine_confidence[batch_indices, matches_indices, indices_0, indices_1_i, indices_1_j]
|
|
fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, 9)
|
|
fine_confidence = nn.functional.softmax(
|
|
fine_confidence / self.config.fine_matching_regress_temperature, dim=-1
|
|
)
|
|
|
|
heatmap = fine_confidence.reshape(batch_size, num_keypoints, 3, 3)
|
|
fine_coordinates_normalized = spatial_expectation2d(heatmap, True)[0]
|
|
|
|
fine_matches_0 = fine_matches[:, 0]
|
|
fine_matches_1 = fine_matches[:, 1] + (fine_coordinates_normalized * (3 // 2) * fine_scale)
|
|
|
|
fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1)
|
|
|
|
return fine_matches
|
|
|
|
def _fine_matching(
|
|
self,
|
|
fine_features_0: torch.Tensor,
|
|
fine_features_1: torch.Tensor,
|
|
coarse_matched_keypoints: torch.Tensor,
|
|
fine_scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
For each coarse pixel with a corresponding window of fine features, compute the matching confidence between fine
|
|
features in the first image and the second image.
|
|
|
|
Fine features are sliced in two part :
|
|
- The first part used for the first stage are the first fine_hidden_size - config.fine_matching_slicedim (64 - 8
|
|
= 56 by default) features.
|
|
- The second part used for the second stage are the last config.fine_matching_slicedim (8 by default) features.
|
|
|
|
Each part is used to compute a fine confidence tensor of the following shape :
|
|
(batch_size, (coarse_height * coarse_width), fine_window_size, fine_window_size)
|
|
They correspond to the score between each fine pixel in the first image and each fine pixel in the second image.
|
|
|
|
Args:
|
|
fine_features_0 (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size ** 2)`):
|
|
Fine features from the first image
|
|
fine_features_1 (`torch.Tensor` of shape `(num_matches, (fine_kernel_size + 2) ** 2, (fine_kernel_size + 2)
|
|
** 2)`):
|
|
Fine features from the second image
|
|
coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`):
|
|
Keypoint coordinates found in coarse matching for the first and second image
|
|
fine_scale (`int`):
|
|
Scale between the size of fine features and coarse features
|
|
|
|
Returns:
|
|
fine_coordinates (`torch.Tensor` of shape `(2, num_matches, 2)`):
|
|
Matched keypoint between the first and the second image. All matched keypoints are concatenated in the
|
|
second dimension.
|
|
|
|
"""
|
|
batch_size, num_keypoints, fine_window_size, fine_embed_dim = fine_features_0.shape
|
|
fine_matching_slice_dim = self.config.fine_matching_slice_dim
|
|
|
|
fine_kernel_size = torch_int(fine_window_size**0.5)
|
|
|
|
# Split fine features into first and second stage features
|
|
split_fine_features_0 = torch.split(fine_features_0, fine_embed_dim - fine_matching_slice_dim, -1)
|
|
split_fine_features_1 = torch.split(fine_features_1, fine_embed_dim - fine_matching_slice_dim, -1)
|
|
|
|
# Retrieve first stage fine features
|
|
fine_features_0 = split_fine_features_0[0]
|
|
fine_features_1 = split_fine_features_1[0]
|
|
|
|
# Normalize first stage fine features
|
|
fine_features_0 = fine_features_0 / fine_features_0.shape[-1] ** 0.5
|
|
fine_features_1 = fine_features_1 / fine_features_1.shape[-1] ** 0.5
|
|
|
|
# Compute first stage confidence
|
|
fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2)
|
|
fine_confidence = nn.functional.softmax(fine_confidence, 1) * nn.functional.softmax(fine_confidence, 2)
|
|
fine_confidence = fine_confidence.reshape(
|
|
batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2
|
|
)
|
|
fine_confidence = fine_confidence[..., 1:-1, 1:-1]
|
|
first_stage_fine_confidence = fine_confidence.reshape(
|
|
batch_size, num_keypoints, fine_window_size, fine_window_size
|
|
)
|
|
|
|
fine_indices, fine_matches = self._get_first_stage_fine_matching(
|
|
first_stage_fine_confidence,
|
|
coarse_matched_keypoints,
|
|
fine_window_size,
|
|
fine_scale,
|
|
)
|
|
|
|
# Retrieve second stage fine features
|
|
fine_features_0 = split_fine_features_0[1]
|
|
fine_features_1 = split_fine_features_1[1]
|
|
|
|
# Normalize second stage fine features
|
|
fine_features_1 = fine_features_1 / fine_matching_slice_dim**0.5
|
|
|
|
# Compute second stage fine confidence
|
|
second_stage_fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2)
|
|
|
|
fine_coordinates = self._get_second_stage_fine_matching(
|
|
fine_indices,
|
|
fine_matches,
|
|
second_stage_fine_confidence,
|
|
fine_window_size,
|
|
fine_scale,
|
|
)
|
|
|
|
return fine_coordinates
|
|
|
|
@auto_docstring
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> KeypointMatchingOutput:
|
|
r"""
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import AutoImageProcessor, AutoModel
|
|
>>> import torch
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
|
|
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
|
|
>>> image1 = Image.open(requests.get(url, stream=True).raw)
|
|
>>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
|
|
>>> image2 = Image.open(requests.get(url, stream=True).raw)
|
|
>>> images = [image1, image2]
|
|
|
|
>>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
|
|
>>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")
|
|
|
|
>>> with torch.no_grad():
|
|
>>> inputs = processor(images, return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
```"""
|
|
if labels is not None:
|
|
raise ValueError("SuperGlue is not trainable, no labels should be provided.")
|
|
|
|
# 1. Extract coarse and residual features
|
|
model_outputs: BackboneOutput = self.efficientloftr(pixel_values, **kwargs)
|
|
features = model_outputs.feature_maps
|
|
|
|
# 2. Compute coarse-level matching
|
|
coarse_features = features[0]
|
|
coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:]
|
|
batch_size, _, channels, height, width = pixel_values.shape
|
|
coarse_scale = height / coarse_height
|
|
coarse_keypoints, coarse_matching_scores, coarse_matched_indices = self._coarse_matching(
|
|
coarse_features, coarse_scale
|
|
)
|
|
|
|
# 3. Fine-level refinement
|
|
residual_features = features[1:]
|
|
fine_features_0, fine_features_1 = self.refinement_layer(coarse_features, residual_features)
|
|
|
|
# Filter fine features with coarse matches indices
|
|
_, _, num_keypoints = coarse_matching_scores.shape
|
|
batch_indices = torch.arange(batch_size)[..., None]
|
|
fine_features_0 = fine_features_0[batch_indices, coarse_matched_indices[:, 0]]
|
|
fine_features_1 = fine_features_1[batch_indices, coarse_matched_indices[:, 1]]
|
|
|
|
# 4. Computer fine-level matching
|
|
fine_height = torch_int(coarse_height * coarse_scale)
|
|
fine_scale = height / fine_height
|
|
matching_keypoints = self._fine_matching(fine_features_0, fine_features_1, coarse_keypoints, fine_scale)
|
|
|
|
matching_keypoints[:, :, :, 0] = matching_keypoints[:, :, :, 0] / width
|
|
matching_keypoints[:, :, :, 1] = matching_keypoints[:, :, :, 1] / height
|
|
|
|
return KeypointMatchingOutput(
|
|
matches=coarse_matched_indices,
|
|
matching_scores=coarse_matching_scores,
|
|
keypoints=matching_keypoints,
|
|
hidden_states=model_outputs.hidden_states,
|
|
attentions=model_outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = ["EfficientLoFTRPreTrainedModel", "EfficientLoFTRModel", "EfficientLoFTRForKeypointMatching"]
|