1623 lines
64 KiB
Python
1623 lines
64 KiB
Python
# Copyright 2025 NXAI GmbH. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch xLSTM Model."""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
from ...generation import GenerationMixin
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_xlstm_available
|
|
from .configuration_xlstm import xLSTMConfig
|
|
|
|
|
|
if is_xlstm_available():
|
|
from xlstm.xlstm_large.model import mLSTMBlock as xLSTMBlock
|
|
from xlstm.xlstm_large.model import mLSTMStateType, soft_cap
|
|
from xlstm.xlstm_large.model import xLSTMRMSNorm as xLSTMRMSNorm
|
|
|
|
external_xlstm = True
|
|
else:
|
|
from functools import partial
|
|
from typing import Callable, Literal
|
|
|
|
from .configuration_xlstm import round_up_to_next_multiple_of
|
|
|
|
mLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
mLSTMStateType = dict[int, mLSTMLayerStateType]
|
|
|
|
external_xlstm = False
|
|
|
|
def soft_cap(values: torch.Tensor, cap_value: Optional[Union[float, torch.Tensor]] = None) -> torch.Tensor:
|
|
"""
|
|
Soft caps a tensor to a value.
|
|
|
|
Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention
|
|
and output language heads to prevent large logits from dominating the softmax. See for example Gemma2:
|
|
https://arxiv.org/abs/2408.00118
|
|
|
|
Args:
|
|
values: The tensor to cap.
|
|
cap_value: The value to cap the values to. If None, no cap is applied.
|
|
|
|
Returns:
|
|
The capped values.
|
|
"""
|
|
if cap_value is None:
|
|
return values
|
|
return cap_value * torch.tanh(values / cap_value)
|
|
|
|
def mlstm_chunkwise_recurrent_fw_C(
|
|
matK: torch.Tensor,
|
|
matV: torch.Tensor,
|
|
vecB: torch.Tensor,
|
|
vecI: torch.Tensor,
|
|
matC_states: torch.Tensor = None,
|
|
vecN_states: torch.Tensor = None,
|
|
scaMinter_states: torch.Tensor = None,
|
|
matC_initial: torch.Tensor = None,
|
|
vecN_initial: torch.Tensor = None,
|
|
scaMinter_initial: torch.Tensor = None,
|
|
qk_scale: Optional[float] = None,
|
|
chunk_size: int = 64,
|
|
num_chunks: int = 1,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
batch_size, nh, _, dhqk, dhhv = *matK.shape, matV.shape[-1]
|
|
nc = num_chunks
|
|
_dtype, _device = matK.dtype, matK.device
|
|
|
|
if qk_scale is None:
|
|
qk_scale = dhqk**-0.5
|
|
|
|
# initialize the states tensors
|
|
if matC_states is None:
|
|
matC_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk, dhhv), dtype=_dtype, device=_device)
|
|
if vecN_states is None:
|
|
vecN_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk), dtype=_dtype, device=_device)
|
|
if scaMinter_states is None:
|
|
scaMinter_states = torch.zeros((batch_size, nh, (nc + 1)), dtype=_dtype, device=_device)
|
|
|
|
# assign the initial states to the running states
|
|
matC_k = (
|
|
torch.zeros((batch_size, nh, dhqk, dhhv), dtype=_dtype, device=_device)
|
|
if matC_initial is None
|
|
else matC_initial
|
|
)
|
|
vecN_k = (
|
|
torch.zeros((batch_size, nh, dhqk), dtype=_dtype, device=_device) if vecN_initial is None else vecN_initial
|
|
)
|
|
scaM_inter_k = (
|
|
torch.zeros((batch_size, nh, 1), dtype=_dtype, device=_device)
|
|
if scaMinter_initial is None
|
|
else scaMinter_initial
|
|
)
|
|
vecA = vecB[..., -1, None] - vecB + vecI
|
|
scaG = vecB[..., -1]
|
|
scaA_max = vecA.max(-1).values
|
|
|
|
scaM_inter_k = scaM_inter_k.squeeze(-1)
|
|
|
|
for key in range(0, num_chunks):
|
|
# store the states from the previous iteration before updating them
|
|
# in the first iteration, these are the initial states
|
|
matC_states[:, :, key * dhqk : (key + 1) * dhqk, :] = matC_k
|
|
vecN_states[:, :, key * dhqk : (key + 1) * dhqk] = vecN_k
|
|
scaMinter_states[:, :, key] = scaM_inter_k
|
|
|
|
# m_k update
|
|
scaA_max_k = scaA_max[:, :, key]
|
|
scaG_k = scaG[:, :, key]
|
|
scaM_inter_k_next = torch.max(scaG_k + scaM_inter_k, scaA_max_k)
|
|
# C_k update
|
|
matK_chunk = matK[:, :, key * chunk_size : (key + 1) * chunk_size, :] # * qk_scale
|
|
matV_chunk = matV[:, :, key * chunk_size : (key + 1) * chunk_size, :]
|
|
vecA_k = vecA[:, :, key, :]
|
|
|
|
vecAbar_k = torch.exp(vecA_k - scaM_inter_k_next[..., None])[:, :, :, None]
|
|
|
|
matK_chunk_gated = matK_chunk * vecAbar_k
|
|
|
|
scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
|
|
|
|
# NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
|
|
matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
|
|
|
|
# n_k update
|
|
vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1)
|
|
|
|
# move to the next iteration
|
|
scaM_inter_k = scaM_inter_k_next
|
|
matC_k = matC_k_next
|
|
vecN_k = vecN_k_next
|
|
|
|
# store the states from the last iteration
|
|
matC_states[:, :, -dhqk:, :] = matC_k
|
|
vecN_states[:, :, -dhqk:] = vecN_k
|
|
scaMinter_states[:, :, -1] = scaM_inter_k
|
|
|
|
return matC_states, vecN_states, scaMinter_states
|
|
|
|
def mlstm_chunkwise_parallel_fw_H(
|
|
matQ: torch.Tensor,
|
|
matK: torch.Tensor,
|
|
matV: torch.Tensor,
|
|
# these states must be all states up to the last chunk, i.e. :-1
|
|
matC_states: torch.Tensor,
|
|
vecN_states: torch.Tensor,
|
|
scaMinter_states: torch.Tensor,
|
|
vecI: torch.Tensor,
|
|
vecB: torch.Tensor,
|
|
qk_scale: float,
|
|
chunk_size: int = 64,
|
|
num_chunks: int = 1,
|
|
eps: float = 1e-6,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
_device = matQ.device
|
|
nc, chunk_size = num_chunks, chunk_size
|
|
batch_size, nh, dqk, dhv = matC_states.shape
|
|
matC_k_states = matC_states.view(batch_size, nh, nc, dqk // nc, dhv)
|
|
vecN_k_states = vecN_states.view(batch_size, nh, nc, dqk // nc)
|
|
scaMinter_k_states = scaMinter_states
|
|
|
|
matQ = matQ.view(batch_size, nh, nc, chunk_size, dqk)
|
|
matK = matK.view(batch_size, nh, nc, chunk_size, dqk)
|
|
matV = matV.view(batch_size, nh, nc, chunk_size, dhv)
|
|
|
|
ltr = torch.tril(
|
|
torch.ones(
|
|
(chunk_size, chunk_size),
|
|
dtype=torch.bool,
|
|
device=_device,
|
|
)
|
|
)
|
|
|
|
# Compute intra chunk contribution: H_intra
|
|
matF_logsig_chunk = vecB[:, :, :, :, None] - vecB[:, :, :, None, :]
|
|
|
|
matF_logsig_mask_chunk = torch.where(ltr, matF_logsig_chunk, -float("inf"))
|
|
|
|
matLogD_chunk = matF_logsig_mask_chunk + vecI[:, :, :, None, :]
|
|
|
|
# max_state intra
|
|
vecMintra_k = torch.max(matLogD_chunk, dim=-1, keepdim=False).values
|
|
|
|
# max_state combined
|
|
vecM_b_inter = vecB + scaMinter_k_states[:, :, :, None]
|
|
vecM_k_combine = torch.maximum(vecM_b_inter, vecMintra_k)
|
|
|
|
vecM_k_combine = vecM_k_combine[:, :, :, :, None]
|
|
vecM_b_inter = vecM_b_inter[:, :, :, :, None]
|
|
|
|
matLogD_stabilized_chunk = matLogD_chunk - vecM_k_combine
|
|
matD_chunk = torch.exp(matLogD_stabilized_chunk)
|
|
|
|
matS_chunk = (matQ @ matK.transpose(-2, -1)) * qk_scale
|
|
|
|
matM_chunk = matS_chunk * matD_chunk
|
|
|
|
# ? Combine H_intra with H_inter
|
|
vecBbar = torch.exp(vecM_b_inter - vecM_k_combine)
|
|
matQ_chunk_gated = matQ * vecBbar * qk_scale
|
|
|
|
matNumerator_common = matQ_chunk_gated @ matC_k_states + matM_chunk @ matV
|
|
|
|
vecDenom_l_common = matQ_chunk_gated @ vecN_k_states.unsqueeze(-1) + matM_chunk.sum(dim=-1, keepdim=True)
|
|
|
|
vecDenom_max_common = torch.maximum(torch.abs(vecDenom_l_common), torch.exp(-vecM_k_combine))
|
|
|
|
matH_k_chunk = matNumerator_common / (vecDenom_max_common + eps)
|
|
|
|
matH_out = matH_k_chunk.view(batch_size, nh, nc * chunk_size, dhv)
|
|
|
|
# we need the denominator and the overall max state for the backward pass
|
|
vecN_out = vecDenom_max_common.reshape(batch_size, nh, nc * chunk_size)
|
|
vecM_out = vecM_k_combine(batch_size, nh, nc * chunk_size)
|
|
return matH_out, vecN_out, vecM_out
|
|
|
|
def mlstm_chunkwise_fw(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
igate: torch.Tensor,
|
|
fgate: torch.Tensor,
|
|
cstate: torch.Tensor = None,
|
|
nstate: torch.Tensor = None,
|
|
mstate: torch.Tensor = None,
|
|
qk_scale: Optional[float] = None,
|
|
return_last_states: bool = False,
|
|
return_all_states: bool = False,
|
|
chunk_size: int = 64,
|
|
eps: float = 1e-6,
|
|
) -> tuple[
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
|
|
Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
|
|
]:
|
|
batch_size, nh, sequence_length, dhqk = query.shape
|
|
if sequence_length % chunk_size != 0:
|
|
raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
|
|
nc = sequence_length // chunk_size
|
|
|
|
vecI = igate.view(batch_size, nh, nc, chunk_size)
|
|
vecF = fgate.view(batch_size, nh, nc, chunk_size)
|
|
|
|
# compute the gates, the g and the a and b vectors
|
|
vecF_logsig = fgate.logsigmoid(vecF)
|
|
vecB = vecF_logsig.cumsum(-1)
|
|
|
|
if qk_scale is None:
|
|
qk_scale = dhqk**-0.5
|
|
|
|
#! materialize the C_k, n_k, m_k states for each chunk
|
|
matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
|
|
matK=key,
|
|
matV=value,
|
|
vecB=vecB,
|
|
vecI=vecI,
|
|
matC_initial=cstate,
|
|
vecN_initial=nstate,
|
|
scaMinter_initial=mstate,
|
|
qk_scale=qk_scale,
|
|
chunk_size=chunk_size,
|
|
num_chunks=nc,
|
|
)
|
|
|
|
#! compute the outputs within each chunk
|
|
matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
|
|
matQ=query,
|
|
matK=key,
|
|
matV=value,
|
|
matC_states=matC_k_states[:, :, :-dhqk, :],
|
|
vecN_states=vecN_k_states[:, :, :-dhqk],
|
|
scaMinter_states=scaMinter_k_states[:, :, :-1],
|
|
vecI=vecI,
|
|
vecB=vecB,
|
|
qk_scale=qk_scale,
|
|
chunk_size=chunk_size,
|
|
num_chunks=nc,
|
|
eps=eps,
|
|
)
|
|
|
|
ret_tuple = (matH_out, vecN_out, vecM_out)
|
|
if return_last_states:
|
|
ret_tuple += (
|
|
(matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]),
|
|
)
|
|
else:
|
|
ret_tuple += (None,)
|
|
|
|
if return_all_states:
|
|
ret_tuple += ((matC_k_states, vecN_k_states, scaMinter_k_states),)
|
|
else:
|
|
ret_tuple += (None,)
|
|
|
|
return ret_tuple
|
|
|
|
def mlstm_chunkwise_native_autograd(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
igate: torch.Tensor,
|
|
fgate: torch.Tensor,
|
|
c_initial: torch.Tensor = None,
|
|
n_initial: torch.Tensor = None,
|
|
m_initial: torch.Tensor = None,
|
|
return_last_states: bool = False,
|
|
eps: float = 1e-6,
|
|
chunk_size: int = 64,
|
|
**kwargs,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
|
|
batch_size, nh, sequence_length, dhqk = query.shape
|
|
if sequence_length % chunk_size != 0:
|
|
raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
|
|
nc = sequence_length // chunk_size
|
|
|
|
vecI = igate.view(batch_size, nh, nc, chunk_size)
|
|
vecF = fgate.view(batch_size, nh, nc, chunk_size)
|
|
|
|
# compute the gates, the g and the a and b vectors
|
|
vecF_logsig = F.logsigmoid(vecF)
|
|
vecB = vecF_logsig.cumsum(-1)
|
|
|
|
qk_scale = dhqk**-0.5
|
|
|
|
#! materialize the C_k, n_k, m_k states for each chunk
|
|
matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
|
|
matK=key,
|
|
matV=value,
|
|
vecB=vecB,
|
|
vecI=vecI,
|
|
matC_initial=c_initial,
|
|
vecN_initial=n_initial,
|
|
scaMinter_initial=m_initial,
|
|
qk_scale=qk_scale,
|
|
chunk_size=chunk_size,
|
|
num_chunks=nc,
|
|
)
|
|
|
|
#! compute the outputs within each chunk
|
|
matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
|
|
matQ=query,
|
|
matK=key,
|
|
matV=value,
|
|
matC_states=matC_k_states[:, :, :-dhqk, :],
|
|
vecN_states=vecN_k_states[:, :, :-dhqk],
|
|
scaMinter_states=scaMinter_k_states[:, :, :-1],
|
|
vecI=vecI,
|
|
vecB=vecB,
|
|
qk_scale=qk_scale,
|
|
chunk_size=chunk_size,
|
|
num_chunks=nc,
|
|
eps=eps,
|
|
)
|
|
|
|
last_states = (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:])
|
|
|
|
if return_last_states:
|
|
return matH_out, last_states
|
|
else:
|
|
return matH_out
|
|
|
|
def mlstm_recurrent_step_native(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
igate: torch.Tensor,
|
|
fgate: torch.Tensor,
|
|
cstate: torch.Tensor,
|
|
nstate: torch.Tensor,
|
|
mstate: torch.Tensor,
|
|
eps: float = 1e-6,
|
|
dtype_state: torch.dtype = torch.float32,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
"""This is a single step of the mLSTM operation in recurrent form."""
|
|
dtype_qkv = query.dtype
|
|
matC_old = cstate.to(dtype=dtype_state)
|
|
vecN_old = nstate.to(dtype=dtype_state)
|
|
scaM_old = mstate.to(dtype=dtype_state)
|
|
|
|
batch_size, nh, dhqk = query.shape
|
|
_, _, dhhv = value.shape
|
|
if query.shape != key.shape:
|
|
raise ValueError("query and key must have the same shape")
|
|
if matC_old.shape != (batch_size, nh, dhqk, dhhv):
|
|
raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}")
|
|
if vecN_old.shape != (batch_size, nh, dhqk):
|
|
raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}")
|
|
if scaM_old.shape != (batch_size, nh, 1):
|
|
raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}")
|
|
if igate.shape != (batch_size, nh, 1):
|
|
raise ValueError(f"scaI has wrong shape, got {igate.shape}")
|
|
if fgate.shape != (batch_size, nh, 1):
|
|
raise ValueError(f"scaF has wrong shape, got {fgate.shape}")
|
|
|
|
# gates
|
|
scaF_log = torch.nn.functional.logsigmoid(fgate)
|
|
|
|
# update rule
|
|
scaM_state_new = torch.max(scaF_log + scaM_old, igate)
|
|
|
|
scaF_act = torch.exp(scaF_log + scaM_old - scaM_state_new)
|
|
scaI_act = torch.exp(igate - scaM_state_new)
|
|
|
|
vecQ_scaled = query * (dhqk ** (-0.5))
|
|
matC_state_new = scaF_act[:, :, :, None] * matC_old + scaI_act[:, :, :, None] * (
|
|
key[:, :, :, None] @ value[:, :, None, :]
|
|
)
|
|
vecN_state_new = scaF_act * vecN_old + scaI_act * key
|
|
h_num = vecQ_scaled[:, :, None, :] @ matC_state_new.to(dtype=dtype_qkv)
|
|
h_num = h_num.squeeze(2).to(dtype=dtype_state)
|
|
|
|
qn_dotproduct = vecQ_scaled[:, :, None, :] @ vecN_state_new[:, :, :, None].to(dtype=dtype_qkv)
|
|
qn_dotproduct = qn_dotproduct.squeeze(2)
|
|
max_val = torch.exp(-scaM_state_new)
|
|
h_denom = (torch.maximum(qn_dotproduct.abs(), max_val) + eps).to(dtype=dtype_state)
|
|
h = h_num / h_denom
|
|
|
|
h = h.to(dtype=dtype_qkv)
|
|
matC_state_new = matC_state_new.to(dtype=dtype_state)
|
|
vecN_state_new = vecN_state_new.to(dtype=dtype_state)
|
|
scaM_state_new = scaM_state_new.to(dtype=dtype_state)
|
|
return h, (matC_state_new, vecN_state_new, scaM_state_new)
|
|
|
|
def mlstm_recurrent_sequence_native(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
igate: torch.Tensor,
|
|
fgate: torch.Tensor,
|
|
c_initial: torch.Tensor = None,
|
|
n_initial: torch.Tensor = None,
|
|
m_initial: torch.Tensor = None,
|
|
return_last_states: bool = False,
|
|
eps: float = 1e-6,
|
|
dtype_state: torch.dtype = torch.float32,
|
|
**kwargs,
|
|
) -> tuple[
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
|
|
Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
|
|
]:
|
|
batch_size, nh, sequence_length, dhqk = query.shape
|
|
dhv = value.shape[-1]
|
|
device = query.device
|
|
|
|
if c_initial is not None:
|
|
if n_initial is None or m_initial is None:
|
|
raise ValueError("Initial states must be provided together.")
|
|
if n_initial is None or m_initial is None:
|
|
raise ValueError("Initial states must be provided together.")
|
|
matC_state, vecN_state, vecM_state = (
|
|
c_initial.to(dtype=dtype_state),
|
|
n_initial.to(dtype=dtype_state),
|
|
m_initial.to(dtype=dtype_state),
|
|
)
|
|
else:
|
|
# memory state
|
|
matC_state = torch.zeros((batch_size, nh, dhqk, dhv), dtype=dtype_state, device=device)
|
|
# normalizer state
|
|
vecN_state = torch.zeros((batch_size, nh, dhqk), dtype=dtype_state, device=device)
|
|
# max state
|
|
vecM_state = torch.zeros((batch_size, nh, 1), dtype=dtype_state, device=device)
|
|
|
|
vecH_list = []
|
|
for t in range(sequence_length):
|
|
# gates
|
|
vecF_t, vecI_t = fgate[:, :, t, None], igate[:, :, t, None]
|
|
|
|
# projections
|
|
vecQ_t, vecK_t, vecV_t = query[:, :, t, :], key[:, :, t, :], value[:, :, t, :]
|
|
|
|
# step
|
|
vecH, (matC_state, vecN_state, vecM_state) = mlstm_recurrent_step_native(
|
|
cstate=matC_state,
|
|
nstate=vecN_state,
|
|
mstate=vecM_state,
|
|
query=vecQ_t,
|
|
key=vecK_t,
|
|
value=vecV_t,
|
|
igate=vecI_t,
|
|
fgate=vecF_t,
|
|
eps=eps,
|
|
dtype_state=dtype_state,
|
|
**kwargs,
|
|
)
|
|
vecH_list.append(vecH)
|
|
|
|
matH = torch.stack(vecH_list, dim=-2)
|
|
|
|
if return_last_states:
|
|
return matH, (matC_state, vecN_state, vecM_state)
|
|
else:
|
|
return matH
|
|
|
|
def wrap_chunkwise_pad_zeros(
|
|
mlstm_chunkwise_kernel: Callable,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
fgate: torch.Tensor,
|
|
igate: torch.Tensor,
|
|
c_initial: torch.Tensor = None,
|
|
n_initial: torch.Tensor = None,
|
|
m_initial: torch.Tensor = None,
|
|
return_last_states: bool = False,
|
|
eps: float = 1e-6,
|
|
autocast_kernel_dtype: torch.dtype = torch.bfloat16,
|
|
chunk_size: int = 64,
|
|
**kwargs,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
|
|
if return_last_states:
|
|
raise ValueError(
|
|
"We are padding zeros, so we cannot return last states,",
|
|
"as they would be not the true last states.",
|
|
)
|
|
|
|
batch_size, nh, sequence_length, dhqk = query.shape
|
|
S_unpadded = sequence_length
|
|
# padding to chunk size for kernels
|
|
if sequence_length % chunk_size != 0:
|
|
S_padded = ((sequence_length + chunk_size - 1) // chunk_size) * chunk_size
|
|
q_pad = query.new_zeros(batch_size, nh, S_padded, query.shape[3])
|
|
k_pad = key.new_zeros(batch_size, nh, S_padded, key.shape[3])
|
|
v_pad = value.new_zeros(batch_size, nh, S_padded, value.shape[3])
|
|
i_pad = igate.new_zeros(batch_size, nh, S_padded)
|
|
f_pad = fgate.new_zeros(batch_size, nh, S_padded)
|
|
q_pad[:, :, :S_unpadded, :] = query
|
|
k_pad[:, :, :S_unpadded, :] = key
|
|
v_pad[:, :, :S_unpadded, :] = value
|
|
i_pad[:, :, :S_unpadded] = igate
|
|
f_pad[:, :, :S_unpadded] = fgate
|
|
else:
|
|
q_pad = query
|
|
k_pad = key
|
|
v_pad = value
|
|
i_pad = igate
|
|
f_pad = fgate
|
|
|
|
matH = mlstm_chunkwise_kernel(
|
|
query=q_pad,
|
|
key=k_pad,
|
|
value=v_pad,
|
|
igate=i_pad,
|
|
fgate=f_pad,
|
|
c_initial=c_initial,
|
|
n_initial=n_initial,
|
|
m_initial=m_initial,
|
|
return_last_states=return_last_states,
|
|
eps=eps,
|
|
autocast_kernel_dtype=autocast_kernel_dtype,
|
|
chunk_size=chunk_size,
|
|
**kwargs,
|
|
)
|
|
matH = matH[:, :, :S_unpadded, :]
|
|
return matH
|
|
|
|
def wrap_chunkwise_arbitrary_sequence_length(
|
|
mlstm_chunkwise_kernel: Callable,
|
|
mlstm_sequence_kernel: Callable,
|
|
mlstm_step_kernel: Callable,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
fgate: torch.Tensor,
|
|
igate: torch.Tensor,
|
|
c_initial: torch.Tensor = None,
|
|
n_initial: torch.Tensor = None,
|
|
m_initial: torch.Tensor = None,
|
|
return_last_states: bool = True,
|
|
eps: float = 1e-6,
|
|
autocast_kernel_dtype: torch.dtype = torch.bfloat16,
|
|
chunk_size: int = 64,
|
|
enable_logging: bool = False,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
|
|
"""This function computes the last hidden state and matH outputs of the mLSTM, independently of the sequence length.
|
|
|
|
For this it uses three kernels:
|
|
- mlstm_chunkwise_kernel: mlstm chunkwise kernels that processes chunks of a given chunk size in parallel.
|
|
- mlstm_sequence_kernel: mlstm kernel that processes the remaining sequence length in a single step recurrence.
|
|
- mlstm_step_kernel: mlstm kernel that processes a sequence length of 1 in a single step.
|
|
|
|
It tries to maximize the chunksizes to improve performance.
|
|
It will start with the given chunk size and then divides the chunksize by 2 until the chunk size is smaller than 16.
|
|
At every chunksize it will process the maximal number of chunks that fit into the remaining sequence length.
|
|
|
|
E.g. for chunk_size = 64, this function will try the chunksizes [64, 32, 16] if necessary.
|
|
|
|
For the remaining sequence length, which is smaller than 16, we use a different kernel that computes the mLSTM
|
|
in a single step and loop over this in pytorch.
|
|
|
|
Args:
|
|
mlstm_chunkwise_kernel: The mLSTM chunkwise kernel that processes chunks of a given chunk size in parallel
|
|
mlstm_sequence_kernel: The mLSTM kernel that processes the remaining sequence length in a single step recurrence
|
|
query: The query tensor (batch_size, nh, sequence_length, dhqk)
|
|
key: The key tensor (batch_size, nh, sequence_length, dhqk)
|
|
value: The value tensor (batch_size, nh, sequence_length, dhhv)
|
|
fgate: The forget gate tensor (batch_size, nh, sequence_length)
|
|
igate: The input gate tensor (batch_size, nh, sequence_length)
|
|
c_initial: The initial cell state tensor (batch_size, nh, dhqk, dhhv)
|
|
n_initial: The initial hidden state tensor (batch_size, nh, dhqk)
|
|
m_initial: The initial memory state tensor (batch_size, nh, 1)
|
|
return_last_states: If True, the function will return the last states of the mLSTM
|
|
eps: The epsilon value used for numerical stability
|
|
autocast_kernel_dtype: The dtype used for the kernel computation
|
|
chunk_size: The chunk size used for the chunkwise kernel
|
|
enable_logging: If True, the function will log debug information. Default is False.
|
|
|
|
Returns:
|
|
The last hidden state tensor (batch_size, nh, sequence_length, dhhv) or a tuple containing the last hidden state tensor and the last states of the mLSTM
|
|
Last states are (cstate (batch_size, nh, dhqk, dhhv), nstate (batch_size, nh, dhqk), mstate (batch_size, nh, 1)).
|
|
"""
|
|
|
|
batch_size, nh, sequence_length, dhqk = key.shape
|
|
dhhv = value.shape[-1]
|
|
|
|
c_state = (
|
|
c_initial
|
|
if c_initial is not None
|
|
else torch.zeros(batch_size, nh, dhqk, dhhv, device=key.device, dtype=torch.float32)
|
|
)
|
|
n_state = (
|
|
n_initial
|
|
if n_initial is not None
|
|
else torch.zeros(batch_size, nh, dhqk, device=key.device, dtype=torch.float32)
|
|
)
|
|
m_state = (
|
|
m_initial
|
|
if m_initial is not None
|
|
else torch.zeros(batch_size, nh, 1, device=key.device, dtype=torch.float32)
|
|
)
|
|
|
|
if sequence_length > 1:
|
|
# process the sequence length in chunks
|
|
h_outs = []
|
|
seq_len_start_idx = 0
|
|
remaining_seq_len = sequence_length - seq_len_start_idx
|
|
num_chunks = remaining_seq_len // chunk_size
|
|
if num_chunks > 0:
|
|
iter_seq_len = chunk_size * num_chunks
|
|
seq_len_idx = seq_len_start_idx + iter_seq_len
|
|
h_out, (c_state, n_state, m_state) = mlstm_chunkwise_kernel(
|
|
query=query[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
|
|
key=key[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
|
|
value=value[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
|
|
fgate=fgate[..., seq_len_start_idx:seq_len_idx].contiguous(),
|
|
igate=igate[..., seq_len_start_idx:seq_len_idx].contiguous(),
|
|
c_initial=c_state,
|
|
n_initial=n_state,
|
|
m_initial=m_state,
|
|
chunk_size=chunk_size,
|
|
return_last_states=True,
|
|
autocast_kernel_dtype=autocast_kernel_dtype,
|
|
eps=eps,
|
|
)
|
|
seq_len_start_idx += iter_seq_len
|
|
h_outs.append(h_out)
|
|
|
|
remaining_seq_len = sequence_length - seq_len_start_idx
|
|
|
|
if remaining_seq_len > 0:
|
|
# we use here matK as query as this kernel does not need a query, since we do not care about the outputs only about the last state
|
|
h_out, (c_state, n_state, m_state) = mlstm_sequence_kernel(
|
|
query=query[..., seq_len_start_idx:sequence_length, :].contiguous(),
|
|
key=key[..., seq_len_start_idx:sequence_length, :].contiguous(),
|
|
value=value[..., seq_len_start_idx:sequence_length, :].contiguous(),
|
|
igate=igate[..., seq_len_start_idx:sequence_length].contiguous(),
|
|
fgate=fgate[..., seq_len_start_idx:sequence_length].contiguous(),
|
|
c_initial=c_state,
|
|
n_initial=n_state,
|
|
m_initial=m_state,
|
|
return_last_states=True,
|
|
eps=eps,
|
|
)
|
|
h_outs.append(h_out)
|
|
h_out = torch.concatenate(h_outs, dim=2)
|
|
|
|
else:
|
|
if sequence_length != 1:
|
|
raise ValueError(
|
|
f"Received empty sequence (sequence_length={sequence_length}), require at least single element in the sequence."
|
|
)
|
|
# process the sequence length in a single step
|
|
# while this case is also captured by the regular mode above,
|
|
# it avoids the overhead of the loop and calls the step kernel directly
|
|
# The step function does not want a sequence dimension
|
|
# qkv shape is (batch_size, nh, dhqk/dhv)
|
|
# igate, fgate shape is (batch_size, nh, 1)
|
|
h_out, (c_state, n_state, m_state) = mlstm_step_kernel(
|
|
query=query.squeeze(2),
|
|
key=key.squeeze(2),
|
|
value=value.squeeze(2),
|
|
igate=igate,
|
|
fgate=fgate,
|
|
cstate=c_state,
|
|
nstate=n_state,
|
|
mstate=m_state,
|
|
eps=eps,
|
|
)
|
|
h_out = h_out[:, :, None, :]
|
|
|
|
if return_last_states:
|
|
return h_out, (c_state, n_state, m_state)
|
|
else:
|
|
return h_out
|
|
|
|
class xLSTMBackend(nn.Module):
|
|
"""xLSTM Backend Module for PyTorch.
|
|
|
|
This module wraps the xLSTM kernels and provides a high-level interface for training and inference.
|
|
"""
|
|
|
|
config_class = xLSTMConfig
|
|
|
|
def __init__(self, config: xLSTMConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.chunkwise_kernel_fn = mlstm_chunkwise_native_autograd
|
|
self.sequence_kernel_fn = mlstm_recurrent_sequence_native
|
|
self.step_kernel_fn = mlstm_recurrent_step_native
|
|
|
|
self._inference_fn = partial(
|
|
wrap_chunkwise_arbitrary_sequence_length,
|
|
mlstm_chunkwise_kernel=self.chunkwise_kernel_fn,
|
|
mlstm_sequence_kernel=partial(
|
|
self.sequence_kernel_fn,
|
|
dtype_state=getattr(torch, config.inference_state_dtype),
|
|
),
|
|
mlstm_step_kernel=partial(
|
|
self.step_kernel_fn,
|
|
dtype_state=getattr(torch, config.inference_state_dtype),
|
|
),
|
|
chunk_size=config.chunk_size,
|
|
eps=config.eps,
|
|
autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
|
|
return_last_states=True,
|
|
)
|
|
|
|
train_kernel_fn = partial(
|
|
self.chunkwise_kernel_fn,
|
|
autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
|
|
eps=config.eps,
|
|
chunk_size=config.chunk_size,
|
|
)
|
|
if "with_padding" in config.mode:
|
|
train_kernel_fn = partial(wrap_chunkwise_pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn)
|
|
self._train_fn = train_kernel_fn
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
igate: torch.Tensor,
|
|
fgate: torch.Tensor,
|
|
c_initial: torch.Tensor = None,
|
|
n_initial: torch.Tensor = None,
|
|
m_initial: torch.Tensor = None,
|
|
return_last_states: bool = False,
|
|
mode: Optional[Literal["train", "inference"]] = None,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
|
|
"""Forward pass of the mLSTM backend.
|
|
|
|
Depending on the configured mode, this method will call the appropriate kernel function.
|
|
|
|
Args:
|
|
query: The query tensor of shape (batch_size, nh, sequence_length, dhqk).
|
|
key: The key tensor of shape (batch_size, nh, sequence_length, dhqk).
|
|
value: The value tensor of shape (batch_size, nh, sequence_length, dhhv).
|
|
igate: The input gate preactivation tensor of shape (batch_size, nh, sequence_length).
|
|
fgate: The forget gate preactivation tensor of shape (batch_size, nh, sequence_length).
|
|
c_initial: The initial cell state tensor of shape (batch_size, nh, dhqk, dhhv).
|
|
Defaults to None.
|
|
n_initial: The initial hidden state tensor of shape (batch_size, nh, dhqk). Defaults to None.
|
|
m_initial: The initial memory tensor of shape (batch_size, nh, 1). Defaults to None.
|
|
return_last_states: Whether to return the last states of the sequence. Defaults to None.
|
|
If None, the value from the config is used.
|
|
|
|
Returns:
|
|
hidden states of shape (batch_size, nh, sequence_length, dhhv)
|
|
hidden states and last states the last states are the cell state cstate (batch_size, nh, dhqk, dhhv),
|
|
the normalizer state nstate (batch_size, nh, dhqk), and the max state mstate (batch_size, nh, 1)
|
|
"""
|
|
if mode is None:
|
|
mode = self.config.mode
|
|
|
|
if "train" in mode:
|
|
if return_last_states is None:
|
|
return_last_states = self.config.return_last_states
|
|
|
|
if self.config.mode == "train_with_padding":
|
|
if return_last_states:
|
|
raise ValueError("return_last_states=True is not supported with train_with_padding mode.")
|
|
|
|
return self._train_fn(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
igate=igate,
|
|
fgate=fgate,
|
|
c_initial=c_initial,
|
|
n_initial=n_initial,
|
|
m_initial=m_initial,
|
|
return_last_states=return_last_states,
|
|
)
|
|
|
|
elif "inference" in mode:
|
|
# inference mode always returns the last states
|
|
return self._inference_fn(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
igate=igate,
|
|
fgate=fgate,
|
|
c_initial=c_initial,
|
|
n_initial=n_initial,
|
|
m_initial=m_initial,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown mode: {self.config.mode}")
|
|
|
|
def extra_repr(self) -> str:
|
|
return f"{self.config}"
|
|
|
|
class xLSTMRMSNorm(nn.Module):
|
|
"""Root mean square normalization layer implementation similar
|
|
to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
|
|
|
|
It normalizes the input tensor by the root mean square of the last dimension.
|
|
|
|
Args:
|
|
num_features: The number of features in the input tensor.
|
|
eps: A small value to avoid division by zero.
|
|
use_weight: Whether to use a learnable weight.
|
|
use_bias: Whether to use a learnable bias.
|
|
force_float32_reductions: Whether to force float32 reductions.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_features: int,
|
|
eps: float = 1e-6,
|
|
use_weight: bool = True,
|
|
use_bias: bool = False,
|
|
force_float32_reductions: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.force_float32_reductions = force_float32_reductions
|
|
|
|
if use_weight:
|
|
self.weight = nn.Parameter(torch.ones(num_features))
|
|
else:
|
|
self.weight = None
|
|
|
|
if use_bias:
|
|
self.bias = nn.Parameter(torch.zeros(num_features))
|
|
else:
|
|
self.bias = None
|
|
|
|
def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.weight is not None:
|
|
x = x * self.weight
|
|
if self.bias is not None:
|
|
x = x + self.bias
|
|
return x
|
|
|
|
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
|
# apply rms norm over the last dimension, i.e. HD dimension
|
|
in_dtype = x.dtype
|
|
if self.force_float32_reductions:
|
|
x = x.float()
|
|
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
|
return x.to(in_dtype)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self._rms_normalize(x)
|
|
x = self._apply_weight_bias(x)
|
|
return x
|
|
|
|
class xLSTMMultiHeadLayerNorm(nn.Module):
|
|
"""Multi-head version of the LayerNorm layer.
|
|
|
|
It normalizes the last dimension of the input tensor.
|
|
|
|
The input is assumed to have the shape (batch_size, sequence_length, nh, DH), where:
|
|
batch_size: batch size
|
|
sequence_length: sequence length
|
|
nh: number of heads
|
|
DH: head dimension
|
|
|
|
The normalization is applied over the last dimension (DH) of the input tensor.
|
|
|
|
Args:
|
|
num_heads: The number of heads.
|
|
head_dim: The head dimension.
|
|
eps: A small value to avoid division by zero.
|
|
use_weight: Whether to use a learnable weight.
|
|
use_bias: Whether to use a learnable bias.
|
|
force_float32_reductions: Whether to force float32 reductions
|
|
|
|
Returns:
|
|
The normalized tensor with the shape (batch_size, sequence_length, nh * DH).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_dim: int,
|
|
eps: float = 1e-6,
|
|
use_weight: bool = True,
|
|
use_bias: bool = False,
|
|
force_float32_reductions: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.num_features = num_heads * head_dim
|
|
self.eps = eps
|
|
self.force_float32_reductions = force_float32_reductions
|
|
|
|
if use_weight:
|
|
self.weight = nn.Parameter(torch.ones(self.num_features))
|
|
else:
|
|
self.weight = None
|
|
|
|
if use_bias:
|
|
self.bias = nn.Parameter(torch.zeros(self.num_features))
|
|
else:
|
|
self.bias = None
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
|
|
def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.weight is not None:
|
|
x = x * self.weight
|
|
if self.bias is not None:
|
|
x = x + self.bias
|
|
return x
|
|
|
|
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
|
# apply layer norm over the last dimension, i.e. HD dimension
|
|
in_dtype = x.dtype
|
|
if self.force_float32_reductions:
|
|
x = x.float()
|
|
x_centered = x - x.mean(dim=-1, keepdim=True)
|
|
y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
|
|
return y.to(in_dtype)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
batch_size, sequence_length, nh, DH = x.shape
|
|
if nh != self.num_heads:
|
|
raise ValueError(f"Expected {self.num_heads} heads, got {nh}, input shape: {x.shape}")
|
|
if DH != self.head_dim:
|
|
raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}")
|
|
|
|
x = self._layer_normalize(x)
|
|
x = x.reshape(batch_size, sequence_length, -1)
|
|
x = self._apply_weight_bias(x)
|
|
return x
|
|
|
|
class xLSTMFeedForward(nn.Module):
|
|
def __init__(self, config: xLSTMConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.up_proj_dim = round_up_to_next_multiple_of(
|
|
config.hidden_size * config.ffn_proj_factor,
|
|
config.ffn_round_up_to_multiple_of,
|
|
)
|
|
|
|
if self.config.weight_mode == "single":
|
|
self.proj_up_gate = nn.Linear(
|
|
in_features=config.hidden_size,
|
|
out_features=self.up_proj_dim,
|
|
bias=self.config.use_bias,
|
|
)
|
|
self.proj_up = nn.Linear(
|
|
in_features=config.hidden_size,
|
|
out_features=self.up_proj_dim,
|
|
bias=self.config.use_bias,
|
|
)
|
|
elif self.config.weight_mode == "fused":
|
|
self.proj_up_gate_z = nn.Linear(
|
|
in_features=config.hidden_size,
|
|
out_features=2 * self.up_proj_dim,
|
|
bias=self.config.use_bias,
|
|
)
|
|
|
|
self.proj_down = nn.Linear(
|
|
in_features=self.up_proj_dim,
|
|
out_features=config.hidden_size,
|
|
bias=self.config.use_bias,
|
|
)
|
|
|
|
self.act_fn = nn.SiLU()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.config.weight_mode == "single":
|
|
x = self.act_fn(self.proj_up_gate(x)) * self.proj_up(x)
|
|
elif self.config.weight_mode == "fused":
|
|
x = self.proj_up_gate_z(x)
|
|
gate, z = torch.tensor_split(x, (self.up_proj_dim,), dim=-1)
|
|
x = self.act_fn(gate) * z
|
|
|
|
y = self.proj_down(x)
|
|
return y
|
|
|
|
class xLSTMLayer(nn.Module):
|
|
def __init__(self, config: xLSTMConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.v_dim = int(config.hidden_size * config.v_dim_factor)
|
|
self.qk_dim = int(config.hidden_size * config.qk_dim_factor)
|
|
|
|
if self.config.weight_mode == "single":
|
|
self.query = nn.Linear(
|
|
in_features=self.config.hidden_size,
|
|
out_features=self.qk_dim,
|
|
bias=self.config.use_bias,
|
|
)
|
|
self.key = nn.Linear(
|
|
in_features=self.config.hidden_size,
|
|
out_features=self.qk_dim,
|
|
bias=self.config.use_bias,
|
|
)
|
|
self.value = nn.Linear(
|
|
in_features=self.config.hidden_size,
|
|
out_features=self.v_dim,
|
|
bias=self.config.use_bias,
|
|
)
|
|
|
|
self.ogate_preact = nn.Linear(
|
|
in_features=self.config.hidden_size,
|
|
out_features=self.v_dim,
|
|
bias=self.config.use_bias,
|
|
)
|
|
self.igate_preact = nn.Linear(
|
|
in_features=self.config.hidden_size,
|
|
out_features=self.config.num_heads,
|
|
bias=True,
|
|
)
|
|
self.fgate_preact = nn.Linear(
|
|
in_features=self.config.hidden_size,
|
|
out_features=self.config.num_heads,
|
|
bias=True,
|
|
)
|
|
elif self.config.weight_mode == "fused":
|
|
self.qkv_opreact = nn.Linear(
|
|
in_features=self.config.hidden_size,
|
|
out_features=2 * self.qk_dim + 2 * self.v_dim,
|
|
bias=self.config.use_bias,
|
|
)
|
|
self.ifgate_preact = nn.Linear(
|
|
in_features=self.config.hidden_size,
|
|
out_features=2 * self.config.num_heads,
|
|
bias=True,
|
|
)
|
|
|
|
self.ogate_act_fn = nn.Sigmoid()
|
|
self.mlstm_backend = xLSTMBackend(config=self.config)
|
|
|
|
self.multihead_norm = xLSTMMultiHeadLayerNorm(
|
|
num_heads=self.config.num_heads,
|
|
head_dim=self.v_dim // self.config.num_heads,
|
|
eps=self.config.norm_eps,
|
|
use_weight=True,
|
|
use_bias=self.config.use_bias,
|
|
force_float32_reductions=self.config.norm_reduction_force_float32,
|
|
)
|
|
self.out_proj = nn.Linear(
|
|
in_features=self.v_dim,
|
|
out_features=self.config.hidden_size,
|
|
bias=self.config.use_bias,
|
|
)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None
|
|
) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]:
|
|
if x.ndim != 3:
|
|
raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}")
|
|
batch_size, sequence_length, _ = x.shape
|
|
if self.config.weight_mode == "single":
|
|
query = self.query(x)
|
|
key = self.key(x)
|
|
value = self.value(x)
|
|
o_preact = self.ogate_preact(x)
|
|
i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap)
|
|
f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap)
|
|
|
|
elif self.config.weight_mode == "fused":
|
|
qkv_opreact = self.qkv_opreact(x)
|
|
query, key, value, o_preact = torch.tensor_split(
|
|
qkv_opreact,
|
|
(
|
|
self.qk_dim,
|
|
2 * self.qk_dim,
|
|
2 * self.qk_dim + self.v_dim,
|
|
),
|
|
dim=-1,
|
|
)
|
|
|
|
if_preact = soft_cap(self.ifgate_preact(x), cap_value=self.config.gate_soft_cap)
|
|
i_preact, f_preact = torch.tensor_split(if_preact, (self.config.num_heads,), dim=-1)
|
|
|
|
query = query.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
|
|
key = key.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
|
|
value = value.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
|
|
i_preact = i_preact.transpose(1, 2)
|
|
f_preact = f_preact.transpose(1, 2)
|
|
if state is None:
|
|
c_initial, n_initial, m_initial = None, None, None
|
|
else:
|
|
c_initial, n_initial, m_initial = state
|
|
|
|
h, state = self.mlstm_backend(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
igate=i_preact,
|
|
fgate=f_preact,
|
|
c_initial=c_initial,
|
|
n_initial=n_initial,
|
|
m_initial=m_initial,
|
|
)
|
|
expected_h_shape = (
|
|
batch_size,
|
|
self.config.num_heads,
|
|
sequence_length,
|
|
self.v_dim // self.config.num_heads,
|
|
)
|
|
if h.shape != expected_h_shape:
|
|
raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
|
|
|
|
h = h.transpose(1, 2)
|
|
h_norm = self.multihead_norm(h)
|
|
h_norm = h_norm.reshape(batch_size, sequence_length, -1)
|
|
|
|
h_out = self.ogate_act_fn(o_preact) * h_norm
|
|
|
|
y = self.out_proj(h_out)
|
|
return y, state
|
|
|
|
class xLSTMBlock(nn.Module):
|
|
def __init__(self, config: xLSTMConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.norm_mlstm = xLSTMRMSNorm(
|
|
num_features=config.hidden_size,
|
|
eps=config.norm_eps,
|
|
use_weight=True,
|
|
use_bias=config.use_bias,
|
|
force_float32_reductions=config.norm_reduction_force_float32,
|
|
)
|
|
self.mlstm_layer = xLSTMLayer(config)
|
|
self.norm_ffn = xLSTMRMSNorm(
|
|
num_features=config.hidden_size,
|
|
eps=config.norm_eps,
|
|
use_weight=True,
|
|
use_bias=config.use_bias,
|
|
force_float32_reductions=config.norm_reduction_force_float32,
|
|
)
|
|
self.ffn = xLSTMFeedForward(config)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, state: Optional[mLSTMStateType] = None
|
|
) -> tuple[torch.Tensor, mLSTMStateType]:
|
|
x_mlstm = self.norm_mlstm(x)
|
|
x_mlstm, state = self.mlstm_layer(x_mlstm, state)
|
|
x = x + x_mlstm
|
|
|
|
x_ffn = self.norm_ffn(x)
|
|
x_ffn = self.ffn(x_ffn)
|
|
x = x + x_ffn
|
|
|
|
return x, state
|
|
|
|
|
|
def small_init_method(dim):
|
|
"""
|
|
Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
|
Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
|
|
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
|
|
std = (2 / (5 * dim)) ** (1 / 2)
|
|
|
|
def init_(tensor):
|
|
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
|
|
|
return init_
|
|
|
|
|
|
def wang_init_method(n_layers, dim):
|
|
"""
|
|
Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
|
|
"""
|
|
std = 2 / n_layers / dim ** (1 / 2)
|
|
|
|
def init_(tensor):
|
|
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
|
|
|
return init_
|
|
|
|
|
|
class xLSTMPreTrainedModel(PreTrainedModel):
|
|
"""
|
|
An abstract class for an interface to loading a pre-trained xLSTM model.
|
|
"""
|
|
|
|
config_class = xLSTMConfig
|
|
base_model_prefix = "backbone"
|
|
_no_split_modules = ["xLSTMBlock"]
|
|
supports_gradient_checkpointing = True
|
|
_is_stateful = True
|
|
|
|
def _module_name_map(self, module):
|
|
for name, mod in self.named_modules():
|
|
if mod is module:
|
|
return name
|
|
return ""
|
|
|
|
def _init_weights(self, module):
|
|
if isinstance(module, nn.Embedding):
|
|
small_init_method(self.config.hidden_size)(self.embeddings.weight)
|
|
elif isinstance(module, nn.Linear):
|
|
if module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
if self.config.weight_mode == "single" and "gate" in self._module_name_map(module):
|
|
torch.nn.init.zeros_(module.weight)
|
|
with torch.no_grad():
|
|
if "igate" in self._module_name_map(module):
|
|
module.bias.copy_(-10.0 * torch.ones_like(module.bias))
|
|
elif "fgate" in self._module_name_map(module):
|
|
module.bias.copy_(
|
|
torch.linspace(
|
|
3.0,
|
|
6.0,
|
|
module.bias.shape[-1],
|
|
).to(
|
|
device=module.bias.device,
|
|
dtype=module.bias.dtype,
|
|
)
|
|
)
|
|
elif self.config.weight_mode == "fused" and "gate" in self._module_name_map(module):
|
|
torch.nn.init.zeros_(module.weight)
|
|
with torch.no_grad():
|
|
module.bias[: self.config.num_heads] += -module.bias[
|
|
: self.config.num_heads
|
|
] - 10.0 * torch.ones_like(module.bias)
|
|
module.bias[: self.config.num_heads] += -module.bias[self.config.num_heads :] + torch.linspace(
|
|
3.0,
|
|
6.0,
|
|
module.bias.shape[-1],
|
|
).to(
|
|
device=module.bias.device,
|
|
dtype=module.bias.dtype,
|
|
)
|
|
elif "proj_down" in self._module_name_map(module):
|
|
wang_init_method(dim=module.weight.shape[1], n_layers=self.config.num_hidden_layers)(module.weight)
|
|
elif "out_proj" in self._module_name_map(module):
|
|
wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)(module.weight)
|
|
elif module.weight is not None:
|
|
small_init_method(self.config.hidden_size)(module.weight)
|
|
elif isinstance(module, xLSTMRMSNorm) or hasattr(module, "_layer_normalize"):
|
|
torch.nn.init.ones_(module.weight)
|
|
if hasattr(module, "bias") and module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
|
|
|
|
class xLSTMCache:
|
|
"""
|
|
Cache for xLSTM model which does not have attention mechanism and key value states.
|
|
|
|
Arguments:
|
|
config (`PretrainedConfig):
|
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
max_batch_size (`int`):
|
|
The batch size with which the model will be used.
|
|
dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
|
The default `dtype` to use when initializing the layer.
|
|
device (`torch.device` or `str`, *optional*):
|
|
The device on which the cache should be initialized. Should be the same as the layer.
|
|
|
|
Attributes:
|
|
seqlen_offset: int
|
|
dtype: torch.dtype
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache
|
|
|
|
>>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b")
|
|
>>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b")
|
|
|
|
>>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt")
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
>>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
|
|
>>> outputs = model(**inputs, cache_params=cache_params, use_cache=True)
|
|
>>> outputs.cache_params
|
|
xLSTMCache()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: xLSTMConfig,
|
|
max_batch_size: int,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
device: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
self.seqlen_offset = 0
|
|
self.dtype = dtype
|
|
self.config = config
|
|
self.rnn_state = {
|
|
layer: (
|
|
torch.zeros(
|
|
[max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim],
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device),
|
|
torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device),
|
|
)
|
|
for layer in range(config.num_hidden_layers)
|
|
}
|
|
|
|
def reset(self):
|
|
self.rnn_state = {
|
|
layer: (
|
|
torch.zeros_like(self.rnn_state[layer][0]),
|
|
torch.zeros_like(self.rnn_state[layer][1]),
|
|
torch.zeros_like(self.rnn_state[layer][2]),
|
|
)
|
|
for layer in self.rnn_state
|
|
}
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class xLSTMOutput(ModelOutput):
|
|
r"""
|
|
cache_params (`xLSTMCache`):
|
|
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
|
avoid providing the old `input_ids`.
|
|
"""
|
|
|
|
last_hidden_state: Optional[torch.FloatTensor]
|
|
cache_params: Optional[xLSTMCache] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
|
|
|
|
@auto_docstring
|
|
class xLSTMModel(xLSTMPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
# use embbeding_dim and num_blocks once here to make use of them
|
|
self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
self.blocks = nn.ModuleList([xLSTMBlock(config) for _ in range(config.num_blocks)])
|
|
self.out_norm = xLSTMRMSNorm(config.hidden_size, eps=config.norm_eps)
|
|
self.gradient_checkpointing = False
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings
|
|
|
|
def set_input_embeddings(self, new_embedding):
|
|
self.embeddings = new_embedding
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
|
cache_params: Optional[xLSTMCache] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> Union[tuple, xLSTMOutput]:
|
|
r"""
|
|
cache_params (`xLSTMCache`, *optional*):
|
|
The xLSTMCache that carries the RNN states.
|
|
"""
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
|
if self.gradient_checkpointing and self.training and use_cache:
|
|
use_cache = False
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embeddings(input_ids)
|
|
|
|
if use_cache and cache_params is None:
|
|
cache_params = xLSTMCache(
|
|
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
if (
|
|
not self.training
|
|
and self.config.max_inference_chunksize < hidden_states.shape[1]
|
|
and not output_hidden_states
|
|
):
|
|
offset = 0
|
|
with torch.no_grad():
|
|
if cache_params is None:
|
|
cache_params = xLSTMCache(config=self.config, batch_size=hidden_states.shape[0])
|
|
final_state = torch.zeros_like(hidden_states)
|
|
while offset < hidden_states.shape[1]:
|
|
hidden_states_chunk = hidden_states[
|
|
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
|
|
]
|
|
for layer_idx, xlstm_block in enumerate(self.blocks):
|
|
hidden_states_chunk, rnn_state = xlstm_block(
|
|
hidden_states_chunk,
|
|
state=cache_params.rnn_state[layer_idx],
|
|
)
|
|
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
|
|
local_rnn_state = rnn_state[state_idx]
|
|
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
|
|
cache_params.rnn_state_initial = False
|
|
final_state[
|
|
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
|
|
] = hidden_states_chunk
|
|
offset += self.config.max_inference_chunksize
|
|
hidden_states = final_state
|
|
else:
|
|
all_hidden_states = () if output_hidden_states else None
|
|
for layer_idx, xlstm_block in enumerate(self.blocks):
|
|
if self.gradient_checkpointing and self.training:
|
|
hidden_states, rnn_state = self._gradient_checkpointing_func(
|
|
xlstm_block.__call__,
|
|
hidden_states,
|
|
cache_params.rnn_state[layer_idx] if cache_params is not None else None,
|
|
)
|
|
else:
|
|
hidden_states, rnn_state = xlstm_block(
|
|
hidden_states,
|
|
state=cache_params.rnn_state[layer_idx] if cache_params is not None else None,
|
|
)
|
|
if cache_params:
|
|
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
|
|
local_rnn_state = rnn_state[state_idx]
|
|
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
|
|
cache_params.rnn_state_initial = False
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if use_cache:
|
|
cache_params.seqlen_offset += inputs_embeds.shape[1]
|
|
|
|
hidden_states = self.out_norm(hidden_states)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
return xLSTMOutput(
|
|
last_hidden_state=hidden_states,
|
|
cache_params=cache_params,
|
|
hidden_states=all_hidden_states,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class xLSTMCausalLMOutput(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
Language modeling loss (for next-token prediction).
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
cache_params (`xLSTMCache`, *optional*, carrying the RNN states):
|
|
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
|
avoid providing the old `input_ids`.
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
logits: Optional[torch.FloatTensor] = None
|
|
cache_params: Optional[xLSTMCache] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
|
|
|
|
@auto_docstring
|
|
class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.backbone = xLSTMModel(config)
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.lm_head
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.lm_head = new_embeddings
|
|
|
|
def get_input_embeddings(self):
|
|
return self.backbone.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, new_embeddings):
|
|
return self.backbone.set_input_embeddings(new_embeddings)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
inputs_embeds=None,
|
|
use_cache=None,
|
|
cache_params: Optional[xLSTMCache] = None,
|
|
**kwargs,
|
|
):
|
|
if use_cache and cache_params is not None:
|
|
# If the first cache position is non-zero, we assume we are in generation mode.
|
|
# Thus, the cache_params state is assumed to be the state before the last token
|
|
# (lastly generated token), and all previous tokens are already ingested.
|
|
# This should as well support generation from scratch with the [BOS] token inserted first.
|
|
input_ids = input_ids[:, -1:]
|
|
if inputs_embeds is not None:
|
|
inputs_embeds = inputs_embeds[:, -1:]
|
|
|
|
if inputs_embeds is not None and cache_params is None:
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
else:
|
|
model_inputs = {"input_ids": input_ids}
|
|
|
|
model_inputs.update({"cache_params": cache_params, "use_cache": use_cache})
|
|
return model_inputs
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
cache_params: Optional[xLSTMCache] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> Union[tuple, xLSTMCausalLMOutput]:
|
|
r"""
|
|
cache_params (`xLSTMCache`, *optional*):
|
|
The xLSTMCache that carries the RNN states.
|
|
"""
|
|
xlstm_outputs = self.backbone(
|
|
input_ids,
|
|
cache_params=cache_params,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_hidden_states=output_hidden_states,
|
|
**kwargs,
|
|
)
|
|
hidden_states = xlstm_outputs[0]
|
|
|
|
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
|
|
|
|
if not self.training and self.config.max_inference_chunksize < logits.shape[1]:
|
|
offset = 0
|
|
with torch.no_grad():
|
|
while offset < logits.shape[1]:
|
|
logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap(
|
|
logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])],
|
|
self.config.output_logit_soft_cap,
|
|
)
|
|
offset += self.config.max_inference_chunksize
|
|
else:
|
|
logits = soft_cap(logits, self.config.output_logit_soft_cap)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
# move labels to correct device to enable model parallelism
|
|
labels = labels.to(logits.device)
|
|
# Shift so that tokens < nstate predict nstate
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
# Flatten the tokens
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
return xLSTMCausalLMOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
cache_params=xlstm_outputs.cache_params,
|
|
hidden_states=xlstm_outputs.hidden_states,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"xLSTMForCausalLM",
|
|
"xLSTMModel",
|
|
"xLSTMPreTrainedModel",
|
|
]
|