166 lines
6.7 KiB
Python
166 lines
6.7 KiB
Python
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import BCEWithLogitsLoss, MSELoss
|
|
|
|
from .loss_d_fine import DFineForObjectDetectionLoss
|
|
from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
|
|
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
|
|
from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
|
|
from .loss_rt_detr import RTDetrForObjectDetectionLoss
|
|
|
|
|
|
def fixed_cross_entropy(
|
|
source: torch.Tensor,
|
|
target: torch.Tensor,
|
|
num_items_in_batch: Optional[torch.Tensor] = None,
|
|
ignore_index: int = -100,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
reduction = "sum" if num_items_in_batch is not None else "mean"
|
|
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
|
|
if reduction == "sum":
|
|
# just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
|
|
if torch.is_tensor(num_items_in_batch):
|
|
num_items_in_batch = num_items_in_batch.to(loss.device)
|
|
loss = loss / num_items_in_batch
|
|
return loss
|
|
|
|
|
|
def ForCausalLMLoss(
|
|
logits,
|
|
labels,
|
|
vocab_size: int,
|
|
num_items_in_batch: Optional[torch.Tensor] = None,
|
|
ignore_index: int = -100,
|
|
shift_labels: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
logits = logits.float()
|
|
|
|
if shift_labels is None:
|
|
# Shift so that tokens < n predict n
|
|
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
# Flatten the tokens
|
|
logits = logits.view(-1, vocab_size)
|
|
shift_labels = shift_labels.view(-1)
|
|
# Enable model parallelism
|
|
shift_labels = shift_labels.to(logits.device)
|
|
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
|
|
return loss
|
|
|
|
|
|
def ForMaskedLMLoss(
|
|
logits: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
vocab_size: int,
|
|
num_items_in_batch: Optional[torch.Tensor] = None,
|
|
ignore_index: int = -100,
|
|
**kwargs,
|
|
):
|
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
logits = logits.float()
|
|
|
|
# Flatten the tokens
|
|
logits = logits.view(-1, vocab_size)
|
|
labels = labels.view(-1)
|
|
# Enable model parallelism
|
|
|
|
labels = labels.to(logits.device)
|
|
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
|
|
return loss
|
|
|
|
|
|
def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor:
|
|
num_labels = config.num_labels
|
|
if config.problem_type is None:
|
|
if num_labels == 1:
|
|
config.problem_type = "regression"
|
|
elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
|
|
config.problem_type = "single_label_classification"
|
|
else:
|
|
config.problem_type = "multi_label_classification"
|
|
|
|
labels = labels.to(pooled_logits.device)
|
|
if config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if num_labels == 1:
|
|
return loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
|
else:
|
|
return loss_fct(pooled_logits, labels)
|
|
if config.problem_type == "single_label_classification":
|
|
return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
|
|
|
|
if config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
return loss_fct(pooled_logits, labels)
|
|
|
|
raise RuntimeError(f"Invalid problem type: {config.problem_type}")
|
|
|
|
|
|
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
|
|
total_loss = None
|
|
if start_positions is not None and end_positions is not None:
|
|
# If we are on multi-GPU, split add a dimension
|
|
if len(start_positions.size()) > 1:
|
|
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
|
if len(end_positions.size()) > 1:
|
|
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
|
ignored_index = start_logits.size(1)
|
|
start_positions = start_positions.clamp(0, ignored_index)
|
|
end_positions = end_positions.clamp(0, ignored_index)
|
|
|
|
start_loss = fixed_cross_entropy(start_logits, start_positions, ignore_index=ignored_index, **kwargs)
|
|
end_loss = fixed_cross_entropy(end_logits, end_positions, ignore_index=ignored_index, **kwargs)
|
|
total_loss = (start_loss + end_loss) / 2
|
|
return total_loss
|
|
|
|
|
|
def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
|
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
logits = logits.view(-1, config.num_labels)
|
|
labels = labels.view(-1).to(logits.device)
|
|
logits = logits.float()
|
|
# Flatten the tokens
|
|
return fixed_cross_entropy(logits, labels, **kwargs)
|
|
|
|
|
|
LOSS_MAPPING = {
|
|
"ForCausalLM": ForCausalLMLoss,
|
|
"ForMaskedLM": ForMaskedLMLoss,
|
|
"ForQuestionAnswering": ForQuestionAnsweringLoss,
|
|
"ForSequenceClassification": ForSequenceClassificationLoss,
|
|
"ForImageClassification": ForSequenceClassificationLoss,
|
|
"ForVideoClassification": ForSequenceClassificationLoss,
|
|
"ForTokenClassification": ForTokenClassification,
|
|
"ForSegmentation": ForSegmentationLoss,
|
|
"ForObjectDetection": ForObjectDetectionLoss,
|
|
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
|
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
|
"DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
|
"GroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
|
|
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
|
|
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
|
|
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
|
|
"DFineForObjectDetection": DFineForObjectDetectionLoss,
|
|
"CsmForConditionalGeneration": ForCausalLMLoss,
|
|
}
|