# coding=utf-8 # Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved. # # This code is based on Llama implementations in this library and Microsoft's # Differential Transformer implementations. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional import torch from torch import nn from ...cache_utils import Cache, StaticCache from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ...utils import logging from ..gemma.modeling_gemma import GemmaForCausalLM from ..llama.modeling_llama import ( LlamaDecoderLayer, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, apply_rotary_pos_emb, repeat_kv, ) from ..mistral.modeling_mistral import MistralMLP from .configuration_diffllama import DiffLlamaConfig logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut" _CONFIG_FOR_DOC = "DiffLlamaConfig" class DiffLlamaMLP(MistralMLP): pass def lambda_init_fn(layer_idx): return 0.8 - 0.6 * math.exp(-0.3 * layer_idx) class DiffLlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads # under this are not used self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) self.lambda_init = lambda_init_fn(layer_idx) self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, target_len, _ = hidden_states.size() q_len = target_len query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) value_states = value_states.repeat(1, 2, 1, 1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( query_states.dtype ) lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( query_states.dtype ) lambda_full = lambda_1 - lambda_2 + self.lambda_init attn_output = torch.matmul(attn_weights, value_states) attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) attn_output = attn_output1 - lambda_full * attn_output2 attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) return attn_output, attn_weights class DiffLlamaFlashAttention2(DiffLlamaAttention): """ DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (DiffLlamaRMSNorm handles it correctly) input_dtype = query_states.dtype device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = ( torch.get_autocast_dtype(device_type) if hasattr(torch, "get_autocast_dtype") else torch.get_autocast_gpu_dtype() ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) value_states1, value_states2 = torch.chunk(value_states, 2, dim=2) value_states1 = value_states1.repeat(1, 1, 2, 1) value_states2 = value_states2.repeat(1, 1, 2, 1) attn_output1 = _flash_attention_forward( query_states, key_states, value_states1, attention_mask, q_len, position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, ) attn_output2 = _flash_attention_forward( query_states, key_states, value_states2, attention_mask, q_len, position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, ) attn_output = torch.cat([attn_output1, attn_output2], dim=-1) attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2) lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( query_states.dtype ) lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( query_states.dtype ) lambda_full = lambda_1 - lambda_2 + self.lambda_init attn_output = attn_output1 - lambda_full * attn_output2 attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None class DiffLlamaSdpaAttention(DiffLlamaAttention): """ DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ # Adapted from DiffLlamaAttention.forward def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) value_states = value_states.repeat(1, 2, 1, 1) causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( query_states.dtype ) lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( query_states.dtype ) lambda_full = lambda_1 - lambda_2 + self.lambda_init attn_output = attn_output1 - lambda_full * attn_output2 attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) return attn_output, None DIFFLLAMA_ATTENTION_CLASSES = { "eager": DiffLlamaAttention, "flash_attention_2": DiffLlamaFlashAttention2, "sdpa": DiffLlamaSdpaAttention, } class DiffLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: DiffLlamaConfig, layer_idx: int): super().__init__(config, layer_idx) self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): _supports_flex_attn = False _supports_attention_backend = False def _init_weights(self, module): LlamaPreTrainedModel._init_weights(module) if isinstance(module, DiffLlamaAttention): module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) module.lambda_k2.data.normal_(0, self.config.lambda_std_dev) class DiffLlamaModel(LlamaModel): pass class DiffLlamaForCausalLM(GemmaForCausalLM): pass class DiffLlamaForSequenceClassification(LlamaForSequenceClassification): pass class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering): pass class DiffLlamaForTokenClassification(LlamaForTokenClassification): pass __all__ = [ "DiffLlamaPreTrainedModel", "DiffLlamaModel", # noqa: F822 "DiffLlamaForCausalLM", "DiffLlamaForSequenceClassification", "DiffLlamaForQuestionAnswering", "DiffLlamaForTokenClassification", ]