# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/eomt/modular_eomt.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_eomt.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections.abc import math from dataclasses import dataclass from typing import Callable, Optional, Union import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn from ...activations import ACT2FN from ...file_utils import ModelOutput, is_scipy_available, requires_backends from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_accelerate_available from .configuration_eomt import EomtConfig if is_scipy_available(): from scipy.optimize import linear_sum_assignment if is_accelerate_available(): from accelerate import PartialState from accelerate.utils import reduce @dataclass @auto_docstring( custom_intro=""" Class for outputs of [`EomtForUniversalSegmentationOutput`]. This output can be directly passed to [`~EomtImageProcessor.post_process_semantic_segmentation`] or [`~EomtImageProcessor.post_process_instance_segmentation`] or [`~EomtImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see [`~EomtImageProcessor] for details regarding usage. """ ) class EomtForUniversalSegmentationOutput(ModelOutput): r""" loss (`torch.Tensor`, *optional*): The computed loss, returned when labels are present. class_queries_logits (`torch.FloatTensor`): A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each query. Note the `+ 1` is needed because we incorporate the null class. masks_queries_logits (`torch.FloatTensor`): A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each query. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Last hidden states (final feature map) of the last layer. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Self and Cross Attentions weights from transformer decoder. patch_offsets (`list[torch.Tensor]`, *optional*): list of tuples indicating the image index and start and end positions of patches for semantic segementation. """ loss: Optional[torch.FloatTensor] = None class_queries_logits: Optional[torch.FloatTensor] = None masks_queries_logits: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None patch_offsets: Optional[list[torch.Tensor]] = None # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py def sample_point( input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs ) -> torch.Tensor: """ A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors. Args: input_features (`torch.Tensor` of shape (batch_size, channels, height, width)): A tensor that contains features map on a height * width grid point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,: 2)): A tensor that contains [0, 1] * [0, 1] normalized point coordinates add_dim (`bool`): boolean value to keep track of added dimension Returns: point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels, height_grid, width_grid): A tensor that contains features for points in `point_coordinates`. """ if point_coordinates.dim() == 3: add_dim = True point_coordinates = point_coordinates.unsqueeze(2) # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs) if add_dim: point_features = point_features.squeeze(3) return point_features def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: """ A pair wise version of the dice loss, see `dice_loss` for usage. Args: inputs (`torch.Tensor`): A tensor representing a mask labels (`torch.Tensor`): A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class). Returns: `torch.Tensor`: The computed loss between each pairs. """ inputs = inputs.sigmoid().flatten(1) numerator = 2 * torch.matmul(inputs, labels.T) # using broadcasting to get a [num_queries, NUM_CLASSES] matrix denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] loss = 1 - (numerator + 1) / (denominator + 1) return loss def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: r""" A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. Args: inputs (`torch.Tensor`): A tensor representing a mask. labels (`torch.Tensor`): A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class). Returns: loss (`torch.Tensor`): The computed loss between each pairs. """ height_and_width = inputs.shape[1] criterion = nn.BCEWithLogitsLoss(reduction="none") cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T) loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T) loss = loss_pos + loss_neg return loss # Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/matcher.py class EomtHungarianMatcher(nn.Module): """This class computes an assignment between the labels and the predictions of the network. For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__( self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544 ): """Creates the matcher Params: cost_class (`float`, *optional*, defaults to 1.0): Relative weight of the classification error in the matching cost. cost_mask (`float`, *optional*, defaults to 1.0): This is the relative weight of the focal loss of the binary mask in the matching cost. cost_dice (`float`, *optional*, defaults to 1.0): This is the relative weight of the dice loss of the binary mask in the matching cost. num_points (`int`, *optional*, defaults to 12544): No. of points to sample on which the mask loss will be calculated. The same set of K points are uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite matching. """ super().__init__() if cost_class == 0 and cost_mask == 0 and cost_dice == 0: raise ValueError("All costs can't be 0") self.num_points = num_points self.cost_class = cost_class self.cost_mask = cost_mask self.cost_dice = cost_dice @torch.no_grad() def forward( self, masks_queries_logits: torch.Tensor, class_queries_logits: torch.Tensor, mask_labels: torch.Tensor, class_labels: torch.Tensor, ) -> list[tuple[Tensor]]: """ Params: masks_queries_logits (`torch.Tensor`): A tensor of dim `batch_size, num_queries, num_labels` with the classification logits. class_queries_logits (`torch.Tensor`): A tensor of dim `batch_size, num_queries, height, width` with the predicted masks. class_labels (`torch.Tensor`): A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels. mask_labels (`torch.Tensor`): A tensor of dim `num_target_boxes, height, width` containing the target masks. Returns: matched_indices (`list[tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected labels (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes). """ indices: list[tuple[np.array]] = [] # iterate through batch size batch_size = masks_queries_logits.shape[0] for i in range(batch_size): pred_probs = class_queries_logits[i].softmax(-1) pred_mask = masks_queries_logits[i] # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be omitted. cost_class = -pred_probs[:, class_labels[i]] target_mask = mask_labels[i].to(pred_mask) target_mask = target_mask[:, None] pred_mask = pred_mask[:, None] # Sample ground truth and predicted masks point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device) target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1) target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1) pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1) pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1) # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels) cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels) cost_dice = pair_wise_dice_loss(pred_mask, target_mask) # final cost matrix cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible`` cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10)) cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10)) cost_matrix = torch.nan_to_num(cost_matrix, 0) # do the assignment using the hungarian algorithm in scipy assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu()) indices.append(assigned_indices) # It could be stacked in one tensor matched_indices = [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] return matched_indices def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: r""" Compute the DICE loss, similar to generalized IOU for masks as follows: $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ Args: inputs (`torch.Tensor`): A tensor representing a mask. labels (`torch.Tensor`): A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class). num_masks (`int`): The number of masks present in the current batch, used for normalization. Returns: `torch.Tensor`: The computed loss. """ probs = inputs.sigmoid().flatten(1) numerator = 2 * (probs * labels).sum(-1) denominator = probs.sum(-1) + labels.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) loss = loss.sum() / num_masks return loss def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor: r""" Args: inputs (`torch.Tensor`): A float tensor of arbitrary shape. labels (`torch.Tensor`): A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs (0 for the negative class and 1 for the positive class). Returns: loss (`torch.Tensor`): The computed loss. """ criterion = nn.BCEWithLogitsLoss(reduction="none") cross_entropy_loss = criterion(inputs, labels) loss = cross_entropy_loss.mean(1).sum() / num_masks return loss # Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/criterion.py class EomtLoss(nn.Module): def __init__(self, config: EomtConfig, weight_dict: dict[str, float]): """ The Eomt Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and mask) Args: config (`EomtConfig`): The configuration for Eomt model also containing loss calculation specific parameters. weight_dict (`dict[str, float]`): A dictionary of weights to be applied to the different losses. """ super().__init__() requires_backends(self, ["scipy"]) self.num_labels = config.num_labels self.weight_dict = weight_dict # Weight to apply to the null class self.eos_coef = config.no_object_weight empty_weight = torch.ones(self.num_labels + 1) empty_weight[-1] = self.eos_coef self.register_buffer("empty_weight", empty_weight) # pointwise mask loss parameters self.num_points = config.train_num_points self.oversample_ratio = config.oversample_ratio self.importance_sample_ratio = config.importance_sample_ratio self.matcher = EomtHungarianMatcher( cost_class=config.class_weight, cost_dice=config.dice_weight, cost_mask=config.mask_weight, num_points=self.num_points, ) def _max_by_axis(self, sizes: list[list[int]]) -> list[int]: maxes = sizes[0] for sublist in sizes[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes # Adapted from nested_tensor_from_tensor_list() in original implementation def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]: # get the maximum size in the batch max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) # compute final size batch_shape = [len(tensors)] + max_size batch_size, _, height, width = batch_shape dtype = tensors[0].dtype device = tensors[0].device padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device) padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) # pad the tensors to the size of the biggest one for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) padding_mask[: tensor.shape[1], : tensor.shape[2]] = False return padded_tensors, padding_masks def loss_labels( self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array] ) -> dict[str, Tensor]: """Compute the losses related to the labels using cross entropy. Args: class_queries_logits (`torch.Tensor`): A tensor of shape `batch_size, num_queries, num_labels` class_labels (`list[torch.Tensor]`): List of class labels of shape `(labels)`. indices (`tuple[np.array])`: The indices computed by the Hungarian matcher. Returns: `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. """ pred_logits = class_queries_logits batch_size, num_queries, _ = pred_logits.shape criterion = nn.CrossEntropyLoss(weight=self.empty_weight) idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries) target_classes_o = torch.cat( [target[j] for target, (_, j) in zip(class_labels, indices)] ) # shape of (batch_size, num_queries) target_classes = torch.full( (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device ) target_classes[idx] = target_classes_o # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries) pred_logits_transposed = pred_logits.transpose(1, 2) loss_ce = criterion(pred_logits_transposed, target_classes) losses = {"loss_cross_entropy": loss_ce} return losses def loss_masks( self, masks_queries_logits: torch.Tensor, mask_labels: list[torch.Tensor], indices: tuple[np.array], num_masks: int, ) -> dict[str, torch.Tensor]: """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss. Args: masks_queries_logits (`torch.Tensor`): A tensor of shape `(batch_size, num_queries, height, width)`. mask_labels (`torch.Tensor`): List of mask labels of shape `(labels, height, width)`. indices (`tuple[np.array])`: The indices computed by the Hungarian matcher. num_masks (`int)`: The number of masks, used for normalization. Returns: losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys: - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth. masks. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth, masks. """ src_idx = self._get_predictions_permutation_indices(indices) tgt_idx = self._get_targets_permutation_indices(indices) # shape (batch_size * num_queries, height, width) pred_masks = masks_queries_logits[src_idx] # shape (batch_size, num_queries, height, width) # pad all and stack the targets to the num_labels dimension target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) target_masks = target_masks[tgt_idx] # No need to upsample predictions as we are using normalized coordinates pred_masks = pred_masks[:, None] target_masks = target_masks[:, None] # Sample point coordinates with torch.no_grad(): point_coordinates = self.sample_points_using_uncertainty( pred_masks, lambda logits: self.calculate_uncertainty(logits), self.num_points, self.oversample_ratio, self.importance_sample_ratio, ) point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1) point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1) losses = { "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), "loss_dice": dice_loss(point_logits, point_labels, num_masks), } del pred_masks del target_masks return losses def _get_predictions_permutation_indices(self, indices): # Permute predictions following indices batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) predictions_indices = torch.cat([src for (src, _) in indices]) return batch_indices, predictions_indices def _get_targets_permutation_indices(self, indices): # Permute labels following indices batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) target_indices = torch.cat([tgt for (_, tgt) in indices]) return batch_indices, target_indices def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor: """ In Eomt paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (`torch.Tensor`): A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is: the number of foreground classes. The values are logits. Returns: scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ uncertainty_scores = -(torch.abs(logits)) return uncertainty_scores def sample_points_using_uncertainty( self, logits: torch.Tensor, uncertainty_function, num_points: int, oversample_ratio: int, importance_sample_ratio: float, ) -> torch.Tensor: """ This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit prediction as input. Args: logits (`float`): Logit predictions for P points. uncertainty_function: A function that takes logit predictions for P points and returns their uncertainties. num_points (`int`): The number of points P to sample. oversample_ratio (`int`): Oversampling parameter. importance_sample_ratio (`float`): Ratio of points that are sampled via importance sampling. Returns: point_coordinates (`torch.Tensor`): Coordinates for P sampled points. """ num_boxes = logits.shape[0] num_points_sampled = int(num_points * oversample_ratio) # Get random point coordinates point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) # Get sampled prediction value for the point coordinates point_logits = sample_point(logits, point_coordinates, align_corners=False) # Calculate the uncertainties based on the sampled prediction values of the points point_uncertainties = uncertainty_function(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) idx += shift[:, None] point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) if num_random_points > 0: point_coordinates = torch.cat( [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], dim=1, ) return point_coordinates def forward( self, masks_queries_logits: torch.Tensor, class_queries_logits: torch.Tensor, mask_labels: list[torch.Tensor], class_labels: list[torch.Tensor], auxiliary_predictions: Optional[dict[str, torch.Tensor]] = None, ) -> dict[str, torch.Tensor]: """ This performs the loss computation. Args: masks_queries_logits (`torch.Tensor`): A tensor of shape `(batch_size, num_queries, height, width)`. class_queries_logits (`torch.Tensor`): A tensor of shape `(batch_size, num_queries, num_labels)`. mask_labels (`torch.Tensor`): List of mask labels of shape `(labels, height, width)`. class_labels (`list[torch.Tensor]`): List of class labels of shape `(labels)`. auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*): if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], then it contains the logits from the inner layers of the EomtMaskedAttentionDecoder. Returns: losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys: - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth masks. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth masks. if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], the dictionary contains additional losses for each auxiliary predictions. """ # retrieve the matching between the outputs of the last layer and the labels indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) # compute the average number of target masks for normalization purposes num_masks = self.get_num_masks(class_labels, device=class_labels[0].device) # get all the losses losses: dict[str, Tensor] = { **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), **self.loss_labels(class_queries_logits, class_labels, indices), } # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. if auxiliary_predictions is not None: for idx, aux_outputs in enumerate(auxiliary_predictions): masks_queries_logits = aux_outputs["masks_queries_logits"] class_queries_logits = aux_outputs["class_queries_logits"] loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels) loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} losses.update(loss_dict) return losses def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: """ Computes the average number of target masks across the batch, for normalization purposes. """ num_masks = sum([len(classes) for classes in class_labels]) num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device) world_size = 1 if is_accelerate_available(): if PartialState._shared_state != {}: num_masks = reduce(num_masks) world_size = PartialState().num_processes num_masks = torch.clamp(num_masks / world_size, min=1) return num_masks class EomtPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: num_channels = pixel_values.shape[1] if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." f" Expected {self.num_channels} but got {num_channels}." ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings class EomtEmbeddings(nn.Module): """ Construct the CLS token, mask token, position and patch embeddings. """ def __init__(self, config: EomtConfig) -> None: super().__init__() self.config = config self.patch_size = config.patch_size self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) self.patch_embeddings = EomtPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.dropout = nn.Dropout(config.hidden_dropout_prob) self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS] self.position_embeddings = nn.Embedding(num_patches, config.hidden_size) self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, _, _, _ = pixel_values.shape target_dtype = self.patch_embeddings.projection.weight.dtype embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) cls_tokens = self.cls_token.expand(batch_size, -1, -1) register_tokens = self.register_tokens.expand(batch_size, -1, -1) embeddings = embeddings + self.position_embeddings(self.position_ids) embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1) embeddings = self.dropout(embeddings) return embeddings def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class EomtAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, seq_length, embed_dim = hidden_states.shape queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) values = self.v_proj(hidden_states) queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, queries, keys, values, attention_mask, is_causal=self.is_causal, scaling=self.scale, dropout=0.0 if not self.training else self.dropout, ) attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) return attn_output, attn_weights class EomtLayerScale(nn.Module): def __init__(self, config) -> None: super().__init__() self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return hidden_state * self.lambda1 def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize output = input.div(keep_prob) * random_tensor return output class EomtDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) def extra_repr(self) -> str: return f"p={self.drop_prob}" class EomtMLP(nn.Module): def __init__(self, config) -> None: super().__init__() in_features = out_features = config.hidden_size hidden_features = int(config.hidden_size * config.mlp_ratio) self.fc1 = nn.Linear(in_features, hidden_features, bias=True) if isinstance(config.hidden_act, str): self.activation = ACT2FN[config.hidden_act] else: self.activation = config.hidden_act self.fc2 = nn.Linear(hidden_features, out_features, bias=True) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.fc1(hidden_state) hidden_state = self.activation(hidden_state) hidden_state = self.fc2(hidden_state) return hidden_state class EomtSwiGLUFFN(nn.Module): def __init__(self, config) -> None: super().__init__() in_features = out_features = config.hidden_size hidden_features = int(config.hidden_size * config.mlp_ratio) hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) self.weights_out = nn.Linear(hidden_features, out_features, bias=True) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.weights_in(hidden_state) x1, x2 = hidden_state.chunk(2, dim=-1) hidden = nn.functional.silu(x1) * x2 return self.weights_out(hidden) class EomtLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: EomtConfig) -> None: super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attention = EomtAttention(config) self.layer_scale1 = EomtLayerScale(config) self.drop_path = EomtDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_swiglu_ffn: self.mlp = EomtSwiGLUFFN(config) else: self.mlp = EomtMLP(config) self.layer_scale2 = EomtLayerScale(config) def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: self_attention_outputs = self.attention( self.norm1(hidden_states), # in Eomt, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] attention_output = self.layer_scale1(attention_output) outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection hidden_states = self.drop_path(attention_output) + hidden_states # in Eomt, layernorm is also applied after self-attention layer_output = self.norm2(hidden_states) layer_output = self.mlp(layer_output) layer_output = self.layer_scale2(layer_output) # second residual connection layer_output = self.drop_path(layer_output) + hidden_states outputs = (layer_output,) + outputs return outputs class EomtLayerNorm2d(nn.LayerNorm): def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = hidden_state.permute(0, 2, 3, 1) hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps) hidden_state = hidden_state.permute(0, 3, 1, 2) return hidden_state class EomtScaleLayer(nn.Module): def __init__(self, config: EomtConfig): super().__init__() hidden_size = config.hidden_size self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2) self.activation = ACT2FN[config.hidden_act] self.conv2 = nn.Conv2d( hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size, bias=False, ) self.layernorm2d = EomtLayerNorm2d(hidden_size) def forward(self, hidden_states: torch.tensor) -> torch.Tensor: hidden_states = self.conv1(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = self.conv2(hidden_states) hidden_states = self.layernorm2d(hidden_states) return hidden_states class EomtScaleBlock(nn.Module): def __init__(self, config: EomtConfig): super().__init__() self.num_blocks = config.num_upscale_blocks self.block = nn.ModuleList([EomtScaleLayer(config) for _ in range(self.num_blocks)]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for block in self.block: hidden_states = block(hidden_states) return hidden_states class EomtMaskHead(nn.Module): def __init__(self, config: EomtConfig): super().__init__() hidden_size = config.hidden_size self.fc1 = nn.Linear(hidden_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, hidden_size) self.activation = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.activation(self.fc1(hidden_states)) hidden_states = self.activation(self.fc2(hidden_states)) hidden_states = self.fc3(hidden_states) return hidden_states @auto_docstring class EomtPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config: EomtConfig base_model_prefix = "eomt" main_input_name = "pixel_values" supports_gradient_checkpointing = False _no_split_modules = ["EomtLayer"] _supports_sdpa = True _supports_flash_attn = True def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(module.bias, -bound, bound) elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=1) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, EomtLayerScale): if hasattr(module, "lambda1"): module.lambda1.data.fill_(self.config.layerscale_value) elif isinstance(module, EomtEmbeddings): module.cls_token.data = nn.init.trunc_normal_( module.cls_token.data.to(torch.float32), mean=0.0, std=std ).to(module.cls_token.dtype) module.register_tokens.data.zero_() @auto_docstring( custom_intro=""" The EoMT Model with head on top for instance/semantic/panoptic segmentation. """ ) class EomtForUniversalSegmentation(EomtPreTrainedModel): main_input_name = "pixel_values" def __init__(self, config: EomtConfig) -> None: super().__init__(config) self.config = config self.num_hidden_layers = config.num_hidden_layers self.embeddings = EomtEmbeddings(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.query = nn.Embedding(config.num_queries, config.hidden_size) self.layers = nn.ModuleList([EomtLayer(config) for _ in range(config.num_hidden_layers)]) self.upscale_block = EomtScaleBlock(config) self.mask_head = EomtMaskHead(config) self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1) self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) self.weight_dict: dict[str, float] = { "loss_cross_entropy": config.class_weight, "loss_mask": config.mask_weight, "loss_dice": config.dice_weight, } self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict) self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks)) self.post_init() def get_loss_dict( self, masks_queries_logits: Tensor, class_queries_logits: Tensor, mask_labels: Tensor, class_labels: Tensor, auxiliary_predictions: dict[str, Tensor], ) -> dict[str, Tensor]: loss_dict: dict[str, Tensor] = self.criterion( masks_queries_logits=masks_queries_logits, class_queries_logits=class_queries_logits, mask_labels=mask_labels, class_labels=class_labels, auxiliary_predictions=auxiliary_predictions, ) # weight each loss by `self.weight_dict[]` including auxiliary losses for key, weight in self.weight_dict.items(): for loss_key, loss in loss_dict.items(): if key in loss_key: loss *= weight return loss_dict def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor: return sum(loss_dict.values()) @auto_docstring @can_return_tuple def forward( self, pixel_values: Tensor, mask_labels: Optional[list[Tensor]] = None, class_labels: Optional[list[Tensor]] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, patch_offsets: Optional[list[Tensor]] = None, ) -> EomtForUniversalSegmentationOutput: r""" mask_labels (`list[torch.Tensor]`, *optional*): list of mask labels of shape `(num_labels, height, width)` to be fed to a model class_labels (`list[torch.LongTensor]`, *optional*): list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. patch_offsets (`list[torch.Tensor]`, *optional*): list of tuples indicating the image index and start and end positions of patches for semantic segementation. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None masks_queries_logits_per_layer, class_queries_logits_per_layer = (), () attention_mask = None if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values) for idx, layer_module in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) if idx == self.num_hidden_layers - self.config.num_blocks: query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device) hidden_states = torch.cat((query, hidden_states), dim=1) if idx >= self.num_hidden_layers - self.config.num_blocks and ( self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0 ): norm_hidden_states = self.layernorm(hidden_states) masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states) masks_queries_logits_per_layer += (masks_queries_logits,) class_queries_logits_per_layer += (class_queries_logits,) attention_mask = torch.ones( hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[1], device=hidden_states.device, dtype=torch.bool, ) interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear") interpolated_logits = interpolated_logits.view( interpolated_logits.size(0), interpolated_logits.size(1), -1 ) num_query_tokens = self.config.num_queries encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens # Set attention mask for queries to focus on encoder tokens based on interpolated logits attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0 # Disable attention mask for random query tokens. attention_mask = self._disable_attention_mask( attention_mask, prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks], num_query_tokens=num_query_tokens, encoder_start_tokens=encoder_start_tokens, device=attention_mask.device, ) # Expand attention mask to 4d mask. attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1) attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9) layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_attentions += (layer_outputs[1],) sequence_output = self.layernorm(hidden_states) if output_hidden_states: all_hidden_states += (sequence_output,) masks_queries_logits, class_queries_logits = self.predict(sequence_output) masks_queries_logits_per_layer += (masks_queries_logits,) class_queries_logits_per_layer += (class_queries_logits,) loss = None if mask_labels is not None and class_labels is not None: loss = 0.0 for masks_queries_logits, class_queries_logits in zip( masks_queries_logits_per_layer, class_queries_logits_per_layer ): loss_dict = self.get_loss_dict( masks_queries_logits=masks_queries_logits, class_queries_logits=class_queries_logits, mask_labels=mask_labels, class_labels=class_labels, auxiliary_predictions=None, ) loss += self.get_loss(loss_dict) return EomtForUniversalSegmentationOutput( loss=loss, masks_queries_logits=masks_queries_logits, class_queries_logits=class_queries_logits, last_hidden_state=sequence_output, hidden_states=all_hidden_states, attentions=all_attentions, patch_offsets=patch_offsets, ) def get_input_embeddings(self): return self.embeddings.patch_embeddings def predict(self, logits: torch.Tensor): query_tokens = logits[:, : self.config.num_queries, :] class_logits = self.class_predictor(query_tokens) prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :] prefix_tokens = prefix_tokens.transpose(1, 2) prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size) query_tokens = self.mask_head(query_tokens) prefix_tokens = self.upscale_block(prefix_tokens) mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens) return mask_logits, class_logits @staticmethod def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device): if prob < 1: # Generate random queries to disable based on the probs random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob # Disable attention to the query tokens, considering the prefix tokens attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1 return attn_mask __all__ = ["EomtPreTrainedModel", "EomtForUniversalSegmentation"]