442 lines
20 KiB
Python
442 lines
20 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on Llama implementations in this library and Microsoft's
|
|
# Differential Transformer implementations.
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ...cache_utils import Cache, StaticCache
|
|
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
|
|
from ...utils import logging
|
|
from ..gemma.modeling_gemma import GemmaForCausalLM
|
|
from ..llama.modeling_llama import (
|
|
LlamaDecoderLayer,
|
|
LlamaForQuestionAnswering,
|
|
LlamaForSequenceClassification,
|
|
LlamaForTokenClassification,
|
|
LlamaModel,
|
|
LlamaPreTrainedModel,
|
|
apply_rotary_pos_emb,
|
|
repeat_kv,
|
|
)
|
|
from ..mistral.modeling_mistral import MistralMLP
|
|
from .configuration_diffllama import DiffLlamaConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
|
|
_CONFIG_FOR_DOC = "DiffLlamaConfig"
|
|
|
|
|
|
class DiffLlamaMLP(MistralMLP):
|
|
pass
|
|
|
|
|
|
def lambda_init_fn(layer_idx):
|
|
return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
|
|
|
|
|
|
class DiffLlamaAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
if layer_idx is None:
|
|
logger.warning_once(
|
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
|
"when creating this class."
|
|
)
|
|
|
|
self.attention_dropout = config.attention_dropout
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
# under this are not used
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = config.rope_theta
|
|
self.is_causal = True
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
|
|
|
self.lambda_init = lambda_init_fn(layer_idx)
|
|
self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
|
|
self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
|
|
self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
|
|
self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
|
|
self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
bsz, target_len, _ = hidden_states.size()
|
|
q_len = target_len
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
|
|
value_states = value_states.repeat(1, 2, 1, 1)
|
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
if attention_mask is not None: # no matter the length, we just slice it
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
|
|
query_states.dtype
|
|
)
|
|
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
|
|
query_states.dtype
|
|
)
|
|
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
|
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
|
|
|
|
attn_output = attn_output1 - lambda_full * attn_output2
|
|
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class DiffLlamaFlashAttention2(DiffLlamaAttention):
|
|
"""
|
|
DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
|
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
flash attention and deal with padding tokens in case the input contains any of them.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
|
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
if isinstance(past_key_value, StaticCache):
|
|
raise ValueError(
|
|
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
|
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
|
)
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
# Flash attention requires the input to have the shape
|
|
# batch_size x seq_length x head_dim x hidden_dim
|
|
# therefore we just need to keep the original shape
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
if position_embeddings is None:
|
|
logger.warning_once(
|
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
"removed and `position_embeddings` will be mandatory."
|
|
)
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
else:
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
|
# to be able to avoid many of these transpose/reshape/view.
|
|
query_states = query_states.transpose(1, 2)
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
|
|
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
|
# in fp32. (DiffLlamaRMSNorm handles it correctly)
|
|
|
|
input_dtype = query_states.dtype
|
|
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
|
if input_dtype == torch.float32:
|
|
if torch.is_autocast_enabled():
|
|
target_dtype = (
|
|
torch.get_autocast_dtype(device_type)
|
|
if hasattr(torch, "get_autocast_dtype")
|
|
else torch.get_autocast_gpu_dtype()
|
|
)
|
|
# Handle the case where the model is quantized
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
target_dtype = self.config._pre_quantization_dtype
|
|
else:
|
|
target_dtype = self.q_proj.weight.dtype
|
|
|
|
logger.warning_once(
|
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
f" {target_dtype}."
|
|
)
|
|
|
|
query_states = query_states.to(target_dtype)
|
|
key_states = key_states.to(target_dtype)
|
|
value_states = value_states.to(target_dtype)
|
|
|
|
value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
|
|
value_states1 = value_states1.repeat(1, 1, 2, 1)
|
|
value_states2 = value_states2.repeat(1, 1, 2, 1)
|
|
|
|
attn_output1 = _flash_attention_forward(
|
|
query_states,
|
|
key_states,
|
|
value_states1,
|
|
attention_mask,
|
|
q_len,
|
|
position_ids=position_ids,
|
|
dropout=dropout_rate,
|
|
sliding_window=getattr(self, "sliding_window", None),
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
is_causal=self.is_causal,
|
|
)
|
|
|
|
attn_output2 = _flash_attention_forward(
|
|
query_states,
|
|
key_states,
|
|
value_states2,
|
|
attention_mask,
|
|
q_len,
|
|
position_ids=position_ids,
|
|
dropout=dropout_rate,
|
|
sliding_window=getattr(self, "sliding_window", None),
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
is_causal=self.is_causal,
|
|
)
|
|
|
|
attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
|
|
attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
|
|
|
|
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
|
|
query_states.dtype
|
|
)
|
|
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
|
|
query_states.dtype
|
|
)
|
|
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
|
|
|
attn_output = attn_output1 - lambda_full * attn_output2
|
|
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, None
|
|
|
|
|
|
class DiffLlamaSdpaAttention(DiffLlamaAttention):
|
|
"""
|
|
DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
|
`DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
|
SDPA API.
|
|
"""
|
|
|
|
# Adapted from DiffLlamaAttention.forward
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
|
|
value_states = value_states.repeat(1, 2, 1, 1)
|
|
|
|
causal_mask = attention_mask
|
|
if attention_mask is not None:
|
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
|
|
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
|
query_states = query_states.contiguous()
|
|
key_states = key_states.contiguous()
|
|
value_states = value_states.contiguous()
|
|
|
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attn_mask=causal_mask,
|
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
|
is_causal=is_causal,
|
|
)
|
|
|
|
attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
|
|
|
|
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
|
|
query_states.dtype
|
|
)
|
|
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
|
|
query_states.dtype
|
|
)
|
|
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
|
|
|
attn_output = attn_output1 - lambda_full * attn_output2
|
|
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.view(bsz, q_len, -1)
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, None
|
|
|
|
|
|
DIFFLLAMA_ATTENTION_CLASSES = {
|
|
"eager": DiffLlamaAttention,
|
|
"flash_attention_2": DiffLlamaFlashAttention2,
|
|
"sdpa": DiffLlamaSdpaAttention,
|
|
}
|
|
|
|
|
|
class DiffLlamaDecoderLayer(LlamaDecoderLayer):
|
|
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
|
|
super().__init__(config, layer_idx)
|
|
|
|
self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
|
|
|
|
|
class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
|
|
_supports_flex_attn = False
|
|
_supports_attention_backend = False
|
|
|
|
def _init_weights(self, module):
|
|
LlamaPreTrainedModel._init_weights(module)
|
|
if isinstance(module, DiffLlamaAttention):
|
|
module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
|
|
module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
|
|
module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
|
|
module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
|
|
|
|
|
|
class DiffLlamaModel(LlamaModel):
|
|
pass
|
|
|
|
|
|
class DiffLlamaForCausalLM(GemmaForCausalLM):
|
|
pass
|
|
|
|
|
|
class DiffLlamaForSequenceClassification(LlamaForSequenceClassification):
|
|
pass
|
|
|
|
|
|
class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering):
|
|
pass
|
|
|
|
|
|
class DiffLlamaForTokenClassification(LlamaForTokenClassification):
|
|
pass
|
|
|
|
|
|
__all__ = [
|
|
"DiffLlamaPreTrainedModel",
|
|
"DiffLlamaModel", # noqa: F822
|
|
"DiffLlamaForCausalLM",
|
|
"DiffLlamaForSequenceClassification",
|
|
"DiffLlamaForQuestionAnswering",
|
|
"DiffLlamaForTokenClassification",
|
|
]
|