team-10/venv/Lib/site-packages/transformers/integrations/flash_paged.py
2025-08-02 02:00:33 +02:00

67 lines
2.9 KiB
Python

import torch
from ..generation.continuous_batching import PagedAttentionCache
from ..utils import is_flash_attn_2_available
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func # noqa: F401
def paged_attention_forward(
module: torch.nn.Module,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor = None,
cache: PagedAttentionCache = None,
cumulative_seqlens_q=None,
cumulative_seqlens_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
block_tables=None,
implementation=None,
**kwargs,
) -> torch.Tensor:
r"""Perform the forward pass of attention with paged key-value cache.
This function handles the cache updates and performs the attention computation
using the flash_attn_varlen_func for efficient processing.
Args:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full k
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full v
cumulative_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cumulative_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
"""
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
if implementation is not None:
flash_attn_varlen_func = implementation.flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
q.transpose(1, 2).squeeze(0).contiguous(),
k.transpose(1, 2).squeeze(0).contiguous(),
v.transpose(1, 2).squeeze(0).contiguous(),
cumulative_seqlens_q.to(torch.int32),
cumulative_seqlens_k.to(torch.int32).clone(),
max_seqlen_q,
max_seqlen_k,
softmax_scale=module.scaling,
causal=True, # kind of a must, it automatically aligns the mask for q < k
window_size=(-1, -1), # -1 means infinite context window
# block_table=block_tables, -> torch.Tensor
# **kwargs,
)
return attn_output, None