426 lines
15 KiB
Python
426 lines
15 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Optional
|
|
|
|
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
import torch.nn as nn
|
|
import triton
|
|
import triton.language as tl
|
|
from torch.nn import functional as F
|
|
|
|
if is_accelerate_available():
|
|
from accelerate import init_empty_weights
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
|
@triton.jit
|
|
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
|
s = tl.max(tl.abs(x)) / 448.0
|
|
y = x / s
|
|
y = y.to(y_ptr.dtype.element_ty)
|
|
tl.store(y_ptr + offs, y)
|
|
tl.store(s_ptr + pid, s)
|
|
|
|
|
|
def act_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
|
assert x.is_contiguous()
|
|
assert x.shape[-1] % block_size == 0
|
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
|
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
|
|
|
|
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
|
return y, s
|
|
|
|
|
|
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
|
|
@triton.jit
|
|
def _w8a8_block_fp8_matmul(
|
|
# Pointers to inputs and output
|
|
A,
|
|
B,
|
|
C,
|
|
As,
|
|
Bs,
|
|
# Shape for matmul
|
|
M,
|
|
N,
|
|
K,
|
|
# Block size for block-wise quantization
|
|
group_n,
|
|
group_k,
|
|
# Stride for inputs and output
|
|
stride_am,
|
|
stride_ak,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
stride_As_m,
|
|
stride_As_k,
|
|
stride_Bs_k,
|
|
stride_Bs_n,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
):
|
|
"""Triton-accelerated function used to perform linear operations (dot
|
|
product) on input tensors `A` and `B` with block-wise quantization, and
|
|
store the result in output tensor `C`.
|
|
"""
|
|
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
As_ptrs = As + offs_am * stride_As_m
|
|
offs_bsn = offs_bn // group_n
|
|
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
|
|
k_start = k * BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
|
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
if C.dtype.element_ty == tl.bfloat16:
|
|
c = accumulator.to(tl.bfloat16)
|
|
elif C.dtype.element_ty == tl.float16:
|
|
c = accumulator.to(tl.float16)
|
|
else:
|
|
c = accumulator.to(tl.float32)
|
|
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
|
|
def w8a8_block_fp8_matmul_triton(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
block_size: list[int],
|
|
output_dtype: torch.dtype = torch.float32,
|
|
) -> torch.Tensor:
|
|
"""This function performs matrix multiplication with block-wise
|
|
quantization.
|
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
|
The output is returned in the specified `output_dtype`.
|
|
Args:
|
|
A: The input tensor, e.g., activation.
|
|
B: The input tensor, e.g., weight.
|
|
As: The per-token-group quantization scale for `A`.
|
|
Bs: The per-block quantization scale for `B`.
|
|
block_size: The block size for per-block quantization. It should
|
|
be 2-dim, e.g., [128, 128].
|
|
output_dytpe: The dtype of the returned tensor.
|
|
Returns:
|
|
torch.Tensor: The result of matmul.
|
|
"""
|
|
assert len(block_size) == 2
|
|
block_n, block_k = block_size[0], block_size[1]
|
|
|
|
assert A.shape[-1] == B.shape[-1]
|
|
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
|
M = A.numel() // A.shape[-1]
|
|
|
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
|
N, K = B.shape
|
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
|
|
|
C_shape = A.shape[:-1] + (N,)
|
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
|
|
|
BLOCK_SIZE_M = 128
|
|
if M < BLOCK_SIZE_M:
|
|
BLOCK_SIZE_M = triton.next_power_of_2(M)
|
|
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
|
|
BLOCK_SIZE_K = block_k
|
|
assert block_k % BLOCK_SIZE_K == 0
|
|
BLOCK_SIZE_N = block_n
|
|
|
|
def grid(META):
|
|
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
|
|
|
|
_w8a8_block_fp8_matmul[grid](
|
|
A,
|
|
B,
|
|
C,
|
|
As,
|
|
Bs,
|
|
M,
|
|
N,
|
|
K,
|
|
block_n,
|
|
block_k,
|
|
A.stride(-2),
|
|
A.stride(-1),
|
|
B.stride(1),
|
|
B.stride(0),
|
|
C.stride(-2),
|
|
C.stride(-1),
|
|
As.stride(-2),
|
|
As.stride(-1),
|
|
Bs.stride(1),
|
|
Bs.stride(0),
|
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
GROUP_SIZE_M=8,
|
|
)
|
|
|
|
return C
|
|
|
|
|
|
# Python version of the above triton function, it's much slower than the triton version, for testing
|
|
@torch.compile
|
|
def w8a8_block_fp8_matmul_compile(
|
|
input_q: torch.Tensor, # [batch, seq_len, hidden_dim]
|
|
weight_q: torch.Tensor, # [out_features, hidden_dim]
|
|
input_scale: torch.Tensor, # [batch * seq_len, num_input_groups]
|
|
weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n]
|
|
block_size: Optional[tuple[int, int]] = None, # (M=128, N=128) for weights for example
|
|
output_dtype: torch.dtype = torch.float32,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Performs blocked matrix multiplication with FP8 quantized matrices.
|
|
|
|
Args:
|
|
input_q: Quantized input tensor with 1x128 block quantization
|
|
weight_q: Quantized weight tensor with 128x128 block quantization
|
|
input_scale: Scaling factors for input blocks
|
|
weight_scale: Scaling factors for weight blocks
|
|
block_size: Tuple of (M, N) for weight block dimensions
|
|
output_dtype: Desired output dtype
|
|
"""
|
|
batch_size, seq_len, hidden_dim = input_q.shape if input_q.ndim == 3 else (1, input_q.shape[0], input_q.shape[1])
|
|
out_features = weight_q.shape[0]
|
|
|
|
# Reshape input for batched matmul
|
|
input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim]
|
|
input_scale_reshaped = input_scale.view(input_scale.shape[0], -1) # [batch*seq_len, 1]
|
|
# Calculate number of blocks
|
|
num_weight_blocks_m = out_features // block_size[0]
|
|
num_weight_blocks_n = hidden_dim // block_size[1]
|
|
|
|
output = torch.zeros((batch_size * seq_len, out_features), dtype=torch.float32, device=input_q.device)
|
|
|
|
for i in range(num_weight_blocks_m):
|
|
m_start = i * block_size[0]
|
|
m_end = m_start + block_size[0]
|
|
|
|
for j in range(num_weight_blocks_n):
|
|
n_start = j * block_size[1]
|
|
n_end = n_start + block_size[1]
|
|
|
|
# Extract current blocks
|
|
input_block = input_reshaped[:, n_start:n_end]
|
|
weight_block = weight_q[m_start:m_end, n_start:n_end]
|
|
|
|
# Get corresponding scales
|
|
curr_input_scale = input_scale_reshaped[:, j : j + 1] # [batch*seq_len, 1]
|
|
curr_weight_scale = weight_scale[i, j] # scalar
|
|
|
|
block_result = (
|
|
torch._scaled_mm(
|
|
input_block,
|
|
weight_block.t(),
|
|
scale_a=torch.tensor(1, dtype=torch.float32, device=input_q.device),
|
|
scale_b=curr_weight_scale,
|
|
out_dtype=output_dtype,
|
|
)
|
|
* curr_input_scale
|
|
)
|
|
|
|
output[:, m_start:m_end] += block_result
|
|
|
|
output = output.view(batch_size, seq_len, out_features)
|
|
|
|
return output.to(output_dtype)
|
|
|
|
|
|
class FP8Linear(nn.Linear):
|
|
dtype = torch.float8_e4m3fn
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
bias: bool = False,
|
|
dtype=None,
|
|
block_size: Optional[tuple[int, int]] = None,
|
|
device=None,
|
|
activation_scheme="dynamic",
|
|
):
|
|
super().__init__(in_features, out_features)
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
|
|
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
|
|
|
|
if self.weight.element_size() == 1:
|
|
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
|
|
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
|
|
self.weight_scale_inv = nn.Parameter(
|
|
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
|
|
)
|
|
else:
|
|
self.register_parameter("weight_scale_inv", None)
|
|
|
|
self.block_size = block_size
|
|
|
|
self.activation_scheme = activation_scheme
|
|
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.empty(self.out_features))
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
if self.weight.element_size() > 1:
|
|
return F.linear(input, self.weight, self.bias)
|
|
else:
|
|
# Context manager used to switch among the available accelerators
|
|
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
|
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
|
with torch_accelerator_module.device(input.device):
|
|
qinput, scale = act_quant(input, self.block_size[1])
|
|
output = w8a8_block_fp8_matmul_triton(
|
|
qinput,
|
|
self.weight,
|
|
scale,
|
|
self.weight_scale_inv,
|
|
self.block_size,
|
|
output_dtype=input.dtype,
|
|
)
|
|
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
|
# preceding operations are ready before proceeding
|
|
torch_accelerator_module.synchronize()
|
|
if self.bias is not None:
|
|
output = output + self.bias
|
|
return output.to(dtype=input.dtype)
|
|
|
|
|
|
def _replace_with_fp8_linear(
|
|
model,
|
|
tp_plan=None,
|
|
modules_to_not_convert=None,
|
|
current_key_name=None,
|
|
quantization_config=None,
|
|
has_been_replaced=False,
|
|
):
|
|
"""Replace Linear layers with FP8Linear."""
|
|
if current_key_name is None:
|
|
current_key_name = []
|
|
|
|
for name, module in model.named_children():
|
|
current_key_name.append(name)
|
|
|
|
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
|
|
current_key_name_str = ".".join(current_key_name)
|
|
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
|
with init_empty_weights():
|
|
model._modules[name] = FP8Linear(
|
|
in_features=module.in_features,
|
|
out_features=module.out_features,
|
|
bias=module.bias is not None,
|
|
device=module.weight.device,
|
|
dtype=module.weight.dtype,
|
|
activation_scheme=quantization_config.activation_scheme,
|
|
block_size=quantization_config.weight_block_size,
|
|
)
|
|
has_been_replaced = True
|
|
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
|
|
|
if len(list(module.children())) > 0:
|
|
_, has_been_replaced = _replace_with_fp8_linear(
|
|
module,
|
|
tp_plan,
|
|
modules_to_not_convert,
|
|
current_key_name,
|
|
quantization_config,
|
|
has_been_replaced=has_been_replaced,
|
|
)
|
|
|
|
current_key_name.pop(-1)
|
|
|
|
return model, has_been_replaced
|
|
|
|
|
|
def replace_with_fp8_linear(
|
|
model,
|
|
modules_to_not_convert=None,
|
|
quantization_config=None,
|
|
):
|
|
"""Helper function to replace model layers with FP8 versions."""
|
|
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
|
|
|
if quantization_config.modules_to_not_convert is not None:
|
|
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
|
modules_to_not_convert = list(set(modules_to_not_convert))
|
|
model, has_been_replaced = _replace_with_fp8_linear(
|
|
model,
|
|
tp_plan=model._tp_plan,
|
|
modules_to_not_convert=modules_to_not_convert,
|
|
quantization_config=quantization_config,
|
|
)
|
|
|
|
if not has_been_replaced:
|
|
logger.warning(
|
|
"You are loading your model using fp8 but no linear modules were found in your model."
|
|
" Please double check your model architecture."
|
|
)
|
|
|
|
return model
|