# Copyright 2025 NXAI GmbH. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch xLSTM Model.""" from dataclasses import dataclass from typing import Optional, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_xlstm_available from .configuration_xlstm import xLSTMConfig if is_xlstm_available(): from xlstm.xlstm_large.model import mLSTMBlock as xLSTMBlock from xlstm.xlstm_large.model import mLSTMStateType, soft_cap from xlstm.xlstm_large.model import xLSTMRMSNorm as xLSTMRMSNorm external_xlstm = True else: from functools import partial from typing import Callable, Literal from .configuration_xlstm import round_up_to_next_multiple_of mLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor, torch.Tensor] mLSTMStateType = dict[int, mLSTMLayerStateType] external_xlstm = False def soft_cap(values: torch.Tensor, cap_value: Optional[Union[float, torch.Tensor]] = None) -> torch.Tensor: """ Soft caps a tensor to a value. Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention and output language heads to prevent large logits from dominating the softmax. See for example Gemma2: https://arxiv.org/abs/2408.00118 Args: values: The tensor to cap. cap_value: The value to cap the values to. If None, no cap is applied. Returns: The capped values. """ if cap_value is None: return values return cap_value * torch.tanh(values / cap_value) def mlstm_chunkwise_recurrent_fw_C( matK: torch.Tensor, matV: torch.Tensor, vecB: torch.Tensor, vecI: torch.Tensor, matC_states: torch.Tensor = None, vecN_states: torch.Tensor = None, scaMinter_states: torch.Tensor = None, matC_initial: torch.Tensor = None, vecN_initial: torch.Tensor = None, scaMinter_initial: torch.Tensor = None, qk_scale: Optional[float] = None, chunk_size: int = 64, num_chunks: int = 1, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, nh, _, dhqk, dhhv = *matK.shape, matV.shape[-1] nc = num_chunks _dtype, _device = matK.dtype, matK.device if qk_scale is None: qk_scale = dhqk**-0.5 # initialize the states tensors if matC_states is None: matC_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk, dhhv), dtype=_dtype, device=_device) if vecN_states is None: vecN_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk), dtype=_dtype, device=_device) if scaMinter_states is None: scaMinter_states = torch.zeros((batch_size, nh, (nc + 1)), dtype=_dtype, device=_device) # assign the initial states to the running states matC_k = ( torch.zeros((batch_size, nh, dhqk, dhhv), dtype=_dtype, device=_device) if matC_initial is None else matC_initial ) vecN_k = ( torch.zeros((batch_size, nh, dhqk), dtype=_dtype, device=_device) if vecN_initial is None else vecN_initial ) scaM_inter_k = ( torch.zeros((batch_size, nh, 1), dtype=_dtype, device=_device) if scaMinter_initial is None else scaMinter_initial ) vecA = vecB[..., -1, None] - vecB + vecI scaG = vecB[..., -1] scaA_max = vecA.max(-1).values scaM_inter_k = scaM_inter_k.squeeze(-1) for key in range(0, num_chunks): # store the states from the previous iteration before updating them # in the first iteration, these are the initial states matC_states[:, :, key * dhqk : (key + 1) * dhqk, :] = matC_k vecN_states[:, :, key * dhqk : (key + 1) * dhqk] = vecN_k scaMinter_states[:, :, key] = scaM_inter_k # m_k update scaA_max_k = scaA_max[:, :, key] scaG_k = scaG[:, :, key] scaM_inter_k_next = torch.max(scaG_k + scaM_inter_k, scaA_max_k) # C_k update matK_chunk = matK[:, :, key * chunk_size : (key + 1) * chunk_size, :] # * qk_scale matV_chunk = matV[:, :, key * chunk_size : (key + 1) * chunk_size, :] vecA_k = vecA[:, :, key, :] vecAbar_k = torch.exp(vecA_k - scaM_inter_k_next[..., None])[:, :, :, None] matK_chunk_gated = matK_chunk * vecAbar_k scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None] # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk) # n_k update vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1) # move to the next iteration scaM_inter_k = scaM_inter_k_next matC_k = matC_k_next vecN_k = vecN_k_next # store the states from the last iteration matC_states[:, :, -dhqk:, :] = matC_k vecN_states[:, :, -dhqk:] = vecN_k scaMinter_states[:, :, -1] = scaM_inter_k return matC_states, vecN_states, scaMinter_states def mlstm_chunkwise_parallel_fw_H( matQ: torch.Tensor, matK: torch.Tensor, matV: torch.Tensor, # these states must be all states up to the last chunk, i.e. :-1 matC_states: torch.Tensor, vecN_states: torch.Tensor, scaMinter_states: torch.Tensor, vecI: torch.Tensor, vecB: torch.Tensor, qk_scale: float, chunk_size: int = 64, num_chunks: int = 1, eps: float = 1e-6, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: _device = matQ.device nc, chunk_size = num_chunks, chunk_size batch_size, nh, dqk, dhv = matC_states.shape matC_k_states = matC_states.view(batch_size, nh, nc, dqk // nc, dhv) vecN_k_states = vecN_states.view(batch_size, nh, nc, dqk // nc) scaMinter_k_states = scaMinter_states matQ = matQ.view(batch_size, nh, nc, chunk_size, dqk) matK = matK.view(batch_size, nh, nc, chunk_size, dqk) matV = matV.view(batch_size, nh, nc, chunk_size, dhv) ltr = torch.tril( torch.ones( (chunk_size, chunk_size), dtype=torch.bool, device=_device, ) ) # Compute intra chunk contribution: H_intra matF_logsig_chunk = vecB[:, :, :, :, None] - vecB[:, :, :, None, :] matF_logsig_mask_chunk = torch.where(ltr, matF_logsig_chunk, -float("inf")) matLogD_chunk = matF_logsig_mask_chunk + vecI[:, :, :, None, :] # max_state intra vecMintra_k = torch.max(matLogD_chunk, dim=-1, keepdim=False).values # max_state combined vecM_b_inter = vecB + scaMinter_k_states[:, :, :, None] vecM_k_combine = torch.maximum(vecM_b_inter, vecMintra_k) vecM_k_combine = vecM_k_combine[:, :, :, :, None] vecM_b_inter = vecM_b_inter[:, :, :, :, None] matLogD_stabilized_chunk = matLogD_chunk - vecM_k_combine matD_chunk = torch.exp(matLogD_stabilized_chunk) matS_chunk = (matQ @ matK.transpose(-2, -1)) * qk_scale matM_chunk = matS_chunk * matD_chunk # ? Combine H_intra with H_inter vecBbar = torch.exp(vecM_b_inter - vecM_k_combine) matQ_chunk_gated = matQ * vecBbar * qk_scale matNumerator_common = matQ_chunk_gated @ matC_k_states + matM_chunk @ matV vecDenom_l_common = matQ_chunk_gated @ vecN_k_states.unsqueeze(-1) + matM_chunk.sum(dim=-1, keepdim=True) vecDenom_max_common = torch.maximum(torch.abs(vecDenom_l_common), torch.exp(-vecM_k_combine)) matH_k_chunk = matNumerator_common / (vecDenom_max_common + eps) matH_out = matH_k_chunk.view(batch_size, nh, nc * chunk_size, dhv) # we need the denominator and the overall max state for the backward pass vecN_out = vecDenom_max_common.reshape(batch_size, nh, nc * chunk_size) vecM_out = vecM_k_combine(batch_size, nh, nc * chunk_size) return matH_out, vecN_out, vecM_out def mlstm_chunkwise_fw( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, igate: torch.Tensor, fgate: torch.Tensor, cstate: torch.Tensor = None, nstate: torch.Tensor = None, mstate: torch.Tensor = None, qk_scale: Optional[float] = None, return_last_states: bool = False, return_all_states: bool = False, chunk_size: int = 64, eps: float = 1e-6, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], ]: batch_size, nh, sequence_length, dhqk = query.shape if sequence_length % chunk_size != 0: raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.") nc = sequence_length // chunk_size vecI = igate.view(batch_size, nh, nc, chunk_size) vecF = fgate.view(batch_size, nh, nc, chunk_size) # compute the gates, the g and the a and b vectors vecF_logsig = fgate.logsigmoid(vecF) vecB = vecF_logsig.cumsum(-1) if qk_scale is None: qk_scale = dhqk**-0.5 #! materialize the C_k, n_k, m_k states for each chunk matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C( matK=key, matV=value, vecB=vecB, vecI=vecI, matC_initial=cstate, vecN_initial=nstate, scaMinter_initial=mstate, qk_scale=qk_scale, chunk_size=chunk_size, num_chunks=nc, ) #! compute the outputs within each chunk matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H( matQ=query, matK=key, matV=value, matC_states=matC_k_states[:, :, :-dhqk, :], vecN_states=vecN_k_states[:, :, :-dhqk], scaMinter_states=scaMinter_k_states[:, :, :-1], vecI=vecI, vecB=vecB, qk_scale=qk_scale, chunk_size=chunk_size, num_chunks=nc, eps=eps, ) ret_tuple = (matH_out, vecN_out, vecM_out) if return_last_states: ret_tuple += ( (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]), ) else: ret_tuple += (None,) if return_all_states: ret_tuple += ((matC_k_states, vecN_k_states, scaMinter_k_states),) else: ret_tuple += (None,) return ret_tuple def mlstm_chunkwise_native_autograd( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, igate: torch.Tensor, fgate: torch.Tensor, c_initial: torch.Tensor = None, n_initial: torch.Tensor = None, m_initial: torch.Tensor = None, return_last_states: bool = False, eps: float = 1e-6, chunk_size: int = 64, **kwargs, ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: batch_size, nh, sequence_length, dhqk = query.shape if sequence_length % chunk_size != 0: raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.") nc = sequence_length // chunk_size vecI = igate.view(batch_size, nh, nc, chunk_size) vecF = fgate.view(batch_size, nh, nc, chunk_size) # compute the gates, the g and the a and b vectors vecF_logsig = F.logsigmoid(vecF) vecB = vecF_logsig.cumsum(-1) qk_scale = dhqk**-0.5 #! materialize the C_k, n_k, m_k states for each chunk matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C( matK=key, matV=value, vecB=vecB, vecI=vecI, matC_initial=c_initial, vecN_initial=n_initial, scaMinter_initial=m_initial, qk_scale=qk_scale, chunk_size=chunk_size, num_chunks=nc, ) #! compute the outputs within each chunk matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H( matQ=query, matK=key, matV=value, matC_states=matC_k_states[:, :, :-dhqk, :], vecN_states=vecN_k_states[:, :, :-dhqk], scaMinter_states=scaMinter_k_states[:, :, :-1], vecI=vecI, vecB=vecB, qk_scale=qk_scale, chunk_size=chunk_size, num_chunks=nc, eps=eps, ) last_states = (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]) if return_last_states: return matH_out, last_states else: return matH_out def mlstm_recurrent_step_native( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, igate: torch.Tensor, fgate: torch.Tensor, cstate: torch.Tensor, nstate: torch.Tensor, mstate: torch.Tensor, eps: float = 1e-6, dtype_state: torch.dtype = torch.float32, **kwargs, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """This is a single step of the mLSTM operation in recurrent form.""" dtype_qkv = query.dtype matC_old = cstate.to(dtype=dtype_state) vecN_old = nstate.to(dtype=dtype_state) scaM_old = mstate.to(dtype=dtype_state) batch_size, nh, dhqk = query.shape _, _, dhhv = value.shape if query.shape != key.shape: raise ValueError("query and key must have the same shape") if matC_old.shape != (batch_size, nh, dhqk, dhhv): raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}") if vecN_old.shape != (batch_size, nh, dhqk): raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}") if scaM_old.shape != (batch_size, nh, 1): raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}") if igate.shape != (batch_size, nh, 1): raise ValueError(f"scaI has wrong shape, got {igate.shape}") if fgate.shape != (batch_size, nh, 1): raise ValueError(f"scaF has wrong shape, got {fgate.shape}") # gates scaF_log = torch.nn.functional.logsigmoid(fgate) # update rule scaM_state_new = torch.max(scaF_log + scaM_old, igate) scaF_act = torch.exp(scaF_log + scaM_old - scaM_state_new) scaI_act = torch.exp(igate - scaM_state_new) vecQ_scaled = query * (dhqk ** (-0.5)) matC_state_new = scaF_act[:, :, :, None] * matC_old + scaI_act[:, :, :, None] * ( key[:, :, :, None] @ value[:, :, None, :] ) vecN_state_new = scaF_act * vecN_old + scaI_act * key h_num = vecQ_scaled[:, :, None, :] @ matC_state_new.to(dtype=dtype_qkv) h_num = h_num.squeeze(2).to(dtype=dtype_state) qn_dotproduct = vecQ_scaled[:, :, None, :] @ vecN_state_new[:, :, :, None].to(dtype=dtype_qkv) qn_dotproduct = qn_dotproduct.squeeze(2) max_val = torch.exp(-scaM_state_new) h_denom = (torch.maximum(qn_dotproduct.abs(), max_val) + eps).to(dtype=dtype_state) h = h_num / h_denom h = h.to(dtype=dtype_qkv) matC_state_new = matC_state_new.to(dtype=dtype_state) vecN_state_new = vecN_state_new.to(dtype=dtype_state) scaM_state_new = scaM_state_new.to(dtype=dtype_state) return h, (matC_state_new, vecN_state_new, scaM_state_new) def mlstm_recurrent_sequence_native( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, igate: torch.Tensor, fgate: torch.Tensor, c_initial: torch.Tensor = None, n_initial: torch.Tensor = None, m_initial: torch.Tensor = None, return_last_states: bool = False, eps: float = 1e-6, dtype_state: torch.dtype = torch.float32, **kwargs, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], ]: batch_size, nh, sequence_length, dhqk = query.shape dhv = value.shape[-1] device = query.device if c_initial is not None: if n_initial is None or m_initial is None: raise ValueError("Initial states must be provided together.") if n_initial is None or m_initial is None: raise ValueError("Initial states must be provided together.") matC_state, vecN_state, vecM_state = ( c_initial.to(dtype=dtype_state), n_initial.to(dtype=dtype_state), m_initial.to(dtype=dtype_state), ) else: # memory state matC_state = torch.zeros((batch_size, nh, dhqk, dhv), dtype=dtype_state, device=device) # normalizer state vecN_state = torch.zeros((batch_size, nh, dhqk), dtype=dtype_state, device=device) # max state vecM_state = torch.zeros((batch_size, nh, 1), dtype=dtype_state, device=device) vecH_list = [] for t in range(sequence_length): # gates vecF_t, vecI_t = fgate[:, :, t, None], igate[:, :, t, None] # projections vecQ_t, vecK_t, vecV_t = query[:, :, t, :], key[:, :, t, :], value[:, :, t, :] # step vecH, (matC_state, vecN_state, vecM_state) = mlstm_recurrent_step_native( cstate=matC_state, nstate=vecN_state, mstate=vecM_state, query=vecQ_t, key=vecK_t, value=vecV_t, igate=vecI_t, fgate=vecF_t, eps=eps, dtype_state=dtype_state, **kwargs, ) vecH_list.append(vecH) matH = torch.stack(vecH_list, dim=-2) if return_last_states: return matH, (matC_state, vecN_state, vecM_state) else: return matH def wrap_chunkwise_pad_zeros( mlstm_chunkwise_kernel: Callable, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, fgate: torch.Tensor, igate: torch.Tensor, c_initial: torch.Tensor = None, n_initial: torch.Tensor = None, m_initial: torch.Tensor = None, return_last_states: bool = False, eps: float = 1e-6, autocast_kernel_dtype: torch.dtype = torch.bfloat16, chunk_size: int = 64, **kwargs, ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: if return_last_states: raise ValueError( "We are padding zeros, so we cannot return last states,", "as they would be not the true last states.", ) batch_size, nh, sequence_length, dhqk = query.shape S_unpadded = sequence_length # padding to chunk size for kernels if sequence_length % chunk_size != 0: S_padded = ((sequence_length + chunk_size - 1) // chunk_size) * chunk_size q_pad = query.new_zeros(batch_size, nh, S_padded, query.shape[3]) k_pad = key.new_zeros(batch_size, nh, S_padded, key.shape[3]) v_pad = value.new_zeros(batch_size, nh, S_padded, value.shape[3]) i_pad = igate.new_zeros(batch_size, nh, S_padded) f_pad = fgate.new_zeros(batch_size, nh, S_padded) q_pad[:, :, :S_unpadded, :] = query k_pad[:, :, :S_unpadded, :] = key v_pad[:, :, :S_unpadded, :] = value i_pad[:, :, :S_unpadded] = igate f_pad[:, :, :S_unpadded] = fgate else: q_pad = query k_pad = key v_pad = value i_pad = igate f_pad = fgate matH = mlstm_chunkwise_kernel( query=q_pad, key=k_pad, value=v_pad, igate=i_pad, fgate=f_pad, c_initial=c_initial, n_initial=n_initial, m_initial=m_initial, return_last_states=return_last_states, eps=eps, autocast_kernel_dtype=autocast_kernel_dtype, chunk_size=chunk_size, **kwargs, ) matH = matH[:, :, :S_unpadded, :] return matH def wrap_chunkwise_arbitrary_sequence_length( mlstm_chunkwise_kernel: Callable, mlstm_sequence_kernel: Callable, mlstm_step_kernel: Callable, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, fgate: torch.Tensor, igate: torch.Tensor, c_initial: torch.Tensor = None, n_initial: torch.Tensor = None, m_initial: torch.Tensor = None, return_last_states: bool = True, eps: float = 1e-6, autocast_kernel_dtype: torch.dtype = torch.bfloat16, chunk_size: int = 64, enable_logging: bool = False, ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """This function computes the last hidden state and matH outputs of the mLSTM, independently of the sequence length. For this it uses three kernels: - mlstm_chunkwise_kernel: mlstm chunkwise kernels that processes chunks of a given chunk size in parallel. - mlstm_sequence_kernel: mlstm kernel that processes the remaining sequence length in a single step recurrence. - mlstm_step_kernel: mlstm kernel that processes a sequence length of 1 in a single step. It tries to maximize the chunksizes to improve performance. It will start with the given chunk size and then divides the chunksize by 2 until the chunk size is smaller than 16. At every chunksize it will process the maximal number of chunks that fit into the remaining sequence length. E.g. for chunk_size = 64, this function will try the chunksizes [64, 32, 16] if necessary. For the remaining sequence length, which is smaller than 16, we use a different kernel that computes the mLSTM in a single step and loop over this in pytorch. Args: mlstm_chunkwise_kernel: The mLSTM chunkwise kernel that processes chunks of a given chunk size in parallel mlstm_sequence_kernel: The mLSTM kernel that processes the remaining sequence length in a single step recurrence query: The query tensor (batch_size, nh, sequence_length, dhqk) key: The key tensor (batch_size, nh, sequence_length, dhqk) value: The value tensor (batch_size, nh, sequence_length, dhhv) fgate: The forget gate tensor (batch_size, nh, sequence_length) igate: The input gate tensor (batch_size, nh, sequence_length) c_initial: The initial cell state tensor (batch_size, nh, dhqk, dhhv) n_initial: The initial hidden state tensor (batch_size, nh, dhqk) m_initial: The initial memory state tensor (batch_size, nh, 1) return_last_states: If True, the function will return the last states of the mLSTM eps: The epsilon value used for numerical stability autocast_kernel_dtype: The dtype used for the kernel computation chunk_size: The chunk size used for the chunkwise kernel enable_logging: If True, the function will log debug information. Default is False. Returns: The last hidden state tensor (batch_size, nh, sequence_length, dhhv) or a tuple containing the last hidden state tensor and the last states of the mLSTM Last states are (cstate (batch_size, nh, dhqk, dhhv), nstate (batch_size, nh, dhqk), mstate (batch_size, nh, 1)). """ batch_size, nh, sequence_length, dhqk = key.shape dhhv = value.shape[-1] c_state = ( c_initial if c_initial is not None else torch.zeros(batch_size, nh, dhqk, dhhv, device=key.device, dtype=torch.float32) ) n_state = ( n_initial if n_initial is not None else torch.zeros(batch_size, nh, dhqk, device=key.device, dtype=torch.float32) ) m_state = ( m_initial if m_initial is not None else torch.zeros(batch_size, nh, 1, device=key.device, dtype=torch.float32) ) if sequence_length > 1: # process the sequence length in chunks h_outs = [] seq_len_start_idx = 0 remaining_seq_len = sequence_length - seq_len_start_idx num_chunks = remaining_seq_len // chunk_size if num_chunks > 0: iter_seq_len = chunk_size * num_chunks seq_len_idx = seq_len_start_idx + iter_seq_len h_out, (c_state, n_state, m_state) = mlstm_chunkwise_kernel( query=query[..., seq_len_start_idx:seq_len_idx, :].contiguous(), key=key[..., seq_len_start_idx:seq_len_idx, :].contiguous(), value=value[..., seq_len_start_idx:seq_len_idx, :].contiguous(), fgate=fgate[..., seq_len_start_idx:seq_len_idx].contiguous(), igate=igate[..., seq_len_start_idx:seq_len_idx].contiguous(), c_initial=c_state, n_initial=n_state, m_initial=m_state, chunk_size=chunk_size, return_last_states=True, autocast_kernel_dtype=autocast_kernel_dtype, eps=eps, ) seq_len_start_idx += iter_seq_len h_outs.append(h_out) remaining_seq_len = sequence_length - seq_len_start_idx if remaining_seq_len > 0: # we use here matK as query as this kernel does not need a query, since we do not care about the outputs only about the last state h_out, (c_state, n_state, m_state) = mlstm_sequence_kernel( query=query[..., seq_len_start_idx:sequence_length, :].contiguous(), key=key[..., seq_len_start_idx:sequence_length, :].contiguous(), value=value[..., seq_len_start_idx:sequence_length, :].contiguous(), igate=igate[..., seq_len_start_idx:sequence_length].contiguous(), fgate=fgate[..., seq_len_start_idx:sequence_length].contiguous(), c_initial=c_state, n_initial=n_state, m_initial=m_state, return_last_states=True, eps=eps, ) h_outs.append(h_out) h_out = torch.concatenate(h_outs, dim=2) else: if sequence_length != 1: raise ValueError( f"Received empty sequence (sequence_length={sequence_length}), require at least single element in the sequence." ) # process the sequence length in a single step # while this case is also captured by the regular mode above, # it avoids the overhead of the loop and calls the step kernel directly # The step function does not want a sequence dimension # qkv shape is (batch_size, nh, dhqk/dhv) # igate, fgate shape is (batch_size, nh, 1) h_out, (c_state, n_state, m_state) = mlstm_step_kernel( query=query.squeeze(2), key=key.squeeze(2), value=value.squeeze(2), igate=igate, fgate=fgate, cstate=c_state, nstate=n_state, mstate=m_state, eps=eps, ) h_out = h_out[:, :, None, :] if return_last_states: return h_out, (c_state, n_state, m_state) else: return h_out class xLSTMBackend(nn.Module): """xLSTM Backend Module for PyTorch. This module wraps the xLSTM kernels and provides a high-level interface for training and inference. """ config_class = xLSTMConfig def __init__(self, config: xLSTMConfig): super().__init__() self.config = config self.chunkwise_kernel_fn = mlstm_chunkwise_native_autograd self.sequence_kernel_fn = mlstm_recurrent_sequence_native self.step_kernel_fn = mlstm_recurrent_step_native self._inference_fn = partial( wrap_chunkwise_arbitrary_sequence_length, mlstm_chunkwise_kernel=self.chunkwise_kernel_fn, mlstm_sequence_kernel=partial( self.sequence_kernel_fn, dtype_state=getattr(torch, config.inference_state_dtype), ), mlstm_step_kernel=partial( self.step_kernel_fn, dtype_state=getattr(torch, config.inference_state_dtype), ), chunk_size=config.chunk_size, eps=config.eps, autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype), return_last_states=True, ) train_kernel_fn = partial( self.chunkwise_kernel_fn, autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype), eps=config.eps, chunk_size=config.chunk_size, ) if "with_padding" in config.mode: train_kernel_fn = partial(wrap_chunkwise_pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn) self._train_fn = train_kernel_fn def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, igate: torch.Tensor, fgate: torch.Tensor, c_initial: torch.Tensor = None, n_initial: torch.Tensor = None, m_initial: torch.Tensor = None, return_last_states: bool = False, mode: Optional[Literal["train", "inference"]] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """Forward pass of the mLSTM backend. Depending on the configured mode, this method will call the appropriate kernel function. Args: query: The query tensor of shape (batch_size, nh, sequence_length, dhqk). key: The key tensor of shape (batch_size, nh, sequence_length, dhqk). value: The value tensor of shape (batch_size, nh, sequence_length, dhhv). igate: The input gate preactivation tensor of shape (batch_size, nh, sequence_length). fgate: The forget gate preactivation tensor of shape (batch_size, nh, sequence_length). c_initial: The initial cell state tensor of shape (batch_size, nh, dhqk, dhhv). Defaults to None. n_initial: The initial hidden state tensor of shape (batch_size, nh, dhqk). Defaults to None. m_initial: The initial memory tensor of shape (batch_size, nh, 1). Defaults to None. return_last_states: Whether to return the last states of the sequence. Defaults to None. If None, the value from the config is used. Returns: hidden states of shape (batch_size, nh, sequence_length, dhhv) hidden states and last states the last states are the cell state cstate (batch_size, nh, dhqk, dhhv), the normalizer state nstate (batch_size, nh, dhqk), and the max state mstate (batch_size, nh, 1) """ if mode is None: mode = self.config.mode if "train" in mode: if return_last_states is None: return_last_states = self.config.return_last_states if self.config.mode == "train_with_padding": if return_last_states: raise ValueError("return_last_states=True is not supported with train_with_padding mode.") return self._train_fn( query=query, key=key, value=value, igate=igate, fgate=fgate, c_initial=c_initial, n_initial=n_initial, m_initial=m_initial, return_last_states=return_last_states, ) elif "inference" in mode: # inference mode always returns the last states return self._inference_fn( query=query, key=key, value=value, igate=igate, fgate=fgate, c_initial=c_initial, n_initial=n_initial, m_initial=m_initial, ) else: raise ValueError(f"Unknown mode: {self.config.mode}") def extra_repr(self) -> str: return f"{self.config}" class xLSTMRMSNorm(nn.Module): """Root mean square normalization layer implementation similar to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html. It normalizes the input tensor by the root mean square of the last dimension. Args: num_features: The number of features in the input tensor. eps: A small value to avoid division by zero. use_weight: Whether to use a learnable weight. use_bias: Whether to use a learnable bias. force_float32_reductions: Whether to force float32 reductions. """ def __init__( self, num_features: int, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = False, force_float32_reductions: bool = True, ): super().__init__() self.num_features = num_features self.eps = eps self.force_float32_reductions = force_float32_reductions if use_weight: self.weight = nn.Parameter(torch.ones(num_features)) else: self.weight = None if use_bias: self.bias = nn.Parameter(torch.zeros(num_features)) else: self.bias = None def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor: if self.weight is not None: x = x * self.weight if self.bias is not None: x = x + self.bias return x def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor: # apply rms norm over the last dimension, i.e. HD dimension in_dtype = x.dtype if self.force_float32_reductions: x = x.float() x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return x.to(in_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._rms_normalize(x) x = self._apply_weight_bias(x) return x class xLSTMMultiHeadLayerNorm(nn.Module): """Multi-head version of the LayerNorm layer. It normalizes the last dimension of the input tensor. The input is assumed to have the shape (batch_size, sequence_length, nh, DH), where: batch_size: batch size sequence_length: sequence length nh: number of heads DH: head dimension The normalization is applied over the last dimension (DH) of the input tensor. Args: num_heads: The number of heads. head_dim: The head dimension. eps: A small value to avoid division by zero. use_weight: Whether to use a learnable weight. use_bias: Whether to use a learnable bias. force_float32_reductions: Whether to force float32 reductions Returns: The normalized tensor with the shape (batch_size, sequence_length, nh * DH). """ def __init__( self, num_heads: int, head_dim: int, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = False, force_float32_reductions: bool = True, ): super().__init__() self.num_features = num_heads * head_dim self.eps = eps self.force_float32_reductions = force_float32_reductions if use_weight: self.weight = nn.Parameter(torch.ones(self.num_features)) else: self.weight = None if use_bias: self.bias = nn.Parameter(torch.zeros(self.num_features)) else: self.bias = None self.num_heads = num_heads self.head_dim = head_dim def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor: if self.weight is not None: x = x * self.weight if self.bias is not None: x = x + self.bias return x def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor: # apply layer norm over the last dimension, i.e. HD dimension in_dtype = x.dtype if self.force_float32_reductions: x = x.float() x_centered = x - x.mean(dim=-1, keepdim=True) y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps) return y.to(in_dtype) def forward( self, x: torch.Tensor, ) -> torch.Tensor: batch_size, sequence_length, nh, DH = x.shape if nh != self.num_heads: raise ValueError(f"Expected {self.num_heads} heads, got {nh}, input shape: {x.shape}") if DH != self.head_dim: raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}") x = self._layer_normalize(x) x = x.reshape(batch_size, sequence_length, -1) x = self._apply_weight_bias(x) return x class xLSTMFeedForward(nn.Module): def __init__(self, config: xLSTMConfig): super().__init__() self.config = config self.up_proj_dim = round_up_to_next_multiple_of( config.hidden_size * config.ffn_proj_factor, config.ffn_round_up_to_multiple_of, ) if self.config.weight_mode == "single": self.proj_up_gate = nn.Linear( in_features=config.hidden_size, out_features=self.up_proj_dim, bias=self.config.use_bias, ) self.proj_up = nn.Linear( in_features=config.hidden_size, out_features=self.up_proj_dim, bias=self.config.use_bias, ) elif self.config.weight_mode == "fused": self.proj_up_gate_z = nn.Linear( in_features=config.hidden_size, out_features=2 * self.up_proj_dim, bias=self.config.use_bias, ) self.proj_down = nn.Linear( in_features=self.up_proj_dim, out_features=config.hidden_size, bias=self.config.use_bias, ) self.act_fn = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: if self.config.weight_mode == "single": x = self.act_fn(self.proj_up_gate(x)) * self.proj_up(x) elif self.config.weight_mode == "fused": x = self.proj_up_gate_z(x) gate, z = torch.tensor_split(x, (self.up_proj_dim,), dim=-1) x = self.act_fn(gate) * z y = self.proj_down(x) return y class xLSTMLayer(nn.Module): def __init__(self, config: xLSTMConfig): super().__init__() self.config = config self.v_dim = int(config.hidden_size * config.v_dim_factor) self.qk_dim = int(config.hidden_size * config.qk_dim_factor) if self.config.weight_mode == "single": self.query = nn.Linear( in_features=self.config.hidden_size, out_features=self.qk_dim, bias=self.config.use_bias, ) self.key = nn.Linear( in_features=self.config.hidden_size, out_features=self.qk_dim, bias=self.config.use_bias, ) self.value = nn.Linear( in_features=self.config.hidden_size, out_features=self.v_dim, bias=self.config.use_bias, ) self.ogate_preact = nn.Linear( in_features=self.config.hidden_size, out_features=self.v_dim, bias=self.config.use_bias, ) self.igate_preact = nn.Linear( in_features=self.config.hidden_size, out_features=self.config.num_heads, bias=True, ) self.fgate_preact = nn.Linear( in_features=self.config.hidden_size, out_features=self.config.num_heads, bias=True, ) elif self.config.weight_mode == "fused": self.qkv_opreact = nn.Linear( in_features=self.config.hidden_size, out_features=2 * self.qk_dim + 2 * self.v_dim, bias=self.config.use_bias, ) self.ifgate_preact = nn.Linear( in_features=self.config.hidden_size, out_features=2 * self.config.num_heads, bias=True, ) self.ogate_act_fn = nn.Sigmoid() self.mlstm_backend = xLSTMBackend(config=self.config) self.multihead_norm = xLSTMMultiHeadLayerNorm( num_heads=self.config.num_heads, head_dim=self.v_dim // self.config.num_heads, eps=self.config.norm_eps, use_weight=True, use_bias=self.config.use_bias, force_float32_reductions=self.config.norm_reduction_force_float32, ) self.out_proj = nn.Linear( in_features=self.v_dim, out_features=self.config.hidden_size, bias=self.config.use_bias, ) def forward( self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None ) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]: if x.ndim != 3: raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}") batch_size, sequence_length, _ = x.shape if self.config.weight_mode == "single": query = self.query(x) key = self.key(x) value = self.value(x) o_preact = self.ogate_preact(x) i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap) f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap) elif self.config.weight_mode == "fused": qkv_opreact = self.qkv_opreact(x) query, key, value, o_preact = torch.tensor_split( qkv_opreact, ( self.qk_dim, 2 * self.qk_dim, 2 * self.qk_dim + self.v_dim, ), dim=-1, ) if_preact = soft_cap(self.ifgate_preact(x), cap_value=self.config.gate_soft_cap) i_preact, f_preact = torch.tensor_split(if_preact, (self.config.num_heads,), dim=-1) query = query.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2) key = key.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2) value = value.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2) i_preact = i_preact.transpose(1, 2) f_preact = f_preact.transpose(1, 2) if state is None: c_initial, n_initial, m_initial = None, None, None else: c_initial, n_initial, m_initial = state h, state = self.mlstm_backend( query=query, key=key, value=value, igate=i_preact, fgate=f_preact, c_initial=c_initial, n_initial=n_initial, m_initial=m_initial, ) expected_h_shape = ( batch_size, self.config.num_heads, sequence_length, self.v_dim // self.config.num_heads, ) if h.shape != expected_h_shape: raise ValueError(f"Got {h.shape}, expected {expected_h_shape}") h = h.transpose(1, 2) h_norm = self.multihead_norm(h) h_norm = h_norm.reshape(batch_size, sequence_length, -1) h_out = self.ogate_act_fn(o_preact) * h_norm y = self.out_proj(h_out) return y, state class xLSTMBlock(nn.Module): def __init__(self, config: xLSTMConfig): super().__init__() self.config = config self.norm_mlstm = xLSTMRMSNorm( num_features=config.hidden_size, eps=config.norm_eps, use_weight=True, use_bias=config.use_bias, force_float32_reductions=config.norm_reduction_force_float32, ) self.mlstm_layer = xLSTMLayer(config) self.norm_ffn = xLSTMRMSNorm( num_features=config.hidden_size, eps=config.norm_eps, use_weight=True, use_bias=config.use_bias, force_float32_reductions=config.norm_reduction_force_float32, ) self.ffn = xLSTMFeedForward(config) def forward( self, x: torch.Tensor, state: Optional[mLSTMStateType] = None ) -> tuple[torch.Tensor, mLSTMStateType]: x_mlstm = self.norm_mlstm(x) x_mlstm, state = self.mlstm_layer(x_mlstm, state) x = x + x_mlstm x_ffn = self.norm_ffn(x) x_ffn = self.ffn(x_ffn) x = x + x_ffn return x, state def small_init_method(dim): """ Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py Fills the input Tensor with values according to the method described in Transformers without Tears: Improving the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution.""" std = (2 / (5 * dim)) ** (1 / 2) def init_(tensor): return torch.nn.init.normal_(tensor, mean=0.0, std=std) return init_ def wang_init_method(n_layers, dim): """ Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py """ std = 2 / n_layers / dim ** (1 / 2) def init_(tensor): return torch.nn.init.normal_(tensor, mean=0.0, std=std) return init_ class xLSTMPreTrainedModel(PreTrainedModel): """ An abstract class for an interface to loading a pre-trained xLSTM model. """ config_class = xLSTMConfig base_model_prefix = "backbone" _no_split_modules = ["xLSTMBlock"] supports_gradient_checkpointing = True _is_stateful = True def _module_name_map(self, module): for name, mod in self.named_modules(): if mod is module: return name return "" def _init_weights(self, module): if isinstance(module, nn.Embedding): small_init_method(self.config.hidden_size)(self.embeddings.weight) elif isinstance(module, nn.Linear): if module.bias is not None: torch.nn.init.zeros_(module.bias) if self.config.weight_mode == "single" and "gate" in self._module_name_map(module): torch.nn.init.zeros_(module.weight) with torch.no_grad(): if "igate" in self._module_name_map(module): module.bias.copy_(-10.0 * torch.ones_like(module.bias)) elif "fgate" in self._module_name_map(module): module.bias.copy_( torch.linspace( 3.0, 6.0, module.bias.shape[-1], ).to( device=module.bias.device, dtype=module.bias.dtype, ) ) elif self.config.weight_mode == "fused" and "gate" in self._module_name_map(module): torch.nn.init.zeros_(module.weight) with torch.no_grad(): module.bias[: self.config.num_heads] += -module.bias[ : self.config.num_heads ] - 10.0 * torch.ones_like(module.bias) module.bias[: self.config.num_heads] += -module.bias[self.config.num_heads :] + torch.linspace( 3.0, 6.0, module.bias.shape[-1], ).to( device=module.bias.device, dtype=module.bias.dtype, ) elif "proj_down" in self._module_name_map(module): wang_init_method(dim=module.weight.shape[1], n_layers=self.config.num_hidden_layers)(module.weight) elif "out_proj" in self._module_name_map(module): wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)(module.weight) elif module.weight is not None: small_init_method(self.config.hidden_size)(module.weight) elif isinstance(module, xLSTMRMSNorm) or hasattr(module, "_layer_normalize"): torch.nn.init.ones_(module.weight) if hasattr(module, "bias") and module.bias is not None: torch.nn.init.zeros_(module.bias) class xLSTMCache: """ Cache for xLSTM model which does not have attention mechanism and key value states. Arguments: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The batch size with which the model will be used. dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): The default `dtype` to use when initializing the layer. device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. Should be the same as the layer. Attributes: seqlen_offset: int dtype: torch.dtype Example: ```python >>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache >>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b") >>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b") >>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True) >>> outputs.cache_params xLSTMCache() """ def __init__( self, config: xLSTMConfig, max_batch_size: int, dtype: torch.dtype = torch.bfloat16, device: Optional[str] = None, **kwargs, ): self.seqlen_offset = 0 self.dtype = dtype self.config = config self.rnn_state = { layer: ( torch.zeros( [max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim], dtype=dtype, device=device, ), torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device), torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device), ) for layer in range(config.num_hidden_layers) } def reset(self): self.rnn_state = { layer: ( torch.zeros_like(self.rnn_state[layer][0]), torch.zeros_like(self.rnn_state[layer][1]), torch.zeros_like(self.rnn_state[layer][2]), ) for layer in self.rnn_state } @dataclass @auto_docstring class xLSTMOutput(ModelOutput): r""" cache_params (`xLSTMCache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. """ last_hidden_state: Optional[torch.FloatTensor] cache_params: Optional[xLSTMCache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None @auto_docstring class xLSTMModel(xLSTMPreTrainedModel): def __init__(self, config): super().__init__(config) # use embbeding_dim and num_blocks once here to make use of them self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim) self.blocks = nn.ModuleList([xLSTMBlock(config) for _ in range(config.num_blocks)]) self.out_norm = xLSTMRMSNorm(config.hidden_size, eps=config.norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings def set_input_embeddings(self, new_embedding): self.embeddings = new_embedding @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, cache_params: Optional[xLSTMCache] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs, ) -> Union[tuple, xLSTMOutput]: r""" cache_params (`xLSTMCache`, *optional*): The xLSTMCache that carries the RNN states. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) if self.gradient_checkpointing and self.training and use_cache: use_cache = False if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) if use_cache and cache_params is None: cache_params = xLSTMCache( self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype ) hidden_states = inputs_embeds if ( not self.training and self.config.max_inference_chunksize < hidden_states.shape[1] and not output_hidden_states ): offset = 0 with torch.no_grad(): if cache_params is None: cache_params = xLSTMCache(config=self.config, batch_size=hidden_states.shape[0]) final_state = torch.zeros_like(hidden_states) while offset < hidden_states.shape[1]: hidden_states_chunk = hidden_states[ :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) ] for layer_idx, xlstm_block in enumerate(self.blocks): hidden_states_chunk, rnn_state = xlstm_block( hidden_states_chunk, state=cache_params.rnn_state[layer_idx], ) for state_idx in range(len(cache_params.rnn_state[layer_idx])): local_rnn_state = rnn_state[state_idx] cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state) cache_params.rnn_state_initial = False final_state[ :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1]) ] = hidden_states_chunk offset += self.config.max_inference_chunksize hidden_states = final_state else: all_hidden_states = () if output_hidden_states else None for layer_idx, xlstm_block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: hidden_states, rnn_state = self._gradient_checkpointing_func( xlstm_block.__call__, hidden_states, cache_params.rnn_state[layer_idx] if cache_params is not None else None, ) else: hidden_states, rnn_state = xlstm_block( hidden_states, state=cache_params.rnn_state[layer_idx] if cache_params is not None else None, ) if cache_params: for state_idx in range(len(cache_params.rnn_state[layer_idx])): local_rnn_state = rnn_state[state_idx] cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state) cache_params.rnn_state_initial = False if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if use_cache: cache_params.seqlen_offset += inputs_embeds.shape[1] hidden_states = self.out_norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) return xLSTMOutput( last_hidden_state=hidden_states, cache_params=cache_params, hidden_states=all_hidden_states, ) @dataclass @auto_docstring class xLSTMCausalLMOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). cache_params (`xLSTMCache`, *optional*, carrying the RNN states): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None cache_params: Optional[xLSTMCache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None @auto_docstring class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) self.backbone = xLSTMModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def get_input_embeddings(self): return self.backbone.get_input_embeddings() def set_input_embeddings(self, new_embeddings): return self.backbone.set_input_embeddings(new_embeddings) def prepare_inputs_for_generation( self, input_ids, inputs_embeds=None, use_cache=None, cache_params: Optional[xLSTMCache] = None, **kwargs, ): if use_cache and cache_params is not None: # If the first cache position is non-zero, we assume we are in generation mode. # Thus, the cache_params state is assumed to be the state before the last token # (lastly generated token), and all previous tokens are already ingested. # This should as well support generation from scratch with the [BOS] token inserted first. input_ids = input_ids[:, -1:] if inputs_embeds is not None: inputs_embeds = inputs_embeds[:, -1:] if inputs_embeds is not None and cache_params is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update({"cache_params": cache_params, "use_cache": use_cache}) return model_inputs @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_params: Optional[xLSTMCache] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs, ) -> Union[tuple, xLSTMCausalLMOutput]: r""" cache_params (`xLSTMCache`, *optional*): The xLSTMCache that carries the RNN states. """ xlstm_outputs = self.backbone( input_ids, cache_params=cache_params, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, **kwargs, ) hidden_states = xlstm_outputs[0] logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() if not self.training and self.config.max_inference_chunksize < logits.shape[1]: offset = 0 with torch.no_grad(): while offset < logits.shape[1]: logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap( logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])], self.config.output_logit_soft_cap, ) offset += self.config.max_inference_chunksize else: logits = soft_cap(logits, self.config.output_logit_soft_cap) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) # Shift so that tokens < nstate predict nstate shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return xLSTMCausalLMOutput( loss=loss, logits=logits, cache_params=xlstm_outputs.cache_params, hidden_states=xlstm_outputs.hidden_states, ) __all__ = [ "xLSTMForCausalLM", "xLSTMModel", "xLSTMPreTrainedModel", ]