302 lines
12 KiB
Python
302 lines
12 KiB
Python
# Copyright 2025 NXAI GmbH. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
"""xLSTM configuration."""
|
|
|
|
from typing import Optional
|
|
|
|
from ...configuration_utils import PretrainedConfig
|
|
from ...utils import is_xlstm_available, logging
|
|
|
|
|
|
if is_xlstm_available():
|
|
from xlstm.xlstm_large.model import (
|
|
BackendModeType,
|
|
ChunkwiseKernelType,
|
|
DtypeType,
|
|
SequenceKernelType,
|
|
StepKernelType,
|
|
WeightModeType,
|
|
round_up_to_next_multiple_of,
|
|
xLSTMLargeConfig,
|
|
)
|
|
|
|
external_xlstm = True
|
|
else:
|
|
from typing import Literal
|
|
|
|
BackendModeType = Literal["train", "train_with_padding", "inference"]
|
|
ChunkwiseKernelType = Literal[
|
|
"chunkwise--native_autograd",
|
|
"parallel--native_autograd",
|
|
]
|
|
DtypeType = Literal["float32", "bfloat16", "float16"]
|
|
SequenceKernelType = Literal["native_sequence__native"]
|
|
StepKernelType = Literal["native"]
|
|
WeightModeType = Literal["single", "fused"]
|
|
|
|
def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
|
|
"""Rounds up x to the next multiple of multiple_of."""
|
|
return int(((x + multiple_of - 1) // multiple_of) * multiple_of)
|
|
|
|
external_xlstm = False
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class xLSTMConfig(PretrainedConfig):
|
|
"""
|
|
This is the configuration class to store the configuration of a [`xLSTM`]. It is used to instantiate a xLSTM
|
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
defaults will yield a similar configuration to that of the xLSTM-7b [NX-AI/xLSTM-7b](https://huggingface.co/NX-AI/xLSTM-7b) model.
|
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PretrainedConfig`] for more information.
|
|
|
|
|
|
Args:
|
|
vocab_size (int, optional, *optional*, defaults to 50304):
|
|
Vocabulary size of the xLSTM model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`xLSTMModel`]. Defaults to the GPT2-NeoX tokenizer size.
|
|
hidden_size (int, optional, *optional*, defaults to 4096):
|
|
Dimensionality of the embeddings or hidden states.
|
|
embedding_dim (int, optional, *optional*, defaults to 4096):
|
|
Dimensionality of the embeddings or hidden states, use hidde_size if None.
|
|
num_hidden_layers (int, optional, *optional*, defaults to 32):
|
|
Number of blocks of the xLSTM model.
|
|
num_blocks (int, optional, *optional*, defaults to 32):
|
|
Number of blocks of the xLSTM model, use num_hidden_layers if None.
|
|
num_heads (int, optional, *optional*, defaults to 8):
|
|
Number of heads for the xLSTM Layer/Cell.
|
|
use_bias (bool, optional, *optional*, defaults to `False`):
|
|
Whether to use biases in the xLSTM model.
|
|
norm_reduction_force_float32 (bool, optional, *optional*, defaults to `True`):
|
|
Whether to force the float32 norm reduction op to be done in fp32 precision.
|
|
tie_word_embeddings (bool, optional, *optional*, defaults to `False`):
|
|
Whether to tie word embeddings to the lm head weights.
|
|
add_out_norm (bool, optional, *optional*, defaults to `True`):
|
|
Whether to add an output norm after the blocks before the LMHead.
|
|
norm_eps (float, optional, *optional*, defaults to 1e-06):
|
|
Norm eps for RMSNorm and Layer Norm.
|
|
qk_dim_factor (float, optional, *optional*, defaults to 0.5):
|
|
Scale factor for the query and key dimension.
|
|
v_dim_factor (float, optional, *optional*, defaults to 1.0):
|
|
Scale factor for the value dimension.
|
|
chunkwise_kernel (ChunkwiseKernelType, optional, *optional*, defaults to `"chunkwise--native_autograd"`):
|
|
Kernel type for chunkwise processing mode.
|
|
sequence_kernel (SequenceKernelType, optional, *optional*, defaults to `"native_sequence__native"`):
|
|
Kernel type for sequence processing mode.
|
|
step_kernel (StepKernelType, optional, *optional*, defaults to `"native"`):
|
|
Kernel type for step processing mode.
|
|
mode (BackendModeType, optional, *optional*, defaults to `"inference"`):
|
|
Operation mode (inference is needed for generation).
|
|
chunk_size (int, optional, *optional*, defaults to 64):
|
|
Internal chunk size.
|
|
return_last_states (bool, optional, *optional*, defaults to `True`):
|
|
If to return the last states / cache internally. Needed as True for generation.
|
|
autocast_kernel_dtype (DtypeType, optional, *optional*, defaults to `"bfloat16"`):
|
|
Kernel dtype for the states.
|
|
eps (float, optional, *optional*, defaults to 1e-06):
|
|
Epsilon for the mLSTM cell post norm.
|
|
inference_state_dtype (DtypeType, optional, *optional*, defaults to `"float32"`):
|
|
Kernel dtype for states in inference.
|
|
ffn_proj_factor (float, optional, *optional*, defaults to 2.667):
|
|
Size factor of the post-up projection gated Feed Forward network.
|
|
ffn_round_up_to_multiple_of (int, optional, *optional*, defaults to 64):
|
|
Size factor round value of the post-up projection gated Feed Forward network.
|
|
gate_soft_cap (float, optional, *optional*, defaults to 15.0):
|
|
Gate soft cap scale.
|
|
output_logit_soft_cap (float, optional, *optional*, defaults to 30.0):
|
|
Output logit soft cap scale.
|
|
weight_mode (`Literal`, *optional*, defaults to `"single"`):
|
|
Whether parallel linear layers are separated or fused (single).
|
|
use_cache (bool, optional, *optional*, defaults to `True`):
|
|
Whether to use the cache (xLSTMCache).
|
|
pad_token_id (int, optional, *optional*, defaults to 1):
|
|
Pad token id needed for generation.
|
|
bos_token_id (int, optional, *optional*, defaults to 0):
|
|
BOS token id needed for generation.
|
|
eos_token_id (int, optional, *optional*, defaults to 2):
|
|
EOS token id needed for generation.
|
|
max_inference_chunksize (int, optional, *optional*, defaults to 16384):
|
|
Limit the chunk size for inference to save memory.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import xLSTMConfig, xLSTMModel
|
|
|
|
>>> # Initializing a xLSTM configuration
|
|
>>> configuration = xLSTMConfig()
|
|
|
|
>>> # Initializing a model (with random weights) from the configuration
|
|
>>> model = xLSTMModel(configuration)
|
|
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
model_type = "xlstm"
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size: int = 50304,
|
|
hidden_size: int = 4096,
|
|
embedding_dim: Optional[int] = None,
|
|
num_hidden_layers: Optional[int] = 32,
|
|
num_blocks: Optional[int] = None,
|
|
num_heads: int = 8,
|
|
use_bias: bool = False,
|
|
norm_reduction_force_float32: bool = True,
|
|
tie_word_embeddings: bool = False,
|
|
add_out_norm: bool = True,
|
|
norm_eps: float = 1e-6,
|
|
# mlstm_layer
|
|
qk_dim_factor: float = 0.5,
|
|
v_dim_factor: float = 1.0,
|
|
# mlstm backend
|
|
chunkwise_kernel: ChunkwiseKernelType = "chunkwise--native_autograd",
|
|
sequence_kernel: SequenceKernelType = "native_sequence__native",
|
|
step_kernel: StepKernelType = "native",
|
|
# nedded to enable generation
|
|
mode: BackendModeType = "inference",
|
|
chunk_size: int = 64,
|
|
# needed to be true for generation
|
|
return_last_states: bool = True,
|
|
autocast_kernel_dtype: DtypeType = "bfloat16",
|
|
eps: float = 1e-6,
|
|
inference_state_dtype: DtypeType = "float32",
|
|
# feedforward
|
|
ffn_proj_factor: float = 2.667,
|
|
ffn_round_up_to_multiple_of: int = 64,
|
|
# capping
|
|
gate_soft_cap: float = 15.0,
|
|
output_logit_soft_cap: float = 30.0,
|
|
# weights
|
|
weight_mode: WeightModeType = "single",
|
|
# HF interface
|
|
use_cache: bool = True,
|
|
pad_token_id: int = 1,
|
|
bos_token_id: int = 0,
|
|
eos_token_id: int = 2,
|
|
max_inference_chunksize: int = 16384,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size if hidden_size is not None else embedding_dim
|
|
self.embedding_dim = embedding_dim if embedding_dim is not None else hidden_size
|
|
self.num_hidden_layers = num_hidden_layers if num_hidden_layers is not None else num_blocks
|
|
self.num_blocks = num_blocks if num_blocks is not None else num_hidden_layers
|
|
self.num_heads = num_heads
|
|
self.use_bias = use_bias
|
|
self.tie_word_embeddings = tie_word_embeddings
|
|
self.add_out_norm = add_out_norm
|
|
self.norm_eps = norm_eps
|
|
self.norm_reduction_force_float32 = norm_reduction_force_float32
|
|
# mlstm_layer
|
|
self.qk_dim_factor = qk_dim_factor
|
|
self.v_dim_factor = v_dim_factor
|
|
# mlstm backend
|
|
self.chunkwise_kernel = chunkwise_kernel
|
|
self.sequence_kernel = sequence_kernel
|
|
self.step_kernel = step_kernel
|
|
self.mode = mode
|
|
self.chunk_size = chunk_size
|
|
self.return_last_states = return_last_states
|
|
self.autocast_kernel_dtype = autocast_kernel_dtype
|
|
self.eps = eps
|
|
self.inference_state_dtype = inference_state_dtype
|
|
# feedforward
|
|
self.ffn_proj_factor = ffn_proj_factor
|
|
self.ffn_round_up_to_multiple_of = ffn_round_up_to_multiple_of
|
|
# capping
|
|
self.gate_soft_cap = gate_soft_cap
|
|
self.output_logit_soft_cap = output_logit_soft_cap
|
|
self.weight_mode = weight_mode
|
|
|
|
self.use_cache = use_cache
|
|
self.pad_token_id = pad_token_id
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
self.max_inference_chunksize = max_inference_chunksize
|
|
|
|
super().__init__(
|
|
bos_token_id=bos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
pad_token_id=pad_token_id,
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
@property
|
|
def qk_dim(self):
|
|
return round_up_to_next_multiple_of(
|
|
self.hidden_size * self.qk_dim_factor,
|
|
multiple_of=64,
|
|
)
|
|
|
|
@property
|
|
def v_dim(self):
|
|
return round_up_to_next_multiple_of(
|
|
self.hidden_size * self.v_dim_factor,
|
|
multiple_of=64,
|
|
)
|
|
|
|
@property
|
|
def qk_head_dim(self):
|
|
return self.qk_dim // self.num_heads
|
|
|
|
@property
|
|
def v_head_dim(self):
|
|
return self.v_dim // self.num_heads
|
|
|
|
def to_xlstm_block_config(self):
|
|
if external_xlstm:
|
|
return xLSTMLargeConfig(
|
|
vocab_size=self.vocab_size,
|
|
embedding_dim=self.hidden_size,
|
|
num_blocks=self.num_hidden_layers,
|
|
num_heads=self.num_heads,
|
|
use_bias=self.use_bias,
|
|
add_out_norm=self.add_out_norm,
|
|
norm_eps=self.norm_eps,
|
|
norm_reduction_force_float32=self.norm_reduction_force_float32,
|
|
# mlstm_layer
|
|
qk_dim_factor=self.qk_dim_factor,
|
|
v_dim_factor=self.v_dim_factor,
|
|
# mlstm backend
|
|
chunkwise_kernel=self.chunkwise_kernel,
|
|
sequence_kernel=self.sequence_kernel,
|
|
step_kernel=self.step_kernel,
|
|
mode=self.mode,
|
|
chunk_size=self.chunk_size,
|
|
return_last_states=self.return_last_states,
|
|
autocast_kernel_dtype=self.autocast_kernel_dtype,
|
|
eps=self.eps,
|
|
inference_state_dtype=self.inference_state_dtype,
|
|
# feedforward
|
|
ffn_proj_factor=self.ffn_proj_factor,
|
|
ffn_round_up_to_multiple_of=self.ffn_round_up_to_multiple_of,
|
|
# capping
|
|
gate_soft_cap=self.gate_soft_cap,
|
|
output_logit_soft_cap=self.output_logit_soft_cap,
|
|
weight_mode=self.weight_mode,
|
|
)
|
|
else:
|
|
return self
|
|
|
|
|
|
__all__ = ["xLSTMConfig"]
|