from typing import Optional import torch from ..utils import logging from ..utils.import_utils import is_torch_greater_or_equal logger = logging.get_logger(__name__) _is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -> bool: # GQA can only be used under the following conditions # 1. torch version >= 2.5 # 2. attention_mask is None (otherwise it will fall back to the math kernel) # 3. key is not a torch.fx.Proxy (otherwise it will fail with a tracing error) return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy) def sdpa_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, is_causal: Optional[bool] = None, **kwargs, ) -> tuple[torch.Tensor, None]: if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None: logger.warning_once( "`sdpa` attention does not support `output_attentions=True` or `head_mask`." " Please set your attention to `eager` if you want any of these features." ) sdpa_kwargs = {} if hasattr(module, "num_key_value_groups"): if not use_gqa_in_sdpa(attention_mask, key): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) else: sdpa_kwargs = {"enable_gqa": True} if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions # Reference: https://github.com/pytorch/pytorch/issues/112577. query = query.contiguous() key = key.contiguous() value = value.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` if is_causal is None: # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools. if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): is_causal = is_causal.item() attn_output = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=dropout, scale=scaling, is_causal=is_causal, **sdpa_kwargs, ) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None