team-10/env/Lib/site-packages/transformers/models/vitpose/modeling_vitpose.py
2025-08-02 07:34:44 +02:00

298 lines
12 KiB
Python

# coding=utf-8
# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch VitPose model."""
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
auto_docstring,
logging,
)
from ...utils.backbone_utils import load_backbone
from .configuration_vitpose import VitPoseConfig
logger = logging.get_logger(__name__)
# General docstring
@dataclass
@auto_docstring(
custom_intro="""
Class for outputs of pose estimation models.
"""
)
class VitPoseEstimatorOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail.
heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`):
Heatmaps as predicted by the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
"""
loss: Optional[torch.FloatTensor] = None
heatmaps: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
@auto_docstring
class VitPosePreTrainedModel(PreTrainedModel):
config: VitPoseConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
"""Flip the flipped heatmaps back to the original form.
Args:
output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`):
The output heatmaps obtained from the flipped images.
flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`):
Pairs of keypoints which are mirrored (for example, left ear -- right ear).
target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`):
Target type to use. Can be gaussian-heatmap or combined-target.
gaussian-heatmap: Classification target with gaussian distribution.
combined-target: The combination of classification target (response map) and regression target (offset map).
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
Returns:
torch.Tensor: heatmaps that flipped back to the original image
"""
if target_type not in ["gaussian-heatmap", "combined-target"]:
raise ValueError("target_type should be gaussian-heatmap or combined-target")
if output_flipped.ndim != 4:
raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]")
batch_size, num_keypoints, height, width = output_flipped.shape
channels = 1
if target_type == "combined-target":
channels = 3
output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...]
output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width)
output_flipped_back = output_flipped.clone()
# Swap left-right parts
for left, right in flip_pairs.tolist():
output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width))
# Flip horizontally
output_flipped_back = output_flipped_back.flip(-1)
return output_flipped_back
class VitPoseSimpleDecoder(nn.Module):
"""
Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the
feature maps into heatmaps.
"""
def __init__(self, config) -> None:
super().__init__()
self.activation = nn.ReLU()
self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False)
self.conv = nn.Conv2d(
config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1
)
def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None) -> torch.Tensor:
# Transform input: ReLU + upsample
hidden_state = self.activation(hidden_state)
hidden_state = self.upsampling(hidden_state)
heatmaps = self.conv(hidden_state)
if flip_pairs is not None:
heatmaps = flip_back(heatmaps, flip_pairs)
return heatmaps
class VitPoseClassicDecoder(nn.Module):
"""
Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer,
turning the feature maps into heatmaps.
"""
def __init__(self, config: VitPoseConfig):
super().__init__()
self.deconv1 = nn.ConvTranspose2d(
config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False
)
self.batchnorm1 = nn.BatchNorm2d(256)
self.relu1 = nn.ReLU()
self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
self.batchnorm2 = nn.BatchNorm2d(256)
self.relu2 = nn.ReLU()
self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None):
hidden_state = self.deconv1(hidden_state)
hidden_state = self.batchnorm1(hidden_state)
hidden_state = self.relu1(hidden_state)
hidden_state = self.deconv2(hidden_state)
hidden_state = self.batchnorm2(hidden_state)
hidden_state = self.relu2(hidden_state)
heatmaps = self.conv(hidden_state)
if flip_pairs is not None:
heatmaps = flip_back(heatmaps, flip_pairs)
return heatmaps
@auto_docstring(
custom_intro="""
The VitPose model with a pose estimation head on top.
"""
)
class VitPoseForPoseEstimation(VitPosePreTrainedModel):
def __init__(self, config: VitPoseConfig) -> None:
super().__init__(config)
self.backbone = load_backbone(config)
# add backbone attributes
if not hasattr(self.backbone.config, "hidden_size"):
raise ValueError("The backbone should have a hidden_size attribute")
if not hasattr(self.backbone.config, "image_size"):
raise ValueError("The backbone should have an image_size attribute")
if not hasattr(self.backbone.config, "patch_size"):
raise ValueError("The backbone should have a patch_size attribute")
self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config)
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
def forward(
self,
pixel_values: torch.Tensor,
dataset_index: Optional[torch.Tensor] = None,
flip_pairs: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, VitPoseEstimatorOutput]:
r"""
dataset_index (`torch.Tensor` of shape `(batch_size,)`):
Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose).
flip_pairs (`torch.tensor`, *optional*):
Whether to mirror pairs of keypoints (for example, left ear -- right ear).
Examples:
```python
>>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation
>>> import torch
>>> from PIL import Image
>>> import requests
>>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
>>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
>>> inputs = processor(image, boxes=boxes, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> heatmaps = outputs.heatmaps
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported")
outputs = self.backbone.forward_with_filtered_kwargs(
pixel_values,
dataset_index=dataset_index,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
# Turn output hidden states in tensor of shape (batch_size, num_channels, height, width)
sequence_output = outputs.feature_maps[-1] if return_dict else outputs[0][-1]
batch_size = sequence_output.shape[0]
patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0]
patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1]
sequence_output = (
sequence_output.permute(0, 2, 1).reshape(batch_size, -1, patch_height, patch_width).contiguous()
)
heatmaps = self.head(sequence_output, flip_pairs=flip_pairs)
if not return_dict:
if output_hidden_states:
output = (heatmaps,) + outputs[1:]
else:
output = (heatmaps,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return VitPoseEstimatorOutput(
loss=loss,
heatmaps=heatmaps,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]