1532 lines
64 KiB
Python
1532 lines
64 KiB
Python
# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import math
|
|
from contextlib import nullcontext
|
|
from typing import Literal, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from ...activations import ACT2FN
|
|
from ...configuration_utils import PretrainedConfig
|
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutput,
|
|
MaskedLMOutput,
|
|
QuestionAnsweringModelOutput,
|
|
SequenceClassifierOutput,
|
|
TokenClassifierOutput,
|
|
)
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...utils import auto_docstring, is_flash_attn_2_available, logging
|
|
from ...utils.import_utils import is_triton_available
|
|
from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
|
|
|
|
|
|
if is_flash_attn_2_available():
|
|
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
|
from flash_attn.layers.rotary import RotaryEmbedding
|
|
from flash_attn.ops.triton.rotary import apply_rotary
|
|
else:
|
|
RotaryEmbedding = object
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class ModernBertConfig(PretrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
|
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
defaults will yield a similar configuration to that of the ModernBERT-base.
|
|
e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
|
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PretrainedConfig`] for more information.
|
|
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 50368):
|
|
Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`ModernBertModel`]
|
|
hidden_size (`int`, *optional*, defaults to 768):
|
|
Dimension of the hidden representations.
|
|
intermediate_size (`int`, *optional*, defaults to 1152):
|
|
Dimension of the MLP representations.
|
|
num_hidden_layers (`int`, *optional*, defaults to 22):
|
|
Number of hidden layers in the Transformer decoder.
|
|
num_attention_heads (`int`, *optional*, defaults to 12):
|
|
Number of attention heads for each attention layer in the Transformer decoder.
|
|
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
|
|
The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
|
|
if not specified.
|
|
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
|
The maximum sequence length that this model might ever be used with.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
|
|
The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
|
|
norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
The epsilon used by the rms normalization layers.
|
|
norm_bias (`bool`, *optional*, defaults to `False`):
|
|
Whether to use bias in the normalization layers.
|
|
pad_token_id (`int`, *optional*, defaults to 50283):
|
|
Padding token id.
|
|
eos_token_id (`int`, *optional*, defaults to 50282):
|
|
End of stream token id.
|
|
bos_token_id (`int`, *optional*, defaults to 50281):
|
|
Beginning of stream token id.
|
|
cls_token_id (`int`, *optional*, defaults to 50281):
|
|
Classification token id.
|
|
sep_token_id (`int`, *optional*, defaults to 50282):
|
|
Separation token id.
|
|
global_rope_theta (`float`, *optional*, defaults to 160000.0):
|
|
The base period of the global RoPE embeddings.
|
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the attention probabilities.
|
|
global_attn_every_n_layers (`int`, *optional*, defaults to 3):
|
|
The number of layers between global attention layers.
|
|
local_attention (`int`, *optional*, defaults to 128):
|
|
The window size for local attention.
|
|
local_rope_theta (`float`, *optional*, defaults to 10000.0):
|
|
The base period of the local RoPE embeddings.
|
|
embedding_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the embeddings.
|
|
mlp_bias (`bool`, *optional*, defaults to `False`):
|
|
Whether to use bias in the MLP layers.
|
|
mlp_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the MLP layers.
|
|
decoder_bias (`bool`, *optional*, defaults to `True`):
|
|
Whether to use bias in the decoder layers.
|
|
classifier_pooling (`str`, *optional*, defaults to `"cls"`):
|
|
The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
|
|
CLS token doesn't attend to all tokens on long sequences.
|
|
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the classifier.
|
|
classifier_bias (`bool`, *optional*, defaults to `False`):
|
|
Whether to use bias in the classifier.
|
|
classifier_activation (`str`, *optional*, defaults to `"gelu"`):
|
|
The activation function for the classifier.
|
|
deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
|
|
Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
|
|
sparse_prediction (`bool`, *optional*, defaults to `False`):
|
|
Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
|
|
sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
|
|
The index to ignore for the sparse prediction.
|
|
reference_compile (`bool`, *optional*):
|
|
Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
|
|
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
|
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
|
be faster in some scenarios.
|
|
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
|
|
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
|
|
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import ModernBertModel, ModernBertConfig
|
|
|
|
>>> # Initializing a ModernBert style configuration
|
|
>>> configuration = ModernBertConfig()
|
|
|
|
>>> # Initializing a model from the modernbert-base style configuration
|
|
>>> model = ModernBertModel(configuration)
|
|
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
model_type = "modernbert"
|
|
attribute_map = {"rope_theta": "global_rope_theta"}
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size=50368,
|
|
hidden_size=768,
|
|
intermediate_size=1152,
|
|
num_hidden_layers=22,
|
|
num_attention_heads=12,
|
|
hidden_activation="gelu",
|
|
max_position_embeddings=8192,
|
|
initializer_range=0.02,
|
|
initializer_cutoff_factor=2.0,
|
|
norm_eps=1e-5,
|
|
norm_bias=False,
|
|
pad_token_id=50283,
|
|
eos_token_id=50282,
|
|
bos_token_id=50281,
|
|
cls_token_id=50281,
|
|
sep_token_id=50282,
|
|
global_rope_theta=160000.0,
|
|
attention_bias=False,
|
|
attention_dropout=0.0,
|
|
global_attn_every_n_layers=3,
|
|
local_attention=128,
|
|
local_rope_theta=10000.0,
|
|
embedding_dropout=0.0,
|
|
mlp_bias=False,
|
|
mlp_dropout=0.0,
|
|
decoder_bias=True,
|
|
classifier_pooling: Literal["cls", "mean"] = "cls",
|
|
classifier_dropout=0.0,
|
|
classifier_bias=False,
|
|
classifier_activation="gelu",
|
|
deterministic_flash_attn=False,
|
|
sparse_prediction=False,
|
|
sparse_pred_ignore_index=-100,
|
|
reference_compile=None,
|
|
repad_logits_with_grad=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
pad_token_id=pad_token_id,
|
|
bos_token_id=bos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
cls_token_id=cls_token_id,
|
|
sep_token_id=sep_token_id,
|
|
**kwargs,
|
|
)
|
|
self.vocab_size = vocab_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.initializer_range = initializer_range
|
|
self.initializer_cutoff_factor = initializer_cutoff_factor
|
|
self.norm_eps = norm_eps
|
|
self.norm_bias = norm_bias
|
|
self.global_rope_theta = global_rope_theta
|
|
self.attention_bias = attention_bias
|
|
self.attention_dropout = attention_dropout
|
|
self.hidden_activation = hidden_activation
|
|
self.global_attn_every_n_layers = global_attn_every_n_layers
|
|
self.local_attention = local_attention
|
|
self.local_rope_theta = local_rope_theta
|
|
self.embedding_dropout = embedding_dropout
|
|
self.mlp_bias = mlp_bias
|
|
self.mlp_dropout = mlp_dropout
|
|
self.decoder_bias = decoder_bias
|
|
self.classifier_pooling = classifier_pooling
|
|
self.classifier_dropout = classifier_dropout
|
|
self.classifier_bias = classifier_bias
|
|
self.classifier_activation = classifier_activation
|
|
self.deterministic_flash_attn = deterministic_flash_attn
|
|
self.sparse_prediction = sparse_prediction
|
|
self.sparse_pred_ignore_index = sparse_pred_ignore_index
|
|
self.reference_compile = reference_compile
|
|
self.repad_logits_with_grad = repad_logits_with_grad
|
|
|
|
if self.classifier_pooling not in ["cls", "mean"]:
|
|
raise ValueError(
|
|
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
|
|
)
|
|
|
|
def to_dict(self):
|
|
output = super().to_dict()
|
|
output.pop("reference_compile", None)
|
|
return output
|
|
|
|
|
|
def _unpad_modernbert_input(
|
|
inputs: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
"""
|
|
Remove padding from input sequences.
|
|
|
|
Args:
|
|
inputs: (batch, seqlen, ...) or (batch, seqlen)
|
|
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
|
position_ids: (batch, seqlen), int, position ids
|
|
labels: (batch, seqlen), int, labels
|
|
|
|
Returns:
|
|
unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
|
|
indices: (total_nnz)
|
|
cu_seqlens: (batch + 1), the cumulative sequence lengths
|
|
max_seqlen_in_batch: int
|
|
unpadded_position_ids: (total_nnz) or None
|
|
unpadded_labels: (total_nnz) or None
|
|
"""
|
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
|
|
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
|
|
if inputs.dim() == 2:
|
|
unpadded_inputs = inputs.flatten()[indices]
|
|
else:
|
|
batch, seqlen, *rest = inputs.shape
|
|
shape = batch * seqlen
|
|
unpadded_inputs = inputs.view(shape, *rest)[indices]
|
|
|
|
unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
|
|
unpadded_labels = labels.flatten()[indices] if labels is not None else None
|
|
|
|
return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
|
|
|
|
|
|
def _pad_modernbert_output(
|
|
inputs: torch.Tensor,
|
|
indices: torch.Tensor,
|
|
batch: int,
|
|
seqlen: int,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Add padding to sequences.
|
|
|
|
Args:
|
|
inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
|
|
indices: (total_nnz)
|
|
batch: int, batch size
|
|
seqlen: int, max sequence length
|
|
|
|
Returns:
|
|
padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
|
|
"""
|
|
if inputs.dim() == 1:
|
|
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
|
|
output[indices] = inputs
|
|
padded_inputs = output.view(batch, seqlen)
|
|
else:
|
|
_, *rest = inputs.shape
|
|
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
|
|
output[indices] = inputs
|
|
padded_inputs = output.view(batch, seqlen, *rest)
|
|
|
|
return padded_inputs
|
|
|
|
|
|
class ApplyRotaryEmbUnpad(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
qkv,
|
|
cos,
|
|
sin,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
):
|
|
# (total_nnz, 3, nheads, headdim)
|
|
qkv = qkv.contiguous()
|
|
total_nnz, _three, _nheads, headdim = qkv.shape
|
|
# We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
|
# we get the same tensor
|
|
# qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
|
|
qk = qkv[:, :2].view(total_nnz, -1, headdim)
|
|
apply_rotary(
|
|
qk,
|
|
cos,
|
|
sin,
|
|
seqlen_offsets=0,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
interleaved=False,
|
|
inplace=True,
|
|
)
|
|
|
|
ctx.save_for_backward(cos, sin, cu_seqlens)
|
|
ctx.max_seqlen = max_seqlen
|
|
return qkv
|
|
|
|
@staticmethod
|
|
def backward(ctx, do):
|
|
cos, sin, cu_seqlens = ctx.saved_tensors
|
|
do = do.contiguous()
|
|
total_nnz, _three, _nheads, headdim = do.shape
|
|
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
|
|
# we get the same tensor
|
|
dqk = do[:, :2].view(total_nnz, -1, headdim)
|
|
apply_rotary(
|
|
dqk,
|
|
cos,
|
|
sin,
|
|
seqlen_offsets=0,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=ctx.max_seqlen,
|
|
interleaved=False,
|
|
inplace=True,
|
|
conjugate=True,
|
|
)
|
|
|
|
return do, None, None, None, None, None, None
|
|
|
|
|
|
def apply_rotary_unpadded(
|
|
qkv,
|
|
cos,
|
|
sin,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
):
|
|
"""
|
|
Arguments:
|
|
qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
|
|
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
|
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
|
of 1st half and 2nd half (GPT-NeoX style).
|
|
inplace: if True, apply rotary embedding in-place.
|
|
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
|
Most commonly used in inference when we have KV cache.
|
|
cu_seqlens: (batch + 1,) or None
|
|
max_seqlen: int
|
|
Return:
|
|
out: (total_nnz, dim)
|
|
rotary_dim must be <= headdim
|
|
Apply rotary embedding to the first rotary_dim of x.
|
|
"""
|
|
return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
|
|
|
|
|
|
class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
|
|
"""
|
|
The rotary position embeddings applied directly to unpadded sequences.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
base: float = 10000.0,
|
|
max_seqlen: Optional[int] = None,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
):
|
|
"""
|
|
max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
|
|
up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
|
|
the cos_sin_cache will be recomputed during the forward pass.
|
|
"""
|
|
super().__init__(dim=dim, base=base, device=device, interleaved=False)
|
|
self.max_seqlen = max_seqlen
|
|
|
|
if max_seqlen is not None and device is not None and dtype is not None:
|
|
self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
|
|
|
|
def forward(
|
|
self,
|
|
qkv: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
max_seqlen: Optional[int] = None,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
"""
|
|
Apply rotary embedding *inplace* to qkv.
|
|
qkv: (total_nnz, 3, nheads, headdim)
|
|
cu_seqlens: (batch + 1,) cumulative sequence lengths
|
|
max_seqlen: int max seq length in the batch
|
|
"""
|
|
if max_seqlen is not None:
|
|
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
|
|
|
qkv = apply_rotary_unpadded(
|
|
qkv,
|
|
self._cos_cached,
|
|
self._sin_cached,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
)
|
|
|
|
return qkv
|
|
|
|
def extra_repr(self) -> str:
|
|
return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
|
|
|
|
|
|
class ModernBertEmbeddings(nn.Module):
|
|
"""
|
|
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
|
"""
|
|
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
self.drop = nn.Dropout(config.embedding_dropout)
|
|
|
|
@torch.compile(dynamic=True)
|
|
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
|
return self.drop(self.norm(self.tok_embeddings(input_ids)))
|
|
|
|
def forward(
|
|
self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
if inputs_embeds is not None:
|
|
hidden_states = self.drop(self.norm(inputs_embeds))
|
|
else:
|
|
hidden_states = (
|
|
self.compiled_embeddings(input_ids)
|
|
if self.config.reference_compile
|
|
else self.drop(self.norm(self.tok_embeddings(input_ids)))
|
|
)
|
|
return hidden_states
|
|
|
|
|
|
class ModernBertMLP(nn.Module):
|
|
"""Applies the GLU at the end of each ModernBERT layer.
|
|
|
|
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
|
|
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
|
|
"""
|
|
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
|
|
self.act = ACT2FN[config.hidden_activation]
|
|
self.drop = nn.Dropout(config.mlp_dropout)
|
|
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
|
|
return self.Wo(self.drop(self.act(input) * gate))
|
|
|
|
|
|
class ModernBertRotaryEmbedding(GemmaRotaryEmbedding):
|
|
pass
|
|
|
|
|
|
def eager_attention_forward(
|
|
module: "ModernBertAttention",
|
|
qkv: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
sliding_window_mask: torch.Tensor,
|
|
position_ids: Optional[torch.LongTensor],
|
|
local_attention: tuple[int, int],
|
|
bs: int,
|
|
dim: int,
|
|
output_attentions: Optional[bool] = False,
|
|
**_kwargs,
|
|
) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
|
|
# qkv: [batch_size, seqlen, 3, nheads, headdim]
|
|
cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
|
|
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
|
|
# query, key, value: [batch_size, heads, seq_len, head_dim]
|
|
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
|
|
|
scale = module.head_dim**-0.5
|
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
|
|
|
|
if local_attention != (-1, -1):
|
|
attention_mask = sliding_window_mask
|
|
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.view(bs, -1, dim)
|
|
if output_attentions:
|
|
return (attn_output, attn_weights)
|
|
return (attn_output,)
|
|
|
|
|
|
def flash_attention_forward(
|
|
module: "ModernBertAttention",
|
|
qkv: torch.Tensor,
|
|
rotary_emb: ModernBertUnpaddedRotaryEmbedding,
|
|
cu_seqlens: torch.Tensor,
|
|
max_seqlen: int,
|
|
local_attention: tuple[int, int],
|
|
bs: int,
|
|
dim: int,
|
|
target_dtype: torch.dtype = torch.bfloat16,
|
|
**_kwargs,
|
|
) -> tuple[torch.Tensor]:
|
|
# (total_seqlen, 3, nheads, headdim)
|
|
qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
|
|
|
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
|
|
if convert_dtype:
|
|
# FA2 implementation only supports fp16 and bf16. If FA2 is supported,
|
|
# bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
|
|
orig_dtype = qkv.dtype
|
|
qkv = qkv.to(target_dtype)
|
|
|
|
attn = flash_attn_varlen_qkvpacked_func(
|
|
qkv,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
dropout_p=module.attention_dropout if module.training else 0.0,
|
|
deterministic=module.deterministic_flash_attn,
|
|
window_size=local_attention,
|
|
)
|
|
attn = attn.to(orig_dtype) # type: ignore
|
|
else:
|
|
attn = flash_attn_varlen_qkvpacked_func(
|
|
qkv,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
dropout_p=module.attention_dropout if module.training else 0.0,
|
|
deterministic=module.deterministic_flash_attn,
|
|
window_size=local_attention,
|
|
)
|
|
return (attn.view(bs, dim),)
|
|
|
|
|
|
def sdpa_attention_forward(
|
|
module: "ModernBertAttention",
|
|
qkv: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
sliding_window_mask: torch.Tensor,
|
|
position_ids: Optional[torch.LongTensor],
|
|
local_attention: tuple[int, int],
|
|
bs: int,
|
|
dim: int,
|
|
**_kwargs,
|
|
) -> tuple[torch.Tensor]:
|
|
# qkv: [batch_size, seqlen, 3, nheads, headdim]
|
|
cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
|
|
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
|
|
# query, key, value: [batch_size, heads, seq_len, head_dim]
|
|
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
|
|
|
if local_attention != (-1, -1):
|
|
attention_mask = sliding_window_mask
|
|
|
|
attn_output = (
|
|
F.scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
dropout_p=module.attention_dropout if module.training else 0.0,
|
|
attn_mask=attention_mask,
|
|
)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
)
|
|
attn_output = attn_output.view(bs, -1, dim)
|
|
return (attn_output,)
|
|
|
|
|
|
MODERNBERT_ATTENTION_FUNCTION = {
|
|
"flash_attention_2": flash_attention_forward,
|
|
"eager": eager_attention_forward,
|
|
"sdpa": sdpa_attention_forward,
|
|
}
|
|
|
|
|
|
class ModernBertAttention(nn.Module):
|
|
"""Performs multi-headed self attention on a batch of unpadded sequences.
|
|
|
|
If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
|
|
If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
|
|
which requires padding and unpadding inputs, adding some overhead.
|
|
|
|
See `forward` method for additional details.
|
|
"""
|
|
|
|
def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_id = layer_id
|
|
|
|
if config.hidden_size % config.num_attention_heads != 0:
|
|
raise ValueError(
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
|
|
)
|
|
|
|
self.attention_dropout = config.attention_dropout
|
|
self.deterministic_flash_attn = config.deterministic_flash_attn
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
self.all_head_size = self.head_dim * self.num_heads
|
|
self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
|
|
|
|
if layer_id % config.global_attn_every_n_layers != 0:
|
|
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
|
|
rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
|
|
max_position_embeddings = config.local_attention
|
|
else:
|
|
self.local_attention = (-1, -1)
|
|
max_position_embeddings = config.max_position_embeddings
|
|
rope_theta = config.global_rope_theta
|
|
|
|
if config._attn_implementation == "flash_attention_2":
|
|
self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
|
|
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
|
|
)
|
|
else:
|
|
config_copy = copy.deepcopy(config)
|
|
config_copy.rope_theta = rope_theta
|
|
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
|
|
|
|
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
|
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
|
self.pruned_heads = set()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
output_attentions: Optional[bool] = False,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
qkv = self.Wqkv(hidden_states)
|
|
|
|
bs = hidden_states.shape[0]
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
|
|
else:
|
|
qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
|
|
|
|
attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
|
|
self,
|
|
qkv=qkv,
|
|
rotary_emb=self.rotary_emb,
|
|
local_attention=self.local_attention,
|
|
bs=bs,
|
|
dim=self.all_head_size,
|
|
output_attentions=output_attentions,
|
|
**kwargs,
|
|
)
|
|
hidden_states = attn_outputs[0]
|
|
hidden_states = self.out_drop(self.Wo(hidden_states))
|
|
|
|
return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
|
|
|
|
|
|
class ModernBertEncoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
if layer_id == 0:
|
|
self.attn_norm = nn.Identity()
|
|
else:
|
|
self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
|
|
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
self.mlp = ModernBertMLP(config)
|
|
|
|
@torch.compile(dynamic=True)
|
|
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
return self.mlp(self.mlp_norm(hidden_states))
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> torch.Tensor:
|
|
attn_outputs = self.attn(
|
|
self.attn_norm(hidden_states),
|
|
attention_mask=attention_mask,
|
|
sliding_window_mask=sliding_window_mask,
|
|
position_ids=position_ids,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = hidden_states + attn_outputs[0]
|
|
mlp_output = (
|
|
self.compiled_mlp(hidden_states)
|
|
if self.config.reference_compile
|
|
else self.mlp(self.mlp_norm(hidden_states))
|
|
)
|
|
hidden_states = hidden_states + mlp_output
|
|
|
|
return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
|
|
|
|
|
|
@auto_docstring
|
|
class ModernBertPreTrainedModel(PreTrainedModel):
|
|
config: ModernBertConfig
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
_supports_flex_attn = False
|
|
|
|
def _init_weights(self, module: nn.Module):
|
|
cutoff_factor = self.config.initializer_cutoff_factor
|
|
if cutoff_factor is None:
|
|
cutoff_factor = 3
|
|
|
|
def init_weight(module: nn.Module, std: float):
|
|
nn.init.trunc_normal_(
|
|
module.weight,
|
|
mean=0.0,
|
|
std=std,
|
|
a=-cutoff_factor * std,
|
|
b=cutoff_factor * std,
|
|
)
|
|
|
|
if isinstance(module, nn.Linear):
|
|
if module.bias is not None:
|
|
nn.init.zeros_(module.bias)
|
|
|
|
stds = {
|
|
"in": self.config.initializer_range,
|
|
"out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
|
|
"embedding": self.config.initializer_range,
|
|
"final_out": self.config.hidden_size**-0.5,
|
|
}
|
|
|
|
if isinstance(module, ModernBertEmbeddings):
|
|
init_weight(module.tok_embeddings, stds["embedding"])
|
|
elif isinstance(module, ModernBertMLP):
|
|
init_weight(module.Wi, stds["in"])
|
|
init_weight(module.Wo, stds["out"])
|
|
elif isinstance(module, ModernBertAttention):
|
|
init_weight(module.Wqkv, stds["in"])
|
|
init_weight(module.Wo, stds["out"])
|
|
elif isinstance(module, ModernBertPredictionHead):
|
|
init_weight(module.dense, stds["out"])
|
|
elif isinstance(module, ModernBertForMaskedLM):
|
|
init_weight(module.decoder, stds["out"])
|
|
elif isinstance(
|
|
module,
|
|
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
|
):
|
|
init_weight(module.classifier, stds["final_out"])
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.weight.data.fill_(1.0)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
def _check_and_adjust_attn_implementation(
|
|
self, attn_implementation: Optional[str], is_init_check: bool = False
|
|
) -> str:
|
|
"""
|
|
Checks and dispatches to hhe requested attention implementation.
|
|
"""
|
|
# If the user didn't specify anything, try to use flash_attention_2 if available.
|
|
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
|
|
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
|
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
|
|
|
try:
|
|
attn_implementation = (
|
|
"flash_attention_2"
|
|
if attn_implementation is None and self._flash_attn_2_can_dispatch()
|
|
else attn_implementation
|
|
)
|
|
except (ValueError, ImportError):
|
|
pass
|
|
return super()._check_and_adjust_attn_implementation(
|
|
attn_implementation=attn_implementation, is_init_check=is_init_check
|
|
)
|
|
|
|
def _maybe_set_compile(self):
|
|
if self.config.reference_compile is False:
|
|
return
|
|
|
|
if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
|
|
if self.config.reference_compile:
|
|
logger.warning_once(
|
|
"If `accelerate` split the model across devices, `torch.compile` will not work. "
|
|
"Falling back to non-compiled mode."
|
|
)
|
|
self.config.reference_compile = False
|
|
|
|
if self.device.type == "mps":
|
|
if self.config.reference_compile:
|
|
logger.warning_once(
|
|
"Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
|
|
"Falling back to non-compiled mode."
|
|
)
|
|
self.config.reference_compile = False
|
|
|
|
if self.device.type == "cpu":
|
|
if self.config.reference_compile:
|
|
logger.warning_once(
|
|
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
|
|
"Falling back to non-compiled mode."
|
|
)
|
|
self.config.reference_compile = False
|
|
|
|
if self.config.reference_compile is None:
|
|
self.config.reference_compile = is_triton_available()
|
|
|
|
def resize_token_embeddings(self, *args, **kwargs):
|
|
model_embeds = super().resize_token_embeddings(*args, **kwargs)
|
|
|
|
if self.config.reference_compile in {True, None}:
|
|
if self.config.reference_compile:
|
|
logger.warning_once(
|
|
"Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
|
|
)
|
|
self.config.reference_compile = False
|
|
|
|
return model_embeds
|
|
|
|
|
|
@auto_docstring
|
|
class ModernBertModel(ModernBertPreTrainedModel):
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.embeddings = ModernBertEmbeddings(config)
|
|
self.layers = nn.ModuleList(
|
|
[ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
|
|
)
|
|
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
self.gradient_checkpointing = False
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.tok_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.tok_embeddings = value
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
indices: Optional[torch.Tensor] = None,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
batch_size: Optional[int] = None,
|
|
seq_len: Optional[int] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
|
|
r"""
|
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
far-away tokens in the local attention layers when not using Flash Attention.
|
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
max_seqlen (`int`, *optional*):
|
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
batch_size (`int`, *optional*):
|
|
Batch size of the input sequences. Used to pad the output tensors.
|
|
seq_len (`int`, *optional*):
|
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
self._maybe_set_compile()
|
|
|
|
if input_ids is not None:
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
|
|
if batch_size is None and seq_len is None:
|
|
if inputs_embeds is not None:
|
|
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
else:
|
|
batch_size, seq_len = input_ids.shape[:2]
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
|
|
|
repad = False
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
if indices is None and cu_seqlens is None and max_seqlen is None:
|
|
repad = True
|
|
if inputs_embeds is None:
|
|
with torch.no_grad():
|
|
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
|
inputs=input_ids, attention_mask=attention_mask
|
|
)
|
|
else:
|
|
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
|
inputs=inputs_embeds, attention_mask=attention_mask
|
|
)
|
|
else:
|
|
if position_ids is None:
|
|
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
|
|
|
|
attention_mask, sliding_window_mask = self._update_attention_mask(
|
|
attention_mask, output_attentions=output_attentions
|
|
)
|
|
|
|
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
|
|
|
|
for encoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
sliding_window_mask=sliding_window_mask,
|
|
position_ids=position_ids,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = layer_outputs[0]
|
|
if output_attentions and len(layer_outputs) > 1:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
hidden_states = self.final_norm(hidden_states)
|
|
|
|
if repad:
|
|
hidden_states = _pad_modernbert_output(
|
|
inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
|
|
)
|
|
if all_hidden_states is not None:
|
|
all_hidden_states = tuple(
|
|
_pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
|
|
for hs in all_hidden_states
|
|
)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|
|
|
|
def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
|
|
if output_attentions:
|
|
if self.config._attn_implementation == "sdpa":
|
|
logger.warning_once(
|
|
"Outputting attentions is only supported with the 'eager' attention implementation, "
|
|
'not with "sdpa". Falling back to `attn_implementation="eager"`.'
|
|
)
|
|
self.config._attn_implementation = "eager"
|
|
elif self.config._attn_implementation != "eager":
|
|
logger.warning_once(
|
|
"Outputting attentions is only supported with the eager attention implementation, "
|
|
f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
|
|
" Setting `output_attentions=False`."
|
|
)
|
|
|
|
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
|
|
|
|
# Create position indices
|
|
rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
|
|
# Calculate distance between positions
|
|
distance = torch.abs(rows - rows.T)
|
|
|
|
# Create sliding window mask (1 for positions within window, 0 outside)
|
|
window_mask = (
|
|
(distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
|
|
)
|
|
# Combine with existing mask
|
|
sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
|
|
|
|
return global_attention_mask, sliding_window_mask
|
|
|
|
|
|
class ModernBertPredictionHead(nn.Module):
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
|
|
self.act = ACT2FN[config.classifier_activation]
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
return self.norm(self.act(self.dense(hidden_states)))
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The ModernBert Model with a decoder head on top that is used for masked language modeling.
|
|
"""
|
|
)
|
|
class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
|
_tied_weights_keys = ["decoder.weight"]
|
|
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.model = ModernBertModel(config)
|
|
self.head = ModernBertPredictionHead(config)
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
|
|
|
|
self.sparse_prediction = self.config.sparse_prediction
|
|
self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.decoder
|
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Linear):
|
|
self.decoder = new_embeddings
|
|
|
|
@torch.compile(dynamic=True)
|
|
def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
|
|
return self.decoder(self.head(output))
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
indices: Optional[torch.Tensor] = None,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
batch_size: Optional[int] = None,
|
|
seq_len: Optional[int] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
|
r"""
|
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
far-away tokens in the local attention layers when not using Flash Attention.
|
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
max_seqlen (`int`, *optional*):
|
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
batch_size (`int`, *optional*):
|
|
Batch size of the input sequences. Used to pad the output tensors.
|
|
seq_len (`int`, *optional*):
|
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
self._maybe_set_compile()
|
|
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
if indices is None and cu_seqlens is None and max_seqlen is None:
|
|
if batch_size is None and seq_len is None:
|
|
if inputs_embeds is not None:
|
|
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
else:
|
|
batch_size, seq_len = input_ids.shape[:2]
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
|
|
|
if inputs_embeds is None:
|
|
with torch.no_grad():
|
|
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
|
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
|
)
|
|
else:
|
|
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
|
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
|
)
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
sliding_window_mask=sliding_window_mask,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
indices=indices,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
batch_size=batch_size,
|
|
seq_len=seq_len,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
last_hidden_state = outputs[0]
|
|
|
|
if self.sparse_prediction and labels is not None:
|
|
# flatten labels and output first
|
|
labels = labels.view(-1)
|
|
last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
|
|
|
|
# then filter out the non-masked tokens
|
|
mask_tokens = labels != self.sparse_pred_ignore_index
|
|
last_hidden_state = last_hidden_state[mask_tokens]
|
|
labels = labels[mask_tokens]
|
|
|
|
logits = (
|
|
self.compiled_head(last_hidden_state)
|
|
if self.config.reference_compile
|
|
else self.decoder(self.head(last_hidden_state))
|
|
)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
|
|
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
|
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
|
|
|
|
if not return_dict:
|
|
output = (logits,)
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return MaskedLMOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The ModernBert Model with a sequence classification head on top that performs pooling.
|
|
"""
|
|
)
|
|
class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.config = config
|
|
|
|
self.model = ModernBertModel(config)
|
|
self.head = ModernBertPredictionHead(config)
|
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
indices: Optional[torch.Tensor] = None,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
batch_size: Optional[int] = None,
|
|
seq_len: Optional[int] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
|
|
r"""
|
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
far-away tokens in the local attention layers when not using Flash Attention.
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
max_seqlen (`int`, *optional*):
|
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
batch_size (`int`, *optional*):
|
|
Batch size of the input sequences. Used to pad the output tensors.
|
|
seq_len (`int`, *optional*):
|
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
self._maybe_set_compile()
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
sliding_window_mask=sliding_window_mask,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
indices=indices,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
batch_size=batch_size,
|
|
seq_len=seq_len,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
last_hidden_state = outputs[0]
|
|
|
|
if self.config.classifier_pooling == "cls":
|
|
last_hidden_state = last_hidden_state[:, 0]
|
|
elif self.config.classifier_pooling == "mean":
|
|
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
|
|
dim=1, keepdim=True
|
|
)
|
|
|
|
pooled_output = self.head(last_hidden_state)
|
|
pooled_output = self.drop(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
loss = loss_fct(logits, labels)
|
|
|
|
if not return_dict:
|
|
output = (logits,)
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
|
|
"""
|
|
)
|
|
class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.model = ModernBertModel(config)
|
|
self.head = ModernBertPredictionHead(config)
|
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
indices: Optional[torch.Tensor] = None,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
batch_size: Optional[int] = None,
|
|
seq_len: Optional[int] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
|
r"""
|
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
far-away tokens in the local attention layers when not using Flash Attention.
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
max_seqlen (`int`, *optional*):
|
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
batch_size (`int`, *optional*):
|
|
Batch size of the input sequences. Used to pad the output tensors.
|
|
seq_len (`int`, *optional*):
|
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
self._maybe_set_compile()
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
sliding_window_mask=sliding_window_mask,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
indices=indices,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
batch_size=batch_size,
|
|
seq_len=seq_len,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
last_hidden_state = outputs[0]
|
|
|
|
last_hidden_state = self.head(last_hidden_state)
|
|
last_hidden_state = self.drop(last_hidden_state)
|
|
logits = self.classifier(last_hidden_state)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TokenClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.model = ModernBertModel(config)
|
|
self.head = ModernBertPredictionHead(config)
|
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
start_positions: Optional[torch.Tensor] = None,
|
|
end_positions: Optional[torch.Tensor] = None,
|
|
indices: Optional[torch.Tensor] = None,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
batch_size: Optional[int] = None,
|
|
seq_len: Optional[int] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
|
r"""
|
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
|
far-away tokens in the local attention layers when not using Flash Attention.
|
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
|
max_seqlen (`int`, *optional*):
|
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
|
batch_size (`int`, *optional*):
|
|
Batch size of the input sequences. Used to pad the output tensors.
|
|
seq_len (`int`, *optional*):
|
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
self._maybe_set_compile()
|
|
|
|
outputs = self.model(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
sliding_window_mask=sliding_window_mask,
|
|
position_ids=position_ids,
|
|
indices=indices,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
batch_size=batch_size,
|
|
seq_len=seq_len,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
last_hidden_state = outputs[0]
|
|
|
|
last_hidden_state = self.head(last_hidden_state)
|
|
last_hidden_state = self.drop(last_hidden_state)
|
|
logits = self.classifier(last_hidden_state)
|
|
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1).contiguous()
|
|
end_logits = end_logits.squeeze(-1).contiguous()
|
|
|
|
loss = None
|
|
if start_positions is not None and end_positions is not None:
|
|
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
|
|
|
if not return_dict:
|
|
output = (start_logits, end_logits) + outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return QuestionAnsweringModelOutput(
|
|
loss=loss,
|
|
start_logits=start_logits,
|
|
end_logits=end_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"ModernBertConfig",
|
|
"ModernBertModel",
|
|
"ModernBertPreTrainedModel",
|
|
"ModernBertForMaskedLM",
|
|
"ModernBertForSequenceClassification",
|
|
"ModernBertForTokenClassification",
|
|
"ModernBertForQuestionAnswering",
|
|
]
|