934 lines
34 KiB
Python
934 lines
34 KiB
Python
# mypy: allow-untyped-defs
|
|
import logging
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn
|
|
import torch.nn.functional as F
|
|
from torch.backends.cuda import (
|
|
can_use_cudnn_attention,
|
|
can_use_efficient_attention,
|
|
can_use_flash_attention,
|
|
cudnn_sdp_enabled,
|
|
flash_sdp_enabled,
|
|
math_sdp_enabled,
|
|
mem_efficient_sdp_enabled,
|
|
SDPAParams,
|
|
)
|
|
from torch.nn.attention import SDPBackend
|
|
|
|
from .nested_tensor import NestedTensor
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def _validate_sdpa_input(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
dropout_p=0.0,
|
|
is_causal=False,
|
|
scale=None,
|
|
):
|
|
if (
|
|
not isinstance(query, NestedTensor)
|
|
or not isinstance(key, NestedTensor)
|
|
or not isinstance(value, NestedTensor)
|
|
):
|
|
raise ValueError(
|
|
f"Expected query, key, and value to be nested tensors, "
|
|
f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
|
|
f"and value.is_nested: {value.is_nested} instead."
|
|
)
|
|
if query.dtype != key.dtype or query.dtype != value.dtype:
|
|
raise ValueError(
|
|
f"Expected query, key, and value to have the same dtype, "
|
|
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
|
|
f"and value.dtype: {value.dtype} instead."
|
|
)
|
|
if query.device != key.device or query.device != value.device:
|
|
raise ValueError(
|
|
f"Expected query, key, and value to have the same device type, "
|
|
f"but got query.device: {query.device}, key.device: {key.device}, "
|
|
f"and value.device: {value.device} instead."
|
|
)
|
|
if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
|
|
raise ValueError(
|
|
f"Expected query, key, and value to all be at least 3 dimensional, but got query.dim: "
|
|
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
|
|
)
|
|
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
|
|
raise ValueError(
|
|
f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
|
|
f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
|
|
)
|
|
if attn_mask is not None:
|
|
# TODO: Figure out whether masks are actually supported for this layout or not
|
|
raise ValueError("Masks are not yet supported!")
|
|
if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
|
|
raise ValueError(
|
|
f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
|
|
f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
|
|
)
|
|
|
|
|
|
def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
|
|
# This is expected to be called after check_tensor_shapes ensuring that the
|
|
# size() calls won't error since the inputs are all 4 dimensional
|
|
q_batch_size = params.query.size(0)
|
|
k_batch_size = params.key.size(0)
|
|
v_batch_size = params.value.size(0)
|
|
|
|
# num_heads logic for nested input is checked in
|
|
# check_for_seq_len_0_nested_tensor as there is handling there to make sure
|
|
# num_heads is not ragged
|
|
return q_batch_size == k_batch_size and q_batch_size == v_batch_size
|
|
|
|
|
|
def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
|
|
max_size = 256
|
|
query_size_last = params.query.size(-1)
|
|
key_size_last = params.key.size(-1)
|
|
value_size_last = params.value.size(-1)
|
|
same_head_dim_size = (
|
|
query_size_last == key_size_last and query_size_last == value_size_last
|
|
)
|
|
if not (
|
|
same_head_dim_size
|
|
and (query_size_last % 8 == 0)
|
|
and (query_size_last <= max_size)
|
|
):
|
|
if debug:
|
|
log.warning(
|
|
"For NestedTensor inputs, Flash attention requires q,k,v to have the same "
|
|
"last dimension and to be a multiple of 8 and less than or equal to 256. "
|
|
"Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
|
|
query_size_last,
|
|
key_size_last,
|
|
value_size_last,
|
|
)
|
|
return False
|
|
return True
|
|
|
|
|
|
def _check_head_dim_size_cudnn_nested(params: SDPAParams, debug=False) -> bool:
|
|
max_size = 128
|
|
query_size_last = params.query.size(-1)
|
|
key_size_last = params.key.size(-1)
|
|
value_size_last = params.value.size(-1)
|
|
same_head_dim_size = (
|
|
query_size_last == key_size_last and query_size_last == value_size_last
|
|
)
|
|
if not (
|
|
same_head_dim_size
|
|
and (query_size_last % 8 == 0)
|
|
and (query_size_last <= max_size)
|
|
):
|
|
if debug:
|
|
log.warning(
|
|
"For NestedTensor inputs, cuDNN attention requires q,k,v to have the same "
|
|
"last dimension and to be a multiple of 8 and less than or equal to 128. "
|
|
"Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
|
|
query_size_last,
|
|
key_size_last,
|
|
value_size_last,
|
|
)
|
|
return False
|
|
return True
|
|
|
|
|
|
def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
|
|
param: torch.Tensor, param_name: str, debug=False
|
|
) -> bool:
|
|
assert isinstance(param, NestedTensor), "param should be a jagged NT"
|
|
|
|
if param._ragged_idx == 1:
|
|
# num_head_dims is ragged
|
|
if debug:
|
|
log.warning(
|
|
"Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
|
|
param_name,
|
|
)
|
|
return False
|
|
|
|
# This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
|
|
if param._get_min_seqlen() == 0:
|
|
if debug:
|
|
log.warning(
|
|
"Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
|
|
param_name,
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
|
|
max_size = max(q_size, k_size, v_size)
|
|
if (
|
|
(q_size != max_size and q_size != 1)
|
|
or (k_size != max_size and k_size != 1)
|
|
or (v_size != max_size and v_size != 1)
|
|
):
|
|
if debug:
|
|
log.warning(
|
|
"Both fused kernels require query, key and value to have broadcastable %s, "
|
|
"got Query %s %d, Key %s %d, Value %s %d instead.",
|
|
param_name,
|
|
param_name,
|
|
q_size,
|
|
param_name,
|
|
k_size,
|
|
param_name,
|
|
v_size,
|
|
)
|
|
return False
|
|
return True
|
|
|
|
|
|
def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
|
|
# When this function is called we are assured that the nt is dim==4
|
|
q_is_safe = (
|
|
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
|
|
params.query, "query", debug
|
|
)
|
|
if params.query.is_nested
|
|
else True
|
|
)
|
|
# short circuit if any is unsafe
|
|
if not q_is_safe:
|
|
return False
|
|
|
|
k_is_safe = (
|
|
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
|
|
params.key, "key", debug
|
|
)
|
|
if params.key.is_nested
|
|
else True
|
|
)
|
|
# short circuit if any is unsafe
|
|
if not k_is_safe:
|
|
return False
|
|
|
|
v_is_safe = (
|
|
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
|
|
params.value, "value", debug
|
|
)
|
|
if params.value.is_nested
|
|
else True
|
|
)
|
|
# short circuit if any is unsafe
|
|
if not v_is_safe:
|
|
return False
|
|
|
|
# We now know none of the inputs have ragged num_heads, so we can safely
|
|
# access .size(1)
|
|
q_num_heads = params.query.size(1)
|
|
k_num_heads = params.key.size(1)
|
|
v_num_heads = params.value.size(1)
|
|
same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
|
|
|
|
if not same_num_heads:
|
|
if (
|
|
params.query.requires_grad
|
|
or params.key.requires_grad
|
|
or params.value.requires_grad
|
|
):
|
|
if debug:
|
|
log.warning(
|
|
"Both fused kernels do not support training with broadcasted NT inputs."
|
|
)
|
|
return False
|
|
return _try_broadcast_param_size(
|
|
q_num_heads, k_num_heads, v_num_heads, "num heads", debug
|
|
)
|
|
return True
|
|
|
|
|
|
def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
|
|
constraints = (
|
|
_check_batch_size_nested,
|
|
_check_head_dim_size_flash_nested,
|
|
_check_for_seq_len_0_nested,
|
|
)
|
|
for constraint in constraints:
|
|
if not constraint(params, debug):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
|
|
constraints = (
|
|
_check_batch_size_nested,
|
|
_check_for_seq_len_0_nested,
|
|
)
|
|
for constraint in constraints:
|
|
if not constraint(params, debug):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
|
|
if (
|
|
not params.query.transpose(1, 2).is_contiguous()
|
|
or not params.key.transpose(1, 2).is_contiguous()
|
|
or not params.value.transpose(1, 2).is_contiguous()
|
|
):
|
|
if debug:
|
|
log.warning(
|
|
"If inputs are nested tensors they must be contiguous after transposing."
|
|
)
|
|
return False
|
|
if params.is_causal:
|
|
if debug:
|
|
log.warning(
|
|
"Nested tensors for query / key are not supported when is_causal=True."
|
|
)
|
|
return False
|
|
return True
|
|
|
|
|
|
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
|
|
if (
|
|
not flash_sdp_enabled()
|
|
and not mem_efficient_sdp_enabled()
|
|
and not math_sdp_enabled()
|
|
and not cudnn_sdp_enabled()
|
|
):
|
|
return SDPBackend.ERROR
|
|
|
|
ordering = (
|
|
SDPBackend.FLASH_ATTENTION,
|
|
SDPBackend.EFFICIENT_ATTENTION,
|
|
SDPBackend.MATH,
|
|
SDPBackend.CUDNN_ATTENTION,
|
|
)
|
|
|
|
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
|
|
|
|
for backend in ordering:
|
|
if backend == SDPBackend.CUDNN_ATTENTION:
|
|
if can_use_cudnn_attention(params):
|
|
return SDPBackend.CUDNN_ATTENTION
|
|
if backend == SDPBackend.FLASH_ATTENTION:
|
|
if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
|
|
return SDPBackend.FLASH_ATTENTION
|
|
if backend == SDPBackend.EFFICIENT_ATTENTION:
|
|
if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
|
|
params
|
|
):
|
|
return SDPBackend.EFFICIENT_ATTENTION
|
|
if backend == SDPBackend.MATH:
|
|
if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
|
|
return SDPBackend.MATH
|
|
|
|
log.warning("Memory efficient kernel not used because:")
|
|
can_use_efficient_attention(params, debug=True)
|
|
_can_use_efficient_sdpa_jagged(params, debug=True)
|
|
log.warning("Flash attention kernel not used because:")
|
|
can_use_flash_attention(params, debug=True)
|
|
_can_use_flash_sdpa_jagged(params, debug=True)
|
|
log.warning("Math attention kernel not used because:")
|
|
_can_use_math_sdpa_jagged(params, debug=True)
|
|
log.warning("cuDNN attention kernel not used because:")
|
|
can_use_cudnn_attention(params, debug=True)
|
|
return SDPBackend.ERROR
|
|
|
|
|
|
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, int, int]:
|
|
# This function is used to calculate two pieces of metadata that are needed
|
|
# for use with flash-attention and efficient_attention kernels. They are the
|
|
# cumulative sequence_length over a batch of sequences and the maximum
|
|
# sequence length.
|
|
|
|
# It returns a tuple of cumulative sequence lengths and the maximum sequence
|
|
# length, and the last element in the cumulative_sequence_lengths
|
|
if not isinstance(qkv, NestedTensor):
|
|
raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
|
|
|
|
if qkv.lengths() is None:
|
|
# TODO: Explore performance impact of copying
|
|
cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
|
|
max_seqlen = qkv._get_max_seqlen()
|
|
n_elem = qkv.values().shape[0]
|
|
else:
|
|
# TODO: Explore performance impact of copying
|
|
cumulative_seqlen = (
|
|
qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
|
|
)
|
|
max_seqlen = qkv._get_max_seqlen()
|
|
# TODO: Explore performance impact when compiling
|
|
n_elem = int(cumulative_seqlen[-1].item())
|
|
return cumulative_seqlen, max_seqlen, n_elem
|
|
|
|
|
|
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
|
|
# This function checks if a nested tensor is valid for
|
|
# use with the flash-attention and efficient_attention kernels without
|
|
# needing to call contiguous on the nested tensor input.
|
|
# It checks that the storage offsets' adjacent_differences are a constant
|
|
# mutiple of the previous tensor in the nested tensor and that the strides
|
|
# are monitonically decreasing. This check is done after calling transpose on
|
|
# the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
|
|
|
|
# Returns a boolean indicating if contiguous needs to be called for input
|
|
assert isinstance(tensor, NestedTensor)
|
|
offsets = tensor.offsets()
|
|
strides = tensor._strides
|
|
|
|
n_tensors = offsets.size(0) - 1
|
|
if n_tensors <= 1:
|
|
return True
|
|
|
|
# Check initially that the tensor strides are in strictly descending order
|
|
prev_stride = strides[1]
|
|
for stride in strides[2:]:
|
|
if prev_stride <= stride:
|
|
# This would mean that the last stride is greater than the seq_len
|
|
# stride
|
|
return False
|
|
prev_stride = stride
|
|
|
|
# Congrats you made it!
|
|
return True
|
|
|
|
|
|
def _view_as_dense(
|
|
tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
|
|
) -> torch.Tensor:
|
|
if tensor.is_nested:
|
|
return tensor.values()
|
|
return tensor.view(Nnz, num_heads, head_dim)
|
|
|
|
|
|
# TODO: Next iteration should add test cases and check it works
|
|
# def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
|
|
# # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
|
|
# # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
|
|
# # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
|
|
# q_batch_size = query.size(0)
|
|
# k_batch_size = key.size(0)
|
|
# v_batch_size = value.size(0)
|
|
|
|
# output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
|
|
|
|
# q_num_heads = query.size(1)
|
|
# k_num_heads = key.size(1)
|
|
# v_num_heads = value.size(1)
|
|
|
|
# output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
|
|
|
|
# head_dim_qk = query.size(3)
|
|
# head_dim_v = value.size(3)
|
|
|
|
# q_t = query.transpose(1, 2)
|
|
# k_t = key.transpose(1, 2)
|
|
# v_t = value.transpose(1, 2)
|
|
|
|
# # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
|
|
# # output_batch_size/num_heads then they are 1
|
|
# q_batch_size_needs_broadcast = q_batch_size != output_batch_size
|
|
# k_batch_size_needs_broadcast = k_batch_size != output_batch_size
|
|
# v_batch_size_needs_broadcast = v_batch_size != output_batch_size
|
|
|
|
# # If {*}_batch_size_needs_broadcast, then
|
|
# # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
|
|
# # this is because needs_broadcast indicates that the batch_size is 1
|
|
# # and hence there is only 1 value for seq_len
|
|
# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
|
|
# # ..., outut_batch_size * {*}_t.size(1)]
|
|
# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
|
|
|
|
# if q_batch_size_needs_broadcast or not q_t.is_nested:
|
|
# max_seqlen_batch_q = q_t.size(1)
|
|
# cumulative_sequence_length_q = torch.arange(
|
|
# 0,
|
|
# (output_batch_size + 1) * max_seqlen_batch_q,
|
|
# max_seqlen_batch_q,
|
|
# device=q_t.device,
|
|
# dtype=torch.int32,
|
|
# )
|
|
# Nnz_q = output_batch_size * max_seqlen_batch_q
|
|
# else:
|
|
# (
|
|
# cumulative_sequence_length_q,
|
|
# max_seqlen_batch_q,
|
|
# Nnz_q,
|
|
# ) = _cumulative_and_max_seq_len_nnz(q_t)
|
|
|
|
# if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
|
|
# assert k_t.size(1) == v_t.size(1)
|
|
# max_seqlen_batch_kv = k_t.size(1)
|
|
# cumulative_sequence_length_kv = torch.arange(
|
|
# 0,
|
|
# (output_batch_size + 1) * max_seqlen_batch_kv,
|
|
# max_seqlen_batch_kv,
|
|
# device=k_t.device,
|
|
# dtype=torch.int32,
|
|
# )
|
|
# Nnz_kv = output_batch_size * max_seqlen_batch_kv
|
|
# else:
|
|
# cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
|
|
# _cumulative_and_max_seq_len_nnz(v_t)
|
|
# if k_batch_size_needs_broadcast
|
|
# else _cumulative_and_max_seq_len_nnz(k_t)
|
|
# )
|
|
|
|
# q_num_heads_needs_broadcast = q_num_heads != output_num_heads
|
|
# k_num_heads_needs_broadcast = k_num_heads != output_num_heads
|
|
# v_num_heads_needs_broadcast = v_num_heads != output_num_heads
|
|
|
|
# if not q_t.is_nested:
|
|
# query_buffer_reshaped = q_t.expand(
|
|
# output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
|
|
# )
|
|
# query_buffer_reshaped = query_buffer_reshaped.reshape(
|
|
# Nnz_q, output_num_heads, head_dim_qk
|
|
# )
|
|
# else:
|
|
# if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
|
|
# q_t = q_t.contiguous()
|
|
# # If we are broadcasting then Nnz_q will be the output_batch_size since
|
|
# # seq_len is 1
|
|
# effective_batch_size_q = (
|
|
# output_batch_size if q_batch_size_needs_broadcast else Nnz_q
|
|
# )
|
|
# query_buffer_reshaped = _view_as_dense(
|
|
# q_t, effective_batch_size_q, output_num_heads, head_dim_qk
|
|
# )
|
|
|
|
# # If the physical layout of the NestedTensor's storage
|
|
# # is not: batch, {seq_len}, num_heads, head_dim then we need
|
|
# # to call contiguous
|
|
# if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
|
|
# k_t = k_t.contiguous()
|
|
# if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
|
|
# v_t = v_t.contiguous()
|
|
|
|
# effective_batch_size_k = (
|
|
# output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
|
|
# )
|
|
# key_buffer_reshaped = _view_as_dense(
|
|
# k_t, effective_batch_size_k, output_num_heads, head_dim_qk
|
|
# )
|
|
|
|
# effective_batch_size_v = (
|
|
# output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
|
|
# )
|
|
# value_buffer_reshaped = _view_as_dense(
|
|
# v_t, effective_batch_size_v, output_num_heads, head_dim_v
|
|
# )
|
|
|
|
# if not q_batch_size_needs_broadcast:
|
|
# output_shape = q_t._size
|
|
# if head_dim_v != head_dim_qk:
|
|
# output_shape[-1] = head_dim_v
|
|
# if q_num_heads_needs_broadcast:
|
|
# output_shape[1] = output_num_heads
|
|
# else:
|
|
# output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
|
|
# output_shape[0] = q_t.size(1)
|
|
# output_shape[1] = output_num_heads
|
|
# output_shape[2] = head_dim_v
|
|
|
|
# return (
|
|
# query_buffer_reshaped,
|
|
# key_buffer_reshaped,
|
|
# value_buffer_reshaped,
|
|
# cumulative_sequence_length_q,
|
|
# cumulative_sequence_length_kv,
|
|
# max_seqlen_batch_q,
|
|
# max_seqlen_batch_kv,
|
|
# output_shape,
|
|
# )
|
|
|
|
|
|
def _sdpa_nested_preprocessing(query, key, value):
|
|
# Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
|
|
# Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
|
|
# Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
|
|
q_batch_size = query.size(0)
|
|
k_batch_size = key.size(0)
|
|
v_batch_size = value.size(0)
|
|
|
|
q_num_heads = query.size(1)
|
|
k_num_heads = key.size(1)
|
|
v_num_heads = value.size(1)
|
|
|
|
if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
|
|
q_num_heads == k_num_heads and k_num_heads == v_num_heads
|
|
):
|
|
raise RuntimeError(
|
|
"This path is currently not implemented for jagged layout NT."
|
|
)
|
|
# return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
|
|
|
|
num_heads = query.size(1)
|
|
head_dim_qk = query.size(3)
|
|
head_dim_v = value.size(3)
|
|
q_t = query.transpose(1, 2)
|
|
k_t = key.transpose(1, 2)
|
|
v_t = value.transpose(1, 2)
|
|
|
|
(
|
|
cumulative_sequence_length_q,
|
|
max_seqlen_batch_q,
|
|
Nnz_q,
|
|
) = _cumulative_and_max_seq_len_nnz(q_t)
|
|
(
|
|
cumulative_sequence_length_kv,
|
|
max_seqlen_batch_kv,
|
|
Nnz_kv,
|
|
) = _cumulative_and_max_seq_len_nnz(k_t)
|
|
|
|
# [TODO] K and V have to have the same Nnz, should probably torch_check
|
|
# assume in order to not iterate over v
|
|
|
|
# If the physical layout of the NestedTensor's storage
|
|
# is not: batch, {seq_len}, num_heads, head_dim then we need
|
|
# to call contiguous
|
|
if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
|
|
q_t = q_t.contiguous()
|
|
if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
|
|
k_t = k_t.contiguous()
|
|
if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
|
|
v_t = v_t.contiguous()
|
|
|
|
query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
|
|
key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
|
|
value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
|
|
|
|
output_nt_info = {
|
|
"offsets": q_t.offsets(),
|
|
"lengths": q_t.lengths(),
|
|
"max_seqlen": q_t._get_max_seqlen(),
|
|
"min_seqlen": q_t._get_min_seqlen(),
|
|
}
|
|
|
|
return (
|
|
query_buffer_reshaped,
|
|
key_buffer_reshaped,
|
|
value_buffer_reshaped,
|
|
cumulative_sequence_length_q,
|
|
cumulative_sequence_length_kv,
|
|
max_seqlen_batch_q,
|
|
max_seqlen_batch_kv,
|
|
output_nt_info,
|
|
)
|
|
|
|
|
|
def _pad_last_dim(
|
|
tensor: torch.Tensor, alignment_size: int, slice: bool
|
|
) -> torch.Tensor:
|
|
# FlashAttentionV2 requires that head dimension be a multiple of 8
|
|
# This was previously done within the kernel, however
|
|
# This causes the kernel to maybe alias query, key, value
|
|
# So instead we pad the head_dimensions to be a multiple of 8
|
|
# in the composite region
|
|
last_dim_size = tensor.size(-1)
|
|
if last_dim_size % alignment_size == 0:
|
|
return tensor
|
|
pad_count = alignment_size - (last_dim_size % alignment_size)
|
|
tensor = torch.nn.functional.pad(tensor, [0, pad_count])
|
|
if slice:
|
|
return tensor[..., 0:last_dim_size]
|
|
return tensor
|
|
|
|
|
|
# TODO: coalesce with torch/nn/utils/attention.py
|
|
def _calculate_scale(query, scale):
|
|
# TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
|
|
softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
|
|
return softmax_scale
|
|
|
|
|
|
def _post_process_flash_output(out: torch.Tensor, og_size):
|
|
if not out.is_nested and out.size(-1) != og_size:
|
|
out = out[..., 0:og_size]
|
|
return out
|
|
|
|
|
|
def _is_computing_meta_flops(x):
|
|
# Note: there's a use case of using meta tensors & the dispatch-based flop counter.
|
|
# We can use this function to check for this scenario in order to handle it specially.
|
|
if not torch.jit.is_scripting() and x.device.type == "meta":
|
|
torch_dispatch_mode_stack = (
|
|
torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
|
)
|
|
return any(
|
|
type(x) == torch.utils.flop_counter._FlopCounterMode
|
|
for x in torch_dispatch_mode_stack
|
|
)
|
|
return False
|
|
|
|
|
|
def _autocast(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor],
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
"""
|
|
[Autocasting SDPA for NJT]
|
|
|
|
Normal autocasting doesn't work for NJT+SDPA right now:
|
|
* NJT intercepts the __torch_function__ call for scaled_dot_product_attention, which happens
|
|
before we get to any aten ops or dispatcher logic; then the torch_function logic calls into
|
|
efficient attention or flash attention. So, autocasting on the scaled_dot_product_attention
|
|
op won't work because we never see that aten op.
|
|
* If we put autocasting on `_flash_attention_forward`, then we'll get autocasting to run, but
|
|
the kernel selection logic in torch_function handling (ie. jagged_scaled_dot_product_attention)
|
|
won't work correctly: the kernel selection logic will run before autocasting, and choose
|
|
a kernel based on the un-autocasted dtypes; but then autocasting will run and the actual
|
|
attention computation will happen in a different dtype.
|
|
|
|
An alternative is to just change the backend selection logic for SDPA+NJT to be autocast-aware
|
|
and rely on autocasting to do the actual conversions for flash attention / efficient attention.
|
|
However, by manually doing the actual autocast before the backend selection, we ensure that the
|
|
autocast handling for backend selection doesn't diverge from the autocast handling for the
|
|
actual dtype conversions.
|
|
"""
|
|
device_type = query.device.type
|
|
# meta device is not supported by autocast, so break early for it
|
|
if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type):
|
|
return query, key, value, attn_mask
|
|
|
|
def cvt(x):
|
|
if x is None:
|
|
return x
|
|
target_dtype = torch.get_autocast_dtype(device_type)
|
|
if (
|
|
(not x.dtype.is_floating_point)
|
|
or x.dtype == target_dtype
|
|
or x.dtype == torch.float64
|
|
):
|
|
return x
|
|
return x.to(target_dtype)
|
|
|
|
return cvt(query), cvt(key), cvt(value), cvt(attn_mask)
|
|
|
|
|
|
def jagged_scaled_dot_product_attention(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
dropout_p=0.0,
|
|
is_causal=False,
|
|
scale=None,
|
|
enable_gqa=False,
|
|
):
|
|
query, key, value, attn_mask = _autocast(query, key, value, attn_mask)
|
|
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
|
# for mypy, ugh
|
|
assert (
|
|
isinstance(query, NestedTensor)
|
|
and isinstance(key, NestedTensor)
|
|
and isinstance(value, NestedTensor)
|
|
)
|
|
from torch.nested._internal.nested_tensor import (
|
|
nested_view_from_values_offsets_lengths,
|
|
)
|
|
|
|
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
|
|
# second batch dim instead). For this case, we can just send the dense buffers through
|
|
# vanilla SDPA.
|
|
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
|
|
output = F.scaled_dot_product_attention(
|
|
query.values(),
|
|
key.values(),
|
|
value.values(),
|
|
attn_mask=(
|
|
attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
|
|
),
|
|
dropout_p=dropout_p,
|
|
is_causal=is_causal,
|
|
scale=scale,
|
|
)
|
|
return nested_view_from_values_offsets_lengths(
|
|
output,
|
|
query.offsets(),
|
|
query.lengths(),
|
|
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
|
|
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
|
|
)
|
|
|
|
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
|
|
|
|
backend_choice = _select_sdp_backend(
|
|
query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
|
|
)
|
|
|
|
if _is_computing_meta_flops(query):
|
|
# Backend choice will probably not be correct if we have a meta device,
|
|
# because backend choice is device-aware. In this case, we mostly just
|
|
# want to avoid using math backend (which does a .item() call).
|
|
# Arbitrarily choose flash attention.
|
|
backend_choice = SDPBackend.FLASH_ATTENTION
|
|
|
|
if backend_choice == SDPBackend.FLASH_ATTENTION:
|
|
og_size = query.size(-1)
|
|
query_padded = _pad_last_dim(query, 8, False)
|
|
key_padded = _pad_last_dim(key, 8, False)
|
|
value_padded = _pad_last_dim(value, 8, False)
|
|
# We need to calculate the scale based off the OG head dim size
|
|
og_scale = _calculate_scale(query, scale)
|
|
(
|
|
query_buffer_reshaped,
|
|
key_buffer_reshaped,
|
|
value_buffer_reshaped,
|
|
cumulative_sequence_length_q,
|
|
cumulative_sequence_length_kv,
|
|
max_seqlen_batch_q,
|
|
max_seqlen_batch_kv,
|
|
output_nt_info,
|
|
) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
|
|
(
|
|
attention,
|
|
_logsumexp,
|
|
_philox_seed,
|
|
_philox_offset,
|
|
_debug_attn_mask,
|
|
) = torch.ops.aten._flash_attention_forward(
|
|
query_buffer_reshaped,
|
|
key_buffer_reshaped,
|
|
value_buffer_reshaped,
|
|
cumulative_sequence_length_q,
|
|
cumulative_sequence_length_kv,
|
|
max_seqlen_batch_q,
|
|
max_seqlen_batch_kv,
|
|
dropout_p,
|
|
is_causal,
|
|
False,
|
|
scale=og_scale,
|
|
)
|
|
# Reshape output to convert nnz to batch_size and seq_len
|
|
attention = nested_view_from_values_offsets_lengths(
|
|
attention, # output from flash_attn is [total_q, num_heads, head_size_og]
|
|
**output_nt_info,
|
|
).transpose(1, 2)
|
|
return _post_process_flash_output(attention, og_size)
|
|
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
|
|
(
|
|
query_reshaped,
|
|
key_reshaped,
|
|
value_reshaped,
|
|
cumulative_sequence_length_q,
|
|
cumulative_sequence_length_kv,
|
|
max_seqlen_batch_q,
|
|
max_seqlen_batch_kv,
|
|
output_nt_info,
|
|
) = _sdpa_nested_preprocessing(query, key, value)
|
|
(
|
|
attention,
|
|
log_sumexp,
|
|
seed,
|
|
offset,
|
|
max_seqlen_q,
|
|
max_seqlen_batch_kv,
|
|
) = torch.ops.aten._efficient_attention_forward(
|
|
query_reshaped.unsqueeze(0),
|
|
key_reshaped.unsqueeze(0),
|
|
value_reshaped.unsqueeze(0),
|
|
None,
|
|
cumulative_sequence_length_q,
|
|
cumulative_sequence_length_kv,
|
|
max_seqlen_batch_q,
|
|
max_seqlen_batch_kv,
|
|
dropout_p,
|
|
int(is_causal),
|
|
compute_logsumexp,
|
|
scale=scale,
|
|
)
|
|
# Reshape output to convert nnz to batch_size and seq_len
|
|
return nested_view_from_values_offsets_lengths(
|
|
attention.squeeze(0),
|
|
**output_nt_info,
|
|
).transpose(1, 2)
|
|
elif backend_choice == SDPBackend.CUDNN_ATTENTION:
|
|
(
|
|
query_reshaped,
|
|
key_reshaped,
|
|
value_reshaped,
|
|
cumulative_sequence_length_q,
|
|
cumulative_sequence_length_kv,
|
|
max_seqlen_batch_q,
|
|
max_seqlen_batch_kv,
|
|
output_nt_info,
|
|
) = _sdpa_nested_preprocessing(query, key, value)
|
|
(
|
|
attention,
|
|
logsumexp,
|
|
cum_seqlen_q,
|
|
cum_seqlen_kv,
|
|
max_seqlen_q,
|
|
max_seqlen_kv,
|
|
seed,
|
|
offset,
|
|
_,
|
|
) = torch.ops.aten._cudnn_attention_forward(
|
|
query_reshaped,
|
|
key_reshaped,
|
|
value_reshaped,
|
|
attn_mask,
|
|
cumulative_sequence_length_q,
|
|
cumulative_sequence_length_kv,
|
|
max_seqlen_batch_q,
|
|
max_seqlen_batch_kv,
|
|
compute_logsumexp,
|
|
dropout_p,
|
|
is_causal,
|
|
False,
|
|
scale=scale,
|
|
)
|
|
return nested_view_from_values_offsets_lengths(
|
|
attention,
|
|
**output_nt_info,
|
|
).transpose(1, 2)
|
|
elif backend_choice == SDPBackend.MATH:
|
|
# save the offsets and shape of the inputs, so we can reshape the final output
|
|
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
|
|
# attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
|
|
offsets = query.offsets()
|
|
q_lengths = query.lengths()
|
|
min_seqlen = query._maybe_min_seqlen
|
|
max_seqlen = query._maybe_max_seqlen
|
|
d1 = query._size[1]
|
|
d2 = value._size[-1]
|
|
|
|
# convert jagged layout Nested Tensor to strided layout Nested Tensor
|
|
# which support the math implementation of SDPA
|
|
def get_strided_layout_nested_tensor(jagged_layout_nt):
|
|
lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
|
|
transpose = torch.transpose(jagged_layout_nt, 1, 2)
|
|
tensor_list = transpose.values().split(list(lengths), dim=0)
|
|
strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
|
|
strided_nt = strided_nt.transpose(1, 2).contiguous()
|
|
return strided_nt
|
|
|
|
query = get_strided_layout_nested_tensor(query)
|
|
key = get_strided_layout_nested_tensor(key)
|
|
value = get_strided_layout_nested_tensor(value)
|
|
|
|
attn_out = torch._scaled_dot_product_attention_math(
|
|
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
|
|
)[0]
|
|
|
|
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
|
|
attn_out = attn_out.transpose(1, 2).contiguous().values()
|
|
attn_out = attn_out.view(-1, d1, d2)
|
|
attn_out = nested_view_from_values_offsets_lengths(
|
|
attn_out,
|
|
offsets,
|
|
lengths=q_lengths,
|
|
min_seqlen=min_seqlen,
|
|
max_seqlen=max_seqlen,
|
|
).transpose(1, 2)
|
|
|
|
return attn_out
|
|
else:
|
|
raise RuntimeError(
|
|
"No viable backend for scaled_dot_product_attention was found."
|
|
)
|