270 lines
10 KiB
Python
270 lines
10 KiB
Python
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import os
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from ..utils.import_utils import is_torch_npu_available
|
||
|
|
||
|
|
||
|
if is_torch_npu_available():
|
||
|
import math
|
||
|
|
||
|
import torch_npu
|
||
|
from einops import rearrange, repeat
|
||
|
from torch_npu import npu_rotary_mul
|
||
|
|
||
|
|
||
|
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
|
||
|
# Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
|
||
|
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
|
||
|
DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3
|
||
|
|
||
|
SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE))
|
||
|
if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]:
|
||
|
raise ValueError(
|
||
|
"Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) "
|
||
|
"or 3 (down-right aligned causal mask)."
|
||
|
)
|
||
|
|
||
|
ATTN_MASK_NPU_CACHE = {}
|
||
|
|
||
|
|
||
|
def get_attn_mask_npu(device):
|
||
|
"""Get or create attention mask for the specified device."""
|
||
|
if device not in ATTN_MASK_NPU_CACHE:
|
||
|
ATTN_MASK_NPU_CACHE[device] = torch.triu(torch.ones([2048, 2048], device=device), diagonal=1).bool()
|
||
|
return ATTN_MASK_NPU_CACHE[device]
|
||
|
|
||
|
|
||
|
def is_npu_fa2_top_left_aligned_causal_mask():
|
||
|
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
|
||
|
|
||
|
|
||
|
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||
|
class IndexFirstAxis(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, input, indices):
|
||
|
ctx.save_for_backward(indices)
|
||
|
assert input.ndim >= 2
|
||
|
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
||
|
second_dim = other_shape.numel()
|
||
|
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||
|
# return input[indices]
|
||
|
return torch.gather(
|
||
|
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
|
||
|
).reshape(-1, *other_shape)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
(indices,) = ctx.saved_tensors
|
||
|
assert grad_output.ndim >= 2
|
||
|
other_shape = grad_output.shape[1:]
|
||
|
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
||
|
grad_input = torch.zeros(
|
||
|
[ctx.first_axis_dim, grad_output.shape[1]],
|
||
|
device=grad_output.device,
|
||
|
dtype=grad_output.dtype,
|
||
|
)
|
||
|
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
||
|
# grad_input[indices] = grad_output
|
||
|
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
||
|
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
||
|
|
||
|
|
||
|
index_first_axis = IndexFirstAxis.apply
|
||
|
|
||
|
|
||
|
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||
|
class IndexPutFirstAxis(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, values, indices, first_axis_dim):
|
||
|
ctx.save_for_backward(indices)
|
||
|
assert indices.ndim == 1
|
||
|
assert values.ndim >= 2
|
||
|
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
|
||
|
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
||
|
output[indices] = values
|
||
|
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
|
||
|
return output
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
(indices,) = ctx.saved_tensors
|
||
|
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||
|
grad_values = grad_output[indices]
|
||
|
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
|
||
|
return grad_values, None, None
|
||
|
|
||
|
|
||
|
index_put_first_axis = IndexPutFirstAxis.apply
|
||
|
|
||
|
|
||
|
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||
|
def pad_input(hidden_states, indices, batch, seqlen):
|
||
|
"""
|
||
|
Arguments:
|
||
|
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||
|
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
||
|
batch: int, batch size for the padded sequence.
|
||
|
seqlen: int, maximum sequence length for the padded sequence.
|
||
|
Return:
|
||
|
hidden_states: (batch, seqlen, ...)
|
||
|
"""
|
||
|
# dim = hidden_states.shape[-1]
|
||
|
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
||
|
# output[indices] = hidden_states
|
||
|
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
||
|
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
||
|
|
||
|
|
||
|
# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||
|
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||
|
"""
|
||
|
Arguments:
|
||
|
hidden_states: (batch, seqlen, ...)
|
||
|
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||
|
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
||
|
Return:
|
||
|
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
||
|
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
|
||
|
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
||
|
max_seqlen_in_batch: int
|
||
|
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
|
||
|
"""
|
||
|
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
||
|
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
||
|
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||
|
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||
|
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
||
|
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
||
|
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
||
|
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
||
|
# so we write custom forward and backward to make it a bit faster.
|
||
|
return (
|
||
|
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
||
|
indices,
|
||
|
cu_seqlens,
|
||
|
max_seqlen_in_batch,
|
||
|
used_seqlens_in_batch,
|
||
|
)
|
||
|
|
||
|
|
||
|
def npu_flash_attn_func(
|
||
|
q,
|
||
|
k,
|
||
|
v,
|
||
|
dropout_p=0.0,
|
||
|
softmax_scale=None,
|
||
|
causal=False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
keep_prob = 1.0 - dropout_p
|
||
|
|
||
|
if softmax_scale is None:
|
||
|
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
|
||
|
|
||
|
if not causal:
|
||
|
head_num = q.shape[2]
|
||
|
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
|
||
|
else:
|
||
|
attn_mask_npu = get_attn_mask_npu(q.device)
|
||
|
head_num = q.shape[2]
|
||
|
output = torch_npu.npu_fusion_attention(
|
||
|
q,
|
||
|
k,
|
||
|
v,
|
||
|
head_num,
|
||
|
"BSND",
|
||
|
keep_prob=keep_prob,
|
||
|
scale=softmax_scale,
|
||
|
atten_mask=attn_mask_npu,
|
||
|
sparse_mode=SPARSE_MODE,
|
||
|
)[0]
|
||
|
|
||
|
return output
|
||
|
|
||
|
|
||
|
def npu_flash_attn_varlen_func(
|
||
|
q,
|
||
|
k,
|
||
|
v,
|
||
|
cu_seqlens_q,
|
||
|
cu_seqlens_k,
|
||
|
max_seqlen_q=None, # defined for aligning params order with corresponding function in `flash-attn`
|
||
|
max_seqlen_k=None, # defined for aligning params order with corresponding function in `flash-attn`
|
||
|
dropout_p=0.0,
|
||
|
softmax_scale=None,
|
||
|
causal=False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
keep_prob = 1.0 - dropout_p
|
||
|
|
||
|
if softmax_scale is None:
|
||
|
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
|
||
|
|
||
|
if not causal:
|
||
|
head_num = q.shape[1]
|
||
|
output = torch_npu.npu_fusion_attention(
|
||
|
q,
|
||
|
k,
|
||
|
v,
|
||
|
head_num,
|
||
|
pse=None,
|
||
|
atten_mask=None,
|
||
|
scale=softmax_scale,
|
||
|
keep_prob=keep_prob,
|
||
|
input_layout="TND",
|
||
|
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
|
||
|
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
|
||
|
)[0]
|
||
|
else:
|
||
|
attn_mask_npu = get_attn_mask_npu(q.device)
|
||
|
head_num = q.shape[1]
|
||
|
output = torch_npu.npu_fusion_attention(
|
||
|
q,
|
||
|
k,
|
||
|
v,
|
||
|
head_num,
|
||
|
pse=None,
|
||
|
padding_mask=None,
|
||
|
atten_mask=attn_mask_npu,
|
||
|
scale=softmax_scale,
|
||
|
keep_prob=keep_prob,
|
||
|
input_layout="TND",
|
||
|
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
|
||
|
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
|
||
|
sparse_mode=SPARSE_MODE,
|
||
|
)[0]
|
||
|
|
||
|
return output
|
||
|
|
||
|
|
||
|
def npu_apply_rotary_emb(x, cos, sin, **kwargs):
|
||
|
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
|
||
|
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
|
||
|
cos = cos.repeat(1, 2)
|
||
|
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
|
||
|
cos = cos.unsqueeze(0).unsqueeze(2)
|
||
|
|
||
|
# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
|
||
|
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
|
||
|
sin = sin.repeat(1, 2)
|
||
|
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
|
||
|
sin = sin.unsqueeze(0).unsqueeze(2)
|
||
|
|
||
|
return npu_rotary_mul(x, cos, sin)
|