# coding=utf-8 # Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch VitPose model.""" from dataclasses import dataclass from typing import Optional, Union import torch import torch.utils.checkpoint from torch import nn from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, auto_docstring, logging, ) from ...utils.backbone_utils import load_backbone from .configuration_vitpose import VitPoseConfig logger = logging.get_logger(__name__) # General docstring @dataclass @auto_docstring( custom_intro=""" Class for outputs of pose estimation models. """ ) class VitPoseEstimatorOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail. heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`): Heatmaps as predicted by the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the model at the output of each stage. """ loss: Optional[torch.FloatTensor] = None heatmaps: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None attentions: Optional[tuple[torch.FloatTensor, ...]] = None @auto_docstring class VitPosePreTrainedModel(PreTrainedModel): config: VitPoseConfig base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues module.weight.data = nn.init.trunc_normal_( module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range ).to(module.weight.dtype) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"): """Flip the flipped heatmaps back to the original form. Args: output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`): The output heatmaps obtained from the flipped images. flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`): Pairs of keypoints which are mirrored (for example, left ear -- right ear). target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`): Target type to use. Can be gaussian-heatmap or combined-target. gaussian-heatmap: Classification target with gaussian distribution. combined-target: The combination of classification target (response map) and regression target (offset map). Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020). Returns: torch.Tensor: heatmaps that flipped back to the original image """ if target_type not in ["gaussian-heatmap", "combined-target"]: raise ValueError("target_type should be gaussian-heatmap or combined-target") if output_flipped.ndim != 4: raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]") batch_size, num_keypoints, height, width = output_flipped.shape channels = 1 if target_type == "combined-target": channels = 3 output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...] output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width) output_flipped_back = output_flipped.clone() # Swap left-right parts for left, right in flip_pairs.tolist(): output_flipped_back[:, left, ...] = output_flipped[:, right, ...] output_flipped_back[:, right, ...] = output_flipped[:, left, ...] output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width)) # Flip horizontally output_flipped_back = output_flipped_back.flip(-1) return output_flipped_back class VitPoseSimpleDecoder(nn.Module): """ Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the feature maps into heatmaps. """ def __init__(self, config) -> None: super().__init__() self.activation = nn.ReLU() self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False) self.conv = nn.Conv2d( config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1 ) def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None) -> torch.Tensor: # Transform input: ReLU + upsample hidden_state = self.activation(hidden_state) hidden_state = self.upsampling(hidden_state) heatmaps = self.conv(hidden_state) if flip_pairs is not None: heatmaps = flip_back(heatmaps, flip_pairs) return heatmaps class VitPoseClassicDecoder(nn.Module): """ Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer, turning the feature maps into heatmaps. """ def __init__(self, config: VitPoseConfig): super().__init__() self.deconv1 = nn.ConvTranspose2d( config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False ) self.batchnorm1 = nn.BatchNorm2d(256) self.relu1 = nn.ReLU() self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False) self.batchnorm2 = nn.BatchNorm2d(256) self.relu2 = nn.ReLU() self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0) def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None): hidden_state = self.deconv1(hidden_state) hidden_state = self.batchnorm1(hidden_state) hidden_state = self.relu1(hidden_state) hidden_state = self.deconv2(hidden_state) hidden_state = self.batchnorm2(hidden_state) hidden_state = self.relu2(hidden_state) heatmaps = self.conv(hidden_state) if flip_pairs is not None: heatmaps = flip_back(heatmaps, flip_pairs) return heatmaps @auto_docstring( custom_intro=""" The VitPose model with a pose estimation head on top. """ ) class VitPoseForPoseEstimation(VitPosePreTrainedModel): def __init__(self, config: VitPoseConfig) -> None: super().__init__(config) self.backbone = load_backbone(config) # add backbone attributes if not hasattr(self.backbone.config, "hidden_size"): raise ValueError("The backbone should have a hidden_size attribute") if not hasattr(self.backbone.config, "image_size"): raise ValueError("The backbone should have an image_size attribute") if not hasattr(self.backbone.config, "patch_size"): raise ValueError("The backbone should have a patch_size attribute") self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, pixel_values: torch.Tensor, dataset_index: Optional[torch.Tensor] = None, flip_pairs: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, VitPoseEstimatorOutput]: r""" dataset_index (`torch.Tensor` of shape `(batch_size,)`): Index to use in the Mixture-of-Experts (MoE) blocks of the backbone. This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose). flip_pairs (`torch.tensor`, *optional*): Whether to mirror pairs of keypoints (for example, left ear -- right ear). Examples: ```python >>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation >>> import torch >>> from PIL import Image >>> import requests >>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple") >>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]] >>> inputs = processor(image, boxes=boxes, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> heatmaps = outputs.heatmaps ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions loss = None if labels is not None: raise NotImplementedError("Training is not yet supported") outputs = self.backbone.forward_with_filtered_kwargs( pixel_values, dataset_index=dataset_index, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) # Turn output hidden states in tensor of shape (batch_size, num_channels, height, width) sequence_output = outputs.feature_maps[-1] if return_dict else outputs[0][-1] batch_size = sequence_output.shape[0] patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0] patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1] sequence_output = ( sequence_output.permute(0, 2, 1).reshape(batch_size, -1, patch_height, patch_width).contiguous() ) heatmaps = self.head(sequence_output, flip_pairs=flip_pairs) if not return_dict: if output_hidden_states: output = (heatmaps,) + outputs[1:] else: output = (heatmaps,) + outputs[2:] return ((loss,) + output) if loss is not None else output return VitPoseEstimatorOutput( loss=loss, heatmaps=heatmaps, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]