team-10/venv/Lib/site-packages/transformers/integrations/flash_attention.py
2025-08-02 02:00:33 +02:00

82 lines
3 KiB
Python

from typing import Optional
import torch
from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
from ..utils import logging
logger = logging.get_logger(__name__)
_use_top_left_mask = flash_attn_supports_top_left_mask()
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None:
logger.warning_once(
"`flash_attention_2` does not support `output_attentions=True` or `head_mask`."
" Please set your attention to `eager` if you want any of these features."
)
# This is before the transpose
seq_len = query.shape[2]
if any(dim == 0 for dim in query.shape):
raise ValueError(
"Tensor query has shape with a zero dimension.\n"
"FlashAttention does not support inputs with dim=0.\n"
"Please check your input shapes or use SDPA instead."
)
# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (usually our RMSNorm modules handle it correctly)
target_dtype = None
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
target_dtype = module.config._pre_quantization_dtype
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
query_length=seq_len,
is_causal=module.is_causal,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
**kwargs,
)
return attn_output, None