team-10/venv/Lib/site-packages/transformers/models/xlstm/modeling_xlstm.py
2025-08-02 02:00:33 +02:00

1623 lines
64 KiB
Python

# Copyright 2025 NXAI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch xLSTM Model."""
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...generation import GenerationMixin
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_xlstm_available
from .configuration_xlstm import xLSTMConfig
if is_xlstm_available():
from xlstm.xlstm_large.model import mLSTMBlock as xLSTMBlock
from xlstm.xlstm_large.model import mLSTMStateType, soft_cap
from xlstm.xlstm_large.model import xLSTMRMSNorm as xLSTMRMSNorm
external_xlstm = True
else:
from functools import partial
from typing import Callable, Literal
from .configuration_xlstm import round_up_to_next_multiple_of
mLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor, torch.Tensor]
mLSTMStateType = dict[int, mLSTMLayerStateType]
external_xlstm = False
def soft_cap(values: torch.Tensor, cap_value: Optional[Union[float, torch.Tensor]] = None) -> torch.Tensor:
"""
Soft caps a tensor to a value.
Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention
and output language heads to prevent large logits from dominating the softmax. See for example Gemma2:
https://arxiv.org/abs/2408.00118
Args:
values: The tensor to cap.
cap_value: The value to cap the values to. If None, no cap is applied.
Returns:
The capped values.
"""
if cap_value is None:
return values
return cap_value * torch.tanh(values / cap_value)
def mlstm_chunkwise_recurrent_fw_C(
matK: torch.Tensor,
matV: torch.Tensor,
vecB: torch.Tensor,
vecI: torch.Tensor,
matC_states: torch.Tensor = None,
vecN_states: torch.Tensor = None,
scaMinter_states: torch.Tensor = None,
matC_initial: torch.Tensor = None,
vecN_initial: torch.Tensor = None,
scaMinter_initial: torch.Tensor = None,
qk_scale: Optional[float] = None,
chunk_size: int = 64,
num_chunks: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, nh, _, dhqk, dhhv = *matK.shape, matV.shape[-1]
nc = num_chunks
_dtype, _device = matK.dtype, matK.device
if qk_scale is None:
qk_scale = dhqk**-0.5
# initialize the states tensors
if matC_states is None:
matC_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk, dhhv), dtype=_dtype, device=_device)
if vecN_states is None:
vecN_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk), dtype=_dtype, device=_device)
if scaMinter_states is None:
scaMinter_states = torch.zeros((batch_size, nh, (nc + 1)), dtype=_dtype, device=_device)
# assign the initial states to the running states
matC_k = (
torch.zeros((batch_size, nh, dhqk, dhhv), dtype=_dtype, device=_device)
if matC_initial is None
else matC_initial
)
vecN_k = (
torch.zeros((batch_size, nh, dhqk), dtype=_dtype, device=_device) if vecN_initial is None else vecN_initial
)
scaM_inter_k = (
torch.zeros((batch_size, nh, 1), dtype=_dtype, device=_device)
if scaMinter_initial is None
else scaMinter_initial
)
vecA = vecB[..., -1, None] - vecB + vecI
scaG = vecB[..., -1]
scaA_max = vecA.max(-1).values
scaM_inter_k = scaM_inter_k.squeeze(-1)
for key in range(0, num_chunks):
# store the states from the previous iteration before updating them
# in the first iteration, these are the initial states
matC_states[:, :, key * dhqk : (key + 1) * dhqk, :] = matC_k
vecN_states[:, :, key * dhqk : (key + 1) * dhqk] = vecN_k
scaMinter_states[:, :, key] = scaM_inter_k
# m_k update
scaA_max_k = scaA_max[:, :, key]
scaG_k = scaG[:, :, key]
scaM_inter_k_next = torch.max(scaG_k + scaM_inter_k, scaA_max_k)
# C_k update
matK_chunk = matK[:, :, key * chunk_size : (key + 1) * chunk_size, :] # * qk_scale
matV_chunk = matV[:, :, key * chunk_size : (key + 1) * chunk_size, :]
vecA_k = vecA[:, :, key, :]
vecAbar_k = torch.exp(vecA_k - scaM_inter_k_next[..., None])[:, :, :, None]
matK_chunk_gated = matK_chunk * vecAbar_k
scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
# NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
# n_k update
vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1)
# move to the next iteration
scaM_inter_k = scaM_inter_k_next
matC_k = matC_k_next
vecN_k = vecN_k_next
# store the states from the last iteration
matC_states[:, :, -dhqk:, :] = matC_k
vecN_states[:, :, -dhqk:] = vecN_k
scaMinter_states[:, :, -1] = scaM_inter_k
return matC_states, vecN_states, scaMinter_states
def mlstm_chunkwise_parallel_fw_H(
matQ: torch.Tensor,
matK: torch.Tensor,
matV: torch.Tensor,
# these states must be all states up to the last chunk, i.e. :-1
matC_states: torch.Tensor,
vecN_states: torch.Tensor,
scaMinter_states: torch.Tensor,
vecI: torch.Tensor,
vecB: torch.Tensor,
qk_scale: float,
chunk_size: int = 64,
num_chunks: int = 1,
eps: float = 1e-6,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
_device = matQ.device
nc, chunk_size = num_chunks, chunk_size
batch_size, nh, dqk, dhv = matC_states.shape
matC_k_states = matC_states.view(batch_size, nh, nc, dqk // nc, dhv)
vecN_k_states = vecN_states.view(batch_size, nh, nc, dqk // nc)
scaMinter_k_states = scaMinter_states
matQ = matQ.view(batch_size, nh, nc, chunk_size, dqk)
matK = matK.view(batch_size, nh, nc, chunk_size, dqk)
matV = matV.view(batch_size, nh, nc, chunk_size, dhv)
ltr = torch.tril(
torch.ones(
(chunk_size, chunk_size),
dtype=torch.bool,
device=_device,
)
)
# Compute intra chunk contribution: H_intra
matF_logsig_chunk = vecB[:, :, :, :, None] - vecB[:, :, :, None, :]
matF_logsig_mask_chunk = torch.where(ltr, matF_logsig_chunk, -float("inf"))
matLogD_chunk = matF_logsig_mask_chunk + vecI[:, :, :, None, :]
# max_state intra
vecMintra_k = torch.max(matLogD_chunk, dim=-1, keepdim=False).values
# max_state combined
vecM_b_inter = vecB + scaMinter_k_states[:, :, :, None]
vecM_k_combine = torch.maximum(vecM_b_inter, vecMintra_k)
vecM_k_combine = vecM_k_combine[:, :, :, :, None]
vecM_b_inter = vecM_b_inter[:, :, :, :, None]
matLogD_stabilized_chunk = matLogD_chunk - vecM_k_combine
matD_chunk = torch.exp(matLogD_stabilized_chunk)
matS_chunk = (matQ @ matK.transpose(-2, -1)) * qk_scale
matM_chunk = matS_chunk * matD_chunk
# ? Combine H_intra with H_inter
vecBbar = torch.exp(vecM_b_inter - vecM_k_combine)
matQ_chunk_gated = matQ * vecBbar * qk_scale
matNumerator_common = matQ_chunk_gated @ matC_k_states + matM_chunk @ matV
vecDenom_l_common = matQ_chunk_gated @ vecN_k_states.unsqueeze(-1) + matM_chunk.sum(dim=-1, keepdim=True)
vecDenom_max_common = torch.maximum(torch.abs(vecDenom_l_common), torch.exp(-vecM_k_combine))
matH_k_chunk = matNumerator_common / (vecDenom_max_common + eps)
matH_out = matH_k_chunk.view(batch_size, nh, nc * chunk_size, dhv)
# we need the denominator and the overall max state for the backward pass
vecN_out = vecDenom_max_common.reshape(batch_size, nh, nc * chunk_size)
vecM_out = vecM_k_combine(batch_size, nh, nc * chunk_size)
return matH_out, vecN_out, vecM_out
def mlstm_chunkwise_fw(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
igate: torch.Tensor,
fgate: torch.Tensor,
cstate: torch.Tensor = None,
nstate: torch.Tensor = None,
mstate: torch.Tensor = None,
qk_scale: Optional[float] = None,
return_last_states: bool = False,
return_all_states: bool = False,
chunk_size: int = 64,
eps: float = 1e-6,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
]:
batch_size, nh, sequence_length, dhqk = query.shape
if sequence_length % chunk_size != 0:
raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
nc = sequence_length // chunk_size
vecI = igate.view(batch_size, nh, nc, chunk_size)
vecF = fgate.view(batch_size, nh, nc, chunk_size)
# compute the gates, the g and the a and b vectors
vecF_logsig = fgate.logsigmoid(vecF)
vecB = vecF_logsig.cumsum(-1)
if qk_scale is None:
qk_scale = dhqk**-0.5
#! materialize the C_k, n_k, m_k states for each chunk
matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
matK=key,
matV=value,
vecB=vecB,
vecI=vecI,
matC_initial=cstate,
vecN_initial=nstate,
scaMinter_initial=mstate,
qk_scale=qk_scale,
chunk_size=chunk_size,
num_chunks=nc,
)
#! compute the outputs within each chunk
matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
matQ=query,
matK=key,
matV=value,
matC_states=matC_k_states[:, :, :-dhqk, :],
vecN_states=vecN_k_states[:, :, :-dhqk],
scaMinter_states=scaMinter_k_states[:, :, :-1],
vecI=vecI,
vecB=vecB,
qk_scale=qk_scale,
chunk_size=chunk_size,
num_chunks=nc,
eps=eps,
)
ret_tuple = (matH_out, vecN_out, vecM_out)
if return_last_states:
ret_tuple += (
(matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]),
)
else:
ret_tuple += (None,)
if return_all_states:
ret_tuple += ((matC_k_states, vecN_k_states, scaMinter_k_states),)
else:
ret_tuple += (None,)
return ret_tuple
def mlstm_chunkwise_native_autograd(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
igate: torch.Tensor,
fgate: torch.Tensor,
c_initial: torch.Tensor = None,
n_initial: torch.Tensor = None,
m_initial: torch.Tensor = None,
return_last_states: bool = False,
eps: float = 1e-6,
chunk_size: int = 64,
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
batch_size, nh, sequence_length, dhqk = query.shape
if sequence_length % chunk_size != 0:
raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
nc = sequence_length // chunk_size
vecI = igate.view(batch_size, nh, nc, chunk_size)
vecF = fgate.view(batch_size, nh, nc, chunk_size)
# compute the gates, the g and the a and b vectors
vecF_logsig = F.logsigmoid(vecF)
vecB = vecF_logsig.cumsum(-1)
qk_scale = dhqk**-0.5
#! materialize the C_k, n_k, m_k states for each chunk
matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
matK=key,
matV=value,
vecB=vecB,
vecI=vecI,
matC_initial=c_initial,
vecN_initial=n_initial,
scaMinter_initial=m_initial,
qk_scale=qk_scale,
chunk_size=chunk_size,
num_chunks=nc,
)
#! compute the outputs within each chunk
matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
matQ=query,
matK=key,
matV=value,
matC_states=matC_k_states[:, :, :-dhqk, :],
vecN_states=vecN_k_states[:, :, :-dhqk],
scaMinter_states=scaMinter_k_states[:, :, :-1],
vecI=vecI,
vecB=vecB,
qk_scale=qk_scale,
chunk_size=chunk_size,
num_chunks=nc,
eps=eps,
)
last_states = (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:])
if return_last_states:
return matH_out, last_states
else:
return matH_out
def mlstm_recurrent_step_native(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
igate: torch.Tensor,
fgate: torch.Tensor,
cstate: torch.Tensor,
nstate: torch.Tensor,
mstate: torch.Tensor,
eps: float = 1e-6,
dtype_state: torch.dtype = torch.float32,
**kwargs,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""This is a single step of the mLSTM operation in recurrent form."""
dtype_qkv = query.dtype
matC_old = cstate.to(dtype=dtype_state)
vecN_old = nstate.to(dtype=dtype_state)
scaM_old = mstate.to(dtype=dtype_state)
batch_size, nh, dhqk = query.shape
_, _, dhhv = value.shape
if query.shape != key.shape:
raise ValueError("query and key must have the same shape")
if matC_old.shape != (batch_size, nh, dhqk, dhhv):
raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}")
if vecN_old.shape != (batch_size, nh, dhqk):
raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}")
if scaM_old.shape != (batch_size, nh, 1):
raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}")
if igate.shape != (batch_size, nh, 1):
raise ValueError(f"scaI has wrong shape, got {igate.shape}")
if fgate.shape != (batch_size, nh, 1):
raise ValueError(f"scaF has wrong shape, got {fgate.shape}")
# gates
scaF_log = torch.nn.functional.logsigmoid(fgate)
# update rule
scaM_state_new = torch.max(scaF_log + scaM_old, igate)
scaF_act = torch.exp(scaF_log + scaM_old - scaM_state_new)
scaI_act = torch.exp(igate - scaM_state_new)
vecQ_scaled = query * (dhqk ** (-0.5))
matC_state_new = scaF_act[:, :, :, None] * matC_old + scaI_act[:, :, :, None] * (
key[:, :, :, None] @ value[:, :, None, :]
)
vecN_state_new = scaF_act * vecN_old + scaI_act * key
h_num = vecQ_scaled[:, :, None, :] @ matC_state_new.to(dtype=dtype_qkv)
h_num = h_num.squeeze(2).to(dtype=dtype_state)
qn_dotproduct = vecQ_scaled[:, :, None, :] @ vecN_state_new[:, :, :, None].to(dtype=dtype_qkv)
qn_dotproduct = qn_dotproduct.squeeze(2)
max_val = torch.exp(-scaM_state_new)
h_denom = (torch.maximum(qn_dotproduct.abs(), max_val) + eps).to(dtype=dtype_state)
h = h_num / h_denom
h = h.to(dtype=dtype_qkv)
matC_state_new = matC_state_new.to(dtype=dtype_state)
vecN_state_new = vecN_state_new.to(dtype=dtype_state)
scaM_state_new = scaM_state_new.to(dtype=dtype_state)
return h, (matC_state_new, vecN_state_new, scaM_state_new)
def mlstm_recurrent_sequence_native(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
igate: torch.Tensor,
fgate: torch.Tensor,
c_initial: torch.Tensor = None,
n_initial: torch.Tensor = None,
m_initial: torch.Tensor = None,
return_last_states: bool = False,
eps: float = 1e-6,
dtype_state: torch.dtype = torch.float32,
**kwargs,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
]:
batch_size, nh, sequence_length, dhqk = query.shape
dhv = value.shape[-1]
device = query.device
if c_initial is not None:
if n_initial is None or m_initial is None:
raise ValueError("Initial states must be provided together.")
if n_initial is None or m_initial is None:
raise ValueError("Initial states must be provided together.")
matC_state, vecN_state, vecM_state = (
c_initial.to(dtype=dtype_state),
n_initial.to(dtype=dtype_state),
m_initial.to(dtype=dtype_state),
)
else:
# memory state
matC_state = torch.zeros((batch_size, nh, dhqk, dhv), dtype=dtype_state, device=device)
# normalizer state
vecN_state = torch.zeros((batch_size, nh, dhqk), dtype=dtype_state, device=device)
# max state
vecM_state = torch.zeros((batch_size, nh, 1), dtype=dtype_state, device=device)
vecH_list = []
for t in range(sequence_length):
# gates
vecF_t, vecI_t = fgate[:, :, t, None], igate[:, :, t, None]
# projections
vecQ_t, vecK_t, vecV_t = query[:, :, t, :], key[:, :, t, :], value[:, :, t, :]
# step
vecH, (matC_state, vecN_state, vecM_state) = mlstm_recurrent_step_native(
cstate=matC_state,
nstate=vecN_state,
mstate=vecM_state,
query=vecQ_t,
key=vecK_t,
value=vecV_t,
igate=vecI_t,
fgate=vecF_t,
eps=eps,
dtype_state=dtype_state,
**kwargs,
)
vecH_list.append(vecH)
matH = torch.stack(vecH_list, dim=-2)
if return_last_states:
return matH, (matC_state, vecN_state, vecM_state)
else:
return matH
def wrap_chunkwise_pad_zeros(
mlstm_chunkwise_kernel: Callable,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
fgate: torch.Tensor,
igate: torch.Tensor,
c_initial: torch.Tensor = None,
n_initial: torch.Tensor = None,
m_initial: torch.Tensor = None,
return_last_states: bool = False,
eps: float = 1e-6,
autocast_kernel_dtype: torch.dtype = torch.bfloat16,
chunk_size: int = 64,
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
if return_last_states:
raise ValueError(
"We are padding zeros, so we cannot return last states,",
"as they would be not the true last states.",
)
batch_size, nh, sequence_length, dhqk = query.shape
S_unpadded = sequence_length
# padding to chunk size for kernels
if sequence_length % chunk_size != 0:
S_padded = ((sequence_length + chunk_size - 1) // chunk_size) * chunk_size
q_pad = query.new_zeros(batch_size, nh, S_padded, query.shape[3])
k_pad = key.new_zeros(batch_size, nh, S_padded, key.shape[3])
v_pad = value.new_zeros(batch_size, nh, S_padded, value.shape[3])
i_pad = igate.new_zeros(batch_size, nh, S_padded)
f_pad = fgate.new_zeros(batch_size, nh, S_padded)
q_pad[:, :, :S_unpadded, :] = query
k_pad[:, :, :S_unpadded, :] = key
v_pad[:, :, :S_unpadded, :] = value
i_pad[:, :, :S_unpadded] = igate
f_pad[:, :, :S_unpadded] = fgate
else:
q_pad = query
k_pad = key
v_pad = value
i_pad = igate
f_pad = fgate
matH = mlstm_chunkwise_kernel(
query=q_pad,
key=k_pad,
value=v_pad,
igate=i_pad,
fgate=f_pad,
c_initial=c_initial,
n_initial=n_initial,
m_initial=m_initial,
return_last_states=return_last_states,
eps=eps,
autocast_kernel_dtype=autocast_kernel_dtype,
chunk_size=chunk_size,
**kwargs,
)
matH = matH[:, :, :S_unpadded, :]
return matH
def wrap_chunkwise_arbitrary_sequence_length(
mlstm_chunkwise_kernel: Callable,
mlstm_sequence_kernel: Callable,
mlstm_step_kernel: Callable,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
fgate: torch.Tensor,
igate: torch.Tensor,
c_initial: torch.Tensor = None,
n_initial: torch.Tensor = None,
m_initial: torch.Tensor = None,
return_last_states: bool = True,
eps: float = 1e-6,
autocast_kernel_dtype: torch.dtype = torch.bfloat16,
chunk_size: int = 64,
enable_logging: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""This function computes the last hidden state and matH outputs of the mLSTM, independently of the sequence length.
For this it uses three kernels:
- mlstm_chunkwise_kernel: mlstm chunkwise kernels that processes chunks of a given chunk size in parallel.
- mlstm_sequence_kernel: mlstm kernel that processes the remaining sequence length in a single step recurrence.
- mlstm_step_kernel: mlstm kernel that processes a sequence length of 1 in a single step.
It tries to maximize the chunksizes to improve performance.
It will start with the given chunk size and then divides the chunksize by 2 until the chunk size is smaller than 16.
At every chunksize it will process the maximal number of chunks that fit into the remaining sequence length.
E.g. for chunk_size = 64, this function will try the chunksizes [64, 32, 16] if necessary.
For the remaining sequence length, which is smaller than 16, we use a different kernel that computes the mLSTM
in a single step and loop over this in pytorch.
Args:
mlstm_chunkwise_kernel: The mLSTM chunkwise kernel that processes chunks of a given chunk size in parallel
mlstm_sequence_kernel: The mLSTM kernel that processes the remaining sequence length in a single step recurrence
query: The query tensor (batch_size, nh, sequence_length, dhqk)
key: The key tensor (batch_size, nh, sequence_length, dhqk)
value: The value tensor (batch_size, nh, sequence_length, dhhv)
fgate: The forget gate tensor (batch_size, nh, sequence_length)
igate: The input gate tensor (batch_size, nh, sequence_length)
c_initial: The initial cell state tensor (batch_size, nh, dhqk, dhhv)
n_initial: The initial hidden state tensor (batch_size, nh, dhqk)
m_initial: The initial memory state tensor (batch_size, nh, 1)
return_last_states: If True, the function will return the last states of the mLSTM
eps: The epsilon value used for numerical stability
autocast_kernel_dtype: The dtype used for the kernel computation
chunk_size: The chunk size used for the chunkwise kernel
enable_logging: If True, the function will log debug information. Default is False.
Returns:
The last hidden state tensor (batch_size, nh, sequence_length, dhhv) or a tuple containing the last hidden state tensor and the last states of the mLSTM
Last states are (cstate (batch_size, nh, dhqk, dhhv), nstate (batch_size, nh, dhqk), mstate (batch_size, nh, 1)).
"""
batch_size, nh, sequence_length, dhqk = key.shape
dhhv = value.shape[-1]
c_state = (
c_initial
if c_initial is not None
else torch.zeros(batch_size, nh, dhqk, dhhv, device=key.device, dtype=torch.float32)
)
n_state = (
n_initial
if n_initial is not None
else torch.zeros(batch_size, nh, dhqk, device=key.device, dtype=torch.float32)
)
m_state = (
m_initial
if m_initial is not None
else torch.zeros(batch_size, nh, 1, device=key.device, dtype=torch.float32)
)
if sequence_length > 1:
# process the sequence length in chunks
h_outs = []
seq_len_start_idx = 0
remaining_seq_len = sequence_length - seq_len_start_idx
num_chunks = remaining_seq_len // chunk_size
if num_chunks > 0:
iter_seq_len = chunk_size * num_chunks
seq_len_idx = seq_len_start_idx + iter_seq_len
h_out, (c_state, n_state, m_state) = mlstm_chunkwise_kernel(
query=query[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
key=key[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
value=value[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
fgate=fgate[..., seq_len_start_idx:seq_len_idx].contiguous(),
igate=igate[..., seq_len_start_idx:seq_len_idx].contiguous(),
c_initial=c_state,
n_initial=n_state,
m_initial=m_state,
chunk_size=chunk_size,
return_last_states=True,
autocast_kernel_dtype=autocast_kernel_dtype,
eps=eps,
)
seq_len_start_idx += iter_seq_len
h_outs.append(h_out)
remaining_seq_len = sequence_length - seq_len_start_idx
if remaining_seq_len > 0:
# we use here matK as query as this kernel does not need a query, since we do not care about the outputs only about the last state
h_out, (c_state, n_state, m_state) = mlstm_sequence_kernel(
query=query[..., seq_len_start_idx:sequence_length, :].contiguous(),
key=key[..., seq_len_start_idx:sequence_length, :].contiguous(),
value=value[..., seq_len_start_idx:sequence_length, :].contiguous(),
igate=igate[..., seq_len_start_idx:sequence_length].contiguous(),
fgate=fgate[..., seq_len_start_idx:sequence_length].contiguous(),
c_initial=c_state,
n_initial=n_state,
m_initial=m_state,
return_last_states=True,
eps=eps,
)
h_outs.append(h_out)
h_out = torch.concatenate(h_outs, dim=2)
else:
if sequence_length != 1:
raise ValueError(
f"Received empty sequence (sequence_length={sequence_length}), require at least single element in the sequence."
)
# process the sequence length in a single step
# while this case is also captured by the regular mode above,
# it avoids the overhead of the loop and calls the step kernel directly
# The step function does not want a sequence dimension
# qkv shape is (batch_size, nh, dhqk/dhv)
# igate, fgate shape is (batch_size, nh, 1)
h_out, (c_state, n_state, m_state) = mlstm_step_kernel(
query=query.squeeze(2),
key=key.squeeze(2),
value=value.squeeze(2),
igate=igate,
fgate=fgate,
cstate=c_state,
nstate=n_state,
mstate=m_state,
eps=eps,
)
h_out = h_out[:, :, None, :]
if return_last_states:
return h_out, (c_state, n_state, m_state)
else:
return h_out
class xLSTMBackend(nn.Module):
"""xLSTM Backend Module for PyTorch.
This module wraps the xLSTM kernels and provides a high-level interface for training and inference.
"""
config_class = xLSTMConfig
def __init__(self, config: xLSTMConfig):
super().__init__()
self.config = config
self.chunkwise_kernel_fn = mlstm_chunkwise_native_autograd
self.sequence_kernel_fn = mlstm_recurrent_sequence_native
self.step_kernel_fn = mlstm_recurrent_step_native
self._inference_fn = partial(
wrap_chunkwise_arbitrary_sequence_length,
mlstm_chunkwise_kernel=self.chunkwise_kernel_fn,
mlstm_sequence_kernel=partial(
self.sequence_kernel_fn,
dtype_state=getattr(torch, config.inference_state_dtype),
),
mlstm_step_kernel=partial(
self.step_kernel_fn,
dtype_state=getattr(torch, config.inference_state_dtype),
),
chunk_size=config.chunk_size,
eps=config.eps,
autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
return_last_states=True,
)
train_kernel_fn = partial(
self.chunkwise_kernel_fn,
autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
eps=config.eps,
chunk_size=config.chunk_size,
)
if "with_padding" in config.mode:
train_kernel_fn = partial(wrap_chunkwise_pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn)
self._train_fn = train_kernel_fn
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
igate: torch.Tensor,
fgate: torch.Tensor,
c_initial: torch.Tensor = None,
n_initial: torch.Tensor = None,
m_initial: torch.Tensor = None,
return_last_states: bool = False,
mode: Optional[Literal["train", "inference"]] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""Forward pass of the mLSTM backend.
Depending on the configured mode, this method will call the appropriate kernel function.
Args:
query: The query tensor of shape (batch_size, nh, sequence_length, dhqk).
key: The key tensor of shape (batch_size, nh, sequence_length, dhqk).
value: The value tensor of shape (batch_size, nh, sequence_length, dhhv).
igate: The input gate preactivation tensor of shape (batch_size, nh, sequence_length).
fgate: The forget gate preactivation tensor of shape (batch_size, nh, sequence_length).
c_initial: The initial cell state tensor of shape (batch_size, nh, dhqk, dhhv).
Defaults to None.
n_initial: The initial hidden state tensor of shape (batch_size, nh, dhqk). Defaults to None.
m_initial: The initial memory tensor of shape (batch_size, nh, 1). Defaults to None.
return_last_states: Whether to return the last states of the sequence. Defaults to None.
If None, the value from the config is used.
Returns:
hidden states of shape (batch_size, nh, sequence_length, dhhv)
hidden states and last states the last states are the cell state cstate (batch_size, nh, dhqk, dhhv),
the normalizer state nstate (batch_size, nh, dhqk), and the max state mstate (batch_size, nh, 1)
"""
if mode is None:
mode = self.config.mode
if "train" in mode:
if return_last_states is None:
return_last_states = self.config.return_last_states
if self.config.mode == "train_with_padding":
if return_last_states:
raise ValueError("return_last_states=True is not supported with train_with_padding mode.")
return self._train_fn(
query=query,
key=key,
value=value,
igate=igate,
fgate=fgate,
c_initial=c_initial,
n_initial=n_initial,
m_initial=m_initial,
return_last_states=return_last_states,
)
elif "inference" in mode:
# inference mode always returns the last states
return self._inference_fn(
query=query,
key=key,
value=value,
igate=igate,
fgate=fgate,
c_initial=c_initial,
n_initial=n_initial,
m_initial=m_initial,
)
else:
raise ValueError(f"Unknown mode: {self.config.mode}")
def extra_repr(self) -> str:
return f"{self.config}"
class xLSTMRMSNorm(nn.Module):
"""Root mean square normalization layer implementation similar
to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
It normalizes the input tensor by the root mean square of the last dimension.
Args:
num_features: The number of features in the input tensor.
eps: A small value to avoid division by zero.
use_weight: Whether to use a learnable weight.
use_bias: Whether to use a learnable bias.
force_float32_reductions: Whether to force float32 reductions.
"""
def __init__(
self,
num_features: int,
eps: float = 1e-6,
use_weight: bool = True,
use_bias: bool = False,
force_float32_reductions: bool = True,
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.force_float32_reductions = force_float32_reductions
if use_weight:
self.weight = nn.Parameter(torch.ones(num_features))
else:
self.weight = None
if use_bias:
self.bias = nn.Parameter(torch.zeros(num_features))
else:
self.bias = None
def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
if self.weight is not None:
x = x * self.weight
if self.bias is not None:
x = x + self.bias
return x
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
# apply rms norm over the last dimension, i.e. HD dimension
in_dtype = x.dtype
if self.force_float32_reductions:
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return x.to(in_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._rms_normalize(x)
x = self._apply_weight_bias(x)
return x
class xLSTMMultiHeadLayerNorm(nn.Module):
"""Multi-head version of the LayerNorm layer.
It normalizes the last dimension of the input tensor.
The input is assumed to have the shape (batch_size, sequence_length, nh, DH), where:
batch_size: batch size
sequence_length: sequence length
nh: number of heads
DH: head dimension
The normalization is applied over the last dimension (DH) of the input tensor.
Args:
num_heads: The number of heads.
head_dim: The head dimension.
eps: A small value to avoid division by zero.
use_weight: Whether to use a learnable weight.
use_bias: Whether to use a learnable bias.
force_float32_reductions: Whether to force float32 reductions
Returns:
The normalized tensor with the shape (batch_size, sequence_length, nh * DH).
"""
def __init__(
self,
num_heads: int,
head_dim: int,
eps: float = 1e-6,
use_weight: bool = True,
use_bias: bool = False,
force_float32_reductions: bool = True,
):
super().__init__()
self.num_features = num_heads * head_dim
self.eps = eps
self.force_float32_reductions = force_float32_reductions
if use_weight:
self.weight = nn.Parameter(torch.ones(self.num_features))
else:
self.weight = None
if use_bias:
self.bias = nn.Parameter(torch.zeros(self.num_features))
else:
self.bias = None
self.num_heads = num_heads
self.head_dim = head_dim
def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
if self.weight is not None:
x = x * self.weight
if self.bias is not None:
x = x + self.bias
return x
def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
# apply layer norm over the last dimension, i.e. HD dimension
in_dtype = x.dtype
if self.force_float32_reductions:
x = x.float()
x_centered = x - x.mean(dim=-1, keepdim=True)
y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
return y.to(in_dtype)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
batch_size, sequence_length, nh, DH = x.shape
if nh != self.num_heads:
raise ValueError(f"Expected {self.num_heads} heads, got {nh}, input shape: {x.shape}")
if DH != self.head_dim:
raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}")
x = self._layer_normalize(x)
x = x.reshape(batch_size, sequence_length, -1)
x = self._apply_weight_bias(x)
return x
class xLSTMFeedForward(nn.Module):
def __init__(self, config: xLSTMConfig):
super().__init__()
self.config = config
self.up_proj_dim = round_up_to_next_multiple_of(
config.hidden_size * config.ffn_proj_factor,
config.ffn_round_up_to_multiple_of,
)
if self.config.weight_mode == "single":
self.proj_up_gate = nn.Linear(
in_features=config.hidden_size,
out_features=self.up_proj_dim,
bias=self.config.use_bias,
)
self.proj_up = nn.Linear(
in_features=config.hidden_size,
out_features=self.up_proj_dim,
bias=self.config.use_bias,
)
elif self.config.weight_mode == "fused":
self.proj_up_gate_z = nn.Linear(
in_features=config.hidden_size,
out_features=2 * self.up_proj_dim,
bias=self.config.use_bias,
)
self.proj_down = nn.Linear(
in_features=self.up_proj_dim,
out_features=config.hidden_size,
bias=self.config.use_bias,
)
self.act_fn = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.config.weight_mode == "single":
x = self.act_fn(self.proj_up_gate(x)) * self.proj_up(x)
elif self.config.weight_mode == "fused":
x = self.proj_up_gate_z(x)
gate, z = torch.tensor_split(x, (self.up_proj_dim,), dim=-1)
x = self.act_fn(gate) * z
y = self.proj_down(x)
return y
class xLSTMLayer(nn.Module):
def __init__(self, config: xLSTMConfig):
super().__init__()
self.config = config
self.v_dim = int(config.hidden_size * config.v_dim_factor)
self.qk_dim = int(config.hidden_size * config.qk_dim_factor)
if self.config.weight_mode == "single":
self.query = nn.Linear(
in_features=self.config.hidden_size,
out_features=self.qk_dim,
bias=self.config.use_bias,
)
self.key = nn.Linear(
in_features=self.config.hidden_size,
out_features=self.qk_dim,
bias=self.config.use_bias,
)
self.value = nn.Linear(
in_features=self.config.hidden_size,
out_features=self.v_dim,
bias=self.config.use_bias,
)
self.ogate_preact = nn.Linear(
in_features=self.config.hidden_size,
out_features=self.v_dim,
bias=self.config.use_bias,
)
self.igate_preact = nn.Linear(
in_features=self.config.hidden_size,
out_features=self.config.num_heads,
bias=True,
)
self.fgate_preact = nn.Linear(
in_features=self.config.hidden_size,
out_features=self.config.num_heads,
bias=True,
)
elif self.config.weight_mode == "fused":
self.qkv_opreact = nn.Linear(
in_features=self.config.hidden_size,
out_features=2 * self.qk_dim + 2 * self.v_dim,
bias=self.config.use_bias,
)
self.ifgate_preact = nn.Linear(
in_features=self.config.hidden_size,
out_features=2 * self.config.num_heads,
bias=True,
)
self.ogate_act_fn = nn.Sigmoid()
self.mlstm_backend = xLSTMBackend(config=self.config)
self.multihead_norm = xLSTMMultiHeadLayerNorm(
num_heads=self.config.num_heads,
head_dim=self.v_dim // self.config.num_heads,
eps=self.config.norm_eps,
use_weight=True,
use_bias=self.config.use_bias,
force_float32_reductions=self.config.norm_reduction_force_float32,
)
self.out_proj = nn.Linear(
in_features=self.v_dim,
out_features=self.config.hidden_size,
bias=self.config.use_bias,
)
def forward(
self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None
) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]:
if x.ndim != 3:
raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}")
batch_size, sequence_length, _ = x.shape
if self.config.weight_mode == "single":
query = self.query(x)
key = self.key(x)
value = self.value(x)
o_preact = self.ogate_preact(x)
i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap)
f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap)
elif self.config.weight_mode == "fused":
qkv_opreact = self.qkv_opreact(x)
query, key, value, o_preact = torch.tensor_split(
qkv_opreact,
(
self.qk_dim,
2 * self.qk_dim,
2 * self.qk_dim + self.v_dim,
),
dim=-1,
)
if_preact = soft_cap(self.ifgate_preact(x), cap_value=self.config.gate_soft_cap)
i_preact, f_preact = torch.tensor_split(if_preact, (self.config.num_heads,), dim=-1)
query = query.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
key = key.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
value = value.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
i_preact = i_preact.transpose(1, 2)
f_preact = f_preact.transpose(1, 2)
if state is None:
c_initial, n_initial, m_initial = None, None, None
else:
c_initial, n_initial, m_initial = state
h, state = self.mlstm_backend(
query=query,
key=key,
value=value,
igate=i_preact,
fgate=f_preact,
c_initial=c_initial,
n_initial=n_initial,
m_initial=m_initial,
)
expected_h_shape = (
batch_size,
self.config.num_heads,
sequence_length,
self.v_dim // self.config.num_heads,
)
if h.shape != expected_h_shape:
raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
h = h.transpose(1, 2)
h_norm = self.multihead_norm(h)
h_norm = h_norm.reshape(batch_size, sequence_length, -1)
h_out = self.ogate_act_fn(o_preact) * h_norm
y = self.out_proj(h_out)
return y, state
class xLSTMBlock(nn.Module):
def __init__(self, config: xLSTMConfig):
super().__init__()
self.config = config
self.norm_mlstm = xLSTMRMSNorm(
num_features=config.hidden_size,
eps=config.norm_eps,
use_weight=True,
use_bias=config.use_bias,
force_float32_reductions=config.norm_reduction_force_float32,
)
self.mlstm_layer = xLSTMLayer(config)
self.norm_ffn = xLSTMRMSNorm(
num_features=config.hidden_size,
eps=config.norm_eps,
use_weight=True,
use_bias=config.use_bias,
force_float32_reductions=config.norm_reduction_force_float32,
)
self.ffn = xLSTMFeedForward(config)
def forward(
self, x: torch.Tensor, state: Optional[mLSTMStateType] = None
) -> tuple[torch.Tensor, mLSTMStateType]:
x_mlstm = self.norm_mlstm(x)
x_mlstm, state = self.mlstm_layer(x_mlstm, state)
x = x + x_mlstm
x_ffn = self.norm_ffn(x)
x_ffn = self.ffn(x_ffn)
x = x + x_ffn
return x, state
def small_init_method(dim):
"""
Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
std = (2 / (5 * dim)) ** (1 / 2)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def wang_init_method(n_layers, dim):
"""
Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
"""
std = 2 / n_layers / dim ** (1 / 2)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class xLSTMPreTrainedModel(PreTrainedModel):
"""
An abstract class for an interface to loading a pre-trained xLSTM model.
"""
config_class = xLSTMConfig
base_model_prefix = "backbone"
_no_split_modules = ["xLSTMBlock"]
supports_gradient_checkpointing = True
_is_stateful = True
def _module_name_map(self, module):
for name, mod in self.named_modules():
if mod is module:
return name
return ""
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
small_init_method(self.config.hidden_size)(self.embeddings.weight)
elif isinstance(module, nn.Linear):
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
if self.config.weight_mode == "single" and "gate" in self._module_name_map(module):
torch.nn.init.zeros_(module.weight)
with torch.no_grad():
if "igate" in self._module_name_map(module):
module.bias.copy_(-10.0 * torch.ones_like(module.bias))
elif "fgate" in self._module_name_map(module):
module.bias.copy_(
torch.linspace(
3.0,
6.0,
module.bias.shape[-1],
).to(
device=module.bias.device,
dtype=module.bias.dtype,
)
)
elif self.config.weight_mode == "fused" and "gate" in self._module_name_map(module):
torch.nn.init.zeros_(module.weight)
with torch.no_grad():
module.bias[: self.config.num_heads] += -module.bias[
: self.config.num_heads
] - 10.0 * torch.ones_like(module.bias)
module.bias[: self.config.num_heads] += -module.bias[self.config.num_heads :] + torch.linspace(
3.0,
6.0,
module.bias.shape[-1],
).to(
device=module.bias.device,
dtype=module.bias.dtype,
)
elif "proj_down" in self._module_name_map(module):
wang_init_method(dim=module.weight.shape[1], n_layers=self.config.num_hidden_layers)(module.weight)
elif "out_proj" in self._module_name_map(module):
wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)(module.weight)
elif module.weight is not None:
small_init_method(self.config.hidden_size)(module.weight)
elif isinstance(module, xLSTMRMSNorm) or hasattr(module, "_layer_normalize"):
torch.nn.init.ones_(module.weight)
if hasattr(module, "bias") and module.bias is not None:
torch.nn.init.zeros_(module.bias)
class xLSTMCache:
"""
Cache for xLSTM model which does not have attention mechanism and key value states.
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The batch size with which the model will be used.
dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
Attributes:
seqlen_offset: int
dtype: torch.dtype
Example:
```python
>>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache
>>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b")
>>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b")
>>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, cache_params=cache_params, use_cache=True)
>>> outputs.cache_params
xLSTMCache()
"""
def __init__(
self,
config: xLSTMConfig,
max_batch_size: int,
dtype: torch.dtype = torch.bfloat16,
device: Optional[str] = None,
**kwargs,
):
self.seqlen_offset = 0
self.dtype = dtype
self.config = config
self.rnn_state = {
layer: (
torch.zeros(
[max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim],
dtype=dtype,
device=device,
),
torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device),
torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device),
)
for layer in range(config.num_hidden_layers)
}
def reset(self):
self.rnn_state = {
layer: (
torch.zeros_like(self.rnn_state[layer][0]),
torch.zeros_like(self.rnn_state[layer][1]),
torch.zeros_like(self.rnn_state[layer][2]),
)
for layer in self.rnn_state
}
@dataclass
@auto_docstring
class xLSTMOutput(ModelOutput):
r"""
cache_params (`xLSTMCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
"""
last_hidden_state: Optional[torch.FloatTensor]
cache_params: Optional[xLSTMCache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
@auto_docstring
class xLSTMModel(xLSTMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# use embbeding_dim and num_blocks once here to make use of them
self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
self.blocks = nn.ModuleList([xLSTMBlock(config) for _ in range(config.num_blocks)])
self.out_norm = xLSTMRMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, new_embedding):
self.embeddings = new_embedding
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[xLSTMCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> Union[tuple, xLSTMOutput]:
r"""
cache_params (`xLSTMCache`, *optional*):
The xLSTMCache that carries the RNN states.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if use_cache and cache_params is None:
cache_params = xLSTMCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
hidden_states = inputs_embeds
if (
not self.training
and self.config.max_inference_chunksize < hidden_states.shape[1]
and not output_hidden_states
):
offset = 0
with torch.no_grad():
if cache_params is None:
cache_params = xLSTMCache(config=self.config, batch_size=hidden_states.shape[0])
final_state = torch.zeros_like(hidden_states)
while offset < hidden_states.shape[1]:
hidden_states_chunk = hidden_states[
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
]
for layer_idx, xlstm_block in enumerate(self.blocks):
hidden_states_chunk, rnn_state = xlstm_block(
hidden_states_chunk,
state=cache_params.rnn_state[layer_idx],
)
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
local_rnn_state = rnn_state[state_idx]
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
cache_params.rnn_state_initial = False
final_state[
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
] = hidden_states_chunk
offset += self.config.max_inference_chunksize
hidden_states = final_state
else:
all_hidden_states = () if output_hidden_states else None
for layer_idx, xlstm_block in enumerate(self.blocks):
if self.gradient_checkpointing and self.training:
hidden_states, rnn_state = self._gradient_checkpointing_func(
xlstm_block.__call__,
hidden_states,
cache_params.rnn_state[layer_idx] if cache_params is not None else None,
)
else:
hidden_states, rnn_state = xlstm_block(
hidden_states,
state=cache_params.rnn_state[layer_idx] if cache_params is not None else None,
)
if cache_params:
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
local_rnn_state = rnn_state[state_idx]
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
cache_params.rnn_state_initial = False
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if use_cache:
cache_params.seqlen_offset += inputs_embeds.shape[1]
hidden_states = self.out_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return xLSTMOutput(
last_hidden_state=hidden_states,
cache_params=cache_params,
hidden_states=all_hidden_states,
)
@dataclass
@auto_docstring
class xLSTMCausalLMOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (`xLSTMCache`, *optional*, carrying the RNN states):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
cache_params: Optional[xLSTMCache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
@auto_docstring
class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.backbone = xLSTMModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_input_embeddings(self):
return self.backbone.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
return self.backbone.set_input_embeddings(new_embeddings)
def prepare_inputs_for_generation(
self,
input_ids,
inputs_embeds=None,
use_cache=None,
cache_params: Optional[xLSTMCache] = None,
**kwargs,
):
if use_cache and cache_params is not None:
# If the first cache position is non-zero, we assume we are in generation mode.
# Thus, the cache_params state is assumed to be the state before the last token
# (lastly generated token), and all previous tokens are already ingested.
# This should as well support generation from scratch with the [BOS] token inserted first.
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update({"cache_params": cache_params, "use_cache": use_cache})
return model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[xLSTMCache] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> Union[tuple, xLSTMCausalLMOutput]:
r"""
cache_params (`xLSTMCache`, *optional*):
The xLSTMCache that carries the RNN states.
"""
xlstm_outputs = self.backbone(
input_ids,
cache_params=cache_params,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
**kwargs,
)
hidden_states = xlstm_outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
if not self.training and self.config.max_inference_chunksize < logits.shape[1]:
offset = 0
with torch.no_grad():
while offset < logits.shape[1]:
logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap(
logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])],
self.config.output_logit_soft_cap,
)
offset += self.config.max_inference_chunksize
else:
logits = soft_cap(logits, self.config.output_logit_soft_cap)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < nstate predict nstate
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return xLSTMCausalLMOutput(
loss=loss,
logits=logits,
cache_params=xlstm_outputs.cache_params,
hidden_states=xlstm_outputs.hidden_states,
)
__all__ = [
"xLSTMForCausalLM",
"xLSTMModel",
"xLSTMPreTrainedModel",
]