67 lines
2.9 KiB
Python
67 lines
2.9 KiB
Python
import torch
|
|
|
|
from ..generation.continuous_batching import PagedAttentionCache
|
|
from ..utils import is_flash_attn_2_available
|
|
|
|
|
|
if is_flash_attn_2_available():
|
|
from flash_attn import flash_attn_varlen_func # noqa: F401
|
|
|
|
|
|
def paged_attention_forward(
|
|
module: torch.nn.Module,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
attention_mask: torch.Tensor = None,
|
|
cache: PagedAttentionCache = None,
|
|
cumulative_seqlens_q=None,
|
|
cumulative_seqlens_k=None,
|
|
max_seqlen_q=None,
|
|
max_seqlen_k=None,
|
|
block_tables=None,
|
|
implementation=None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
r"""Perform the forward pass of attention with paged key-value cache.
|
|
|
|
This function handles the cache updates and performs the attention computation
|
|
using the flash_attn_varlen_func for efficient processing.
|
|
|
|
Args:
|
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
|
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full k
|
|
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full v
|
|
cumulative_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into q.
|
|
cumulative_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into kv.
|
|
max_seqlen_q: int. Maximum query sequence length in the batch.
|
|
max_seqlen_k: int. Maximum key sequence length in the batch.
|
|
dropout_p: float. Dropout probability.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(headdim).
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
softcap: float. Anything > 0 activates softcapping attention.
|
|
"""
|
|
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
|
|
|
|
if implementation is not None:
|
|
flash_attn_varlen_func = implementation.flash_attn_varlen_func
|
|
attn_output = flash_attn_varlen_func(
|
|
q.transpose(1, 2).squeeze(0).contiguous(),
|
|
k.transpose(1, 2).squeeze(0).contiguous(),
|
|
v.transpose(1, 2).squeeze(0).contiguous(),
|
|
cumulative_seqlens_q.to(torch.int32),
|
|
cumulative_seqlens_k.to(torch.int32).clone(),
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
softmax_scale=module.scaling,
|
|
causal=True, # kind of a must, it automatically aligns the mask for q < k
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
# block_table=block_tables, -> torch.Tensor
|
|
# **kwargs,
|
|
)
|
|
|
|
return attn_output, None
|