team-10/venv/Lib/site-packages/transformers/models/xlstm/configuration_xlstm.py
2025-08-02 02:00:33 +02:00

302 lines
12 KiB
Python

# Copyright 2025 NXAI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""xLSTM configuration."""
from typing import Optional
from ...configuration_utils import PretrainedConfig
from ...utils import is_xlstm_available, logging
if is_xlstm_available():
from xlstm.xlstm_large.model import (
BackendModeType,
ChunkwiseKernelType,
DtypeType,
SequenceKernelType,
StepKernelType,
WeightModeType,
round_up_to_next_multiple_of,
xLSTMLargeConfig,
)
external_xlstm = True
else:
from typing import Literal
BackendModeType = Literal["train", "train_with_padding", "inference"]
ChunkwiseKernelType = Literal[
"chunkwise--native_autograd",
"parallel--native_autograd",
]
DtypeType = Literal["float32", "bfloat16", "float16"]
SequenceKernelType = Literal["native_sequence__native"]
StepKernelType = Literal["native"]
WeightModeType = Literal["single", "fused"]
def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
"""Rounds up x to the next multiple of multiple_of."""
return int(((x + multiple_of - 1) // multiple_of) * multiple_of)
external_xlstm = False
logger = logging.get_logger(__name__)
class xLSTMConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`xLSTM`]. It is used to instantiate a xLSTM
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the xLSTM-7b [NX-AI/xLSTM-7b](https://huggingface.co/NX-AI/xLSTM-7b) model.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (int, optional, *optional*, defaults to 50304):
Vocabulary size of the xLSTM model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`xLSTMModel`]. Defaults to the GPT2-NeoX tokenizer size.
hidden_size (int, optional, *optional*, defaults to 4096):
Dimensionality of the embeddings or hidden states.
embedding_dim (int, optional, *optional*, defaults to 4096):
Dimensionality of the embeddings or hidden states, use hidde_size if None.
num_hidden_layers (int, optional, *optional*, defaults to 32):
Number of blocks of the xLSTM model.
num_blocks (int, optional, *optional*, defaults to 32):
Number of blocks of the xLSTM model, use num_hidden_layers if None.
num_heads (int, optional, *optional*, defaults to 8):
Number of heads for the xLSTM Layer/Cell.
use_bias (bool, optional, *optional*, defaults to `False`):
Whether to use biases in the xLSTM model.
norm_reduction_force_float32 (bool, optional, *optional*, defaults to `True`):
Whether to force the float32 norm reduction op to be done in fp32 precision.
tie_word_embeddings (bool, optional, *optional*, defaults to `False`):
Whether to tie word embeddings to the lm head weights.
add_out_norm (bool, optional, *optional*, defaults to `True`):
Whether to add an output norm after the blocks before the LMHead.
norm_eps (float, optional, *optional*, defaults to 1e-06):
Norm eps for RMSNorm and Layer Norm.
qk_dim_factor (float, optional, *optional*, defaults to 0.5):
Scale factor for the query and key dimension.
v_dim_factor (float, optional, *optional*, defaults to 1.0):
Scale factor for the value dimension.
chunkwise_kernel (ChunkwiseKernelType, optional, *optional*, defaults to `"chunkwise--native_autograd"`):
Kernel type for chunkwise processing mode.
sequence_kernel (SequenceKernelType, optional, *optional*, defaults to `"native_sequence__native"`):
Kernel type for sequence processing mode.
step_kernel (StepKernelType, optional, *optional*, defaults to `"native"`):
Kernel type for step processing mode.
mode (BackendModeType, optional, *optional*, defaults to `"inference"`):
Operation mode (inference is needed for generation).
chunk_size (int, optional, *optional*, defaults to 64):
Internal chunk size.
return_last_states (bool, optional, *optional*, defaults to `True`):
If to return the last states / cache internally. Needed as True for generation.
autocast_kernel_dtype (DtypeType, optional, *optional*, defaults to `"bfloat16"`):
Kernel dtype for the states.
eps (float, optional, *optional*, defaults to 1e-06):
Epsilon for the mLSTM cell post norm.
inference_state_dtype (DtypeType, optional, *optional*, defaults to `"float32"`):
Kernel dtype for states in inference.
ffn_proj_factor (float, optional, *optional*, defaults to 2.667):
Size factor of the post-up projection gated Feed Forward network.
ffn_round_up_to_multiple_of (int, optional, *optional*, defaults to 64):
Size factor round value of the post-up projection gated Feed Forward network.
gate_soft_cap (float, optional, *optional*, defaults to 15.0):
Gate soft cap scale.
output_logit_soft_cap (float, optional, *optional*, defaults to 30.0):
Output logit soft cap scale.
weight_mode (`Literal`, *optional*, defaults to `"single"`):
Whether parallel linear layers are separated or fused (single).
use_cache (bool, optional, *optional*, defaults to `True`):
Whether to use the cache (xLSTMCache).
pad_token_id (int, optional, *optional*, defaults to 1):
Pad token id needed for generation.
bos_token_id (int, optional, *optional*, defaults to 0):
BOS token id needed for generation.
eos_token_id (int, optional, *optional*, defaults to 2):
EOS token id needed for generation.
max_inference_chunksize (int, optional, *optional*, defaults to 16384):
Limit the chunk size for inference to save memory.
Example:
```python
>>> from transformers import xLSTMConfig, xLSTMModel
>>> # Initializing a xLSTM configuration
>>> configuration = xLSTMConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = xLSTMModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "xlstm"
def __init__(
self,
vocab_size: int = 50304,
hidden_size: int = 4096,
embedding_dim: Optional[int] = None,
num_hidden_layers: Optional[int] = 32,
num_blocks: Optional[int] = None,
num_heads: int = 8,
use_bias: bool = False,
norm_reduction_force_float32: bool = True,
tie_word_embeddings: bool = False,
add_out_norm: bool = True,
norm_eps: float = 1e-6,
# mlstm_layer
qk_dim_factor: float = 0.5,
v_dim_factor: float = 1.0,
# mlstm backend
chunkwise_kernel: ChunkwiseKernelType = "chunkwise--native_autograd",
sequence_kernel: SequenceKernelType = "native_sequence__native",
step_kernel: StepKernelType = "native",
# nedded to enable generation
mode: BackendModeType = "inference",
chunk_size: int = 64,
# needed to be true for generation
return_last_states: bool = True,
autocast_kernel_dtype: DtypeType = "bfloat16",
eps: float = 1e-6,
inference_state_dtype: DtypeType = "float32",
# feedforward
ffn_proj_factor: float = 2.667,
ffn_round_up_to_multiple_of: int = 64,
# capping
gate_soft_cap: float = 15.0,
output_logit_soft_cap: float = 30.0,
# weights
weight_mode: WeightModeType = "single",
# HF interface
use_cache: bool = True,
pad_token_id: int = 1,
bos_token_id: int = 0,
eos_token_id: int = 2,
max_inference_chunksize: int = 16384,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size if hidden_size is not None else embedding_dim
self.embedding_dim = embedding_dim if embedding_dim is not None else hidden_size
self.num_hidden_layers = num_hidden_layers if num_hidden_layers is not None else num_blocks
self.num_blocks = num_blocks if num_blocks is not None else num_hidden_layers
self.num_heads = num_heads
self.use_bias = use_bias
self.tie_word_embeddings = tie_word_embeddings
self.add_out_norm = add_out_norm
self.norm_eps = norm_eps
self.norm_reduction_force_float32 = norm_reduction_force_float32
# mlstm_layer
self.qk_dim_factor = qk_dim_factor
self.v_dim_factor = v_dim_factor
# mlstm backend
self.chunkwise_kernel = chunkwise_kernel
self.sequence_kernel = sequence_kernel
self.step_kernel = step_kernel
self.mode = mode
self.chunk_size = chunk_size
self.return_last_states = return_last_states
self.autocast_kernel_dtype = autocast_kernel_dtype
self.eps = eps
self.inference_state_dtype = inference_state_dtype
# feedforward
self.ffn_proj_factor = ffn_proj_factor
self.ffn_round_up_to_multiple_of = ffn_round_up_to_multiple_of
# capping
self.gate_soft_cap = gate_soft_cap
self.output_logit_soft_cap = output_logit_soft_cap
self.weight_mode = weight_mode
self.use_cache = use_cache
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.max_inference_chunksize = max_inference_chunksize
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def qk_dim(self):
return round_up_to_next_multiple_of(
self.hidden_size * self.qk_dim_factor,
multiple_of=64,
)
@property
def v_dim(self):
return round_up_to_next_multiple_of(
self.hidden_size * self.v_dim_factor,
multiple_of=64,
)
@property
def qk_head_dim(self):
return self.qk_dim // self.num_heads
@property
def v_head_dim(self):
return self.v_dim // self.num_heads
def to_xlstm_block_config(self):
if external_xlstm:
return xLSTMLargeConfig(
vocab_size=self.vocab_size,
embedding_dim=self.hidden_size,
num_blocks=self.num_hidden_layers,
num_heads=self.num_heads,
use_bias=self.use_bias,
add_out_norm=self.add_out_norm,
norm_eps=self.norm_eps,
norm_reduction_force_float32=self.norm_reduction_force_float32,
# mlstm_layer
qk_dim_factor=self.qk_dim_factor,
v_dim_factor=self.v_dim_factor,
# mlstm backend
chunkwise_kernel=self.chunkwise_kernel,
sequence_kernel=self.sequence_kernel,
step_kernel=self.step_kernel,
mode=self.mode,
chunk_size=self.chunk_size,
return_last_states=self.return_last_states,
autocast_kernel_dtype=self.autocast_kernel_dtype,
eps=self.eps,
inference_state_dtype=self.inference_state_dtype,
# feedforward
ffn_proj_factor=self.ffn_proj_factor,
ffn_round_up_to_multiple_of=self.ffn_round_up_to_multiple_of,
# capping
gate_soft_cap=self.gate_soft_cap,
output_logit_soft_cap=self.output_logit_soft_cap,
weight_mode=self.weight_mode,
)
else:
return self
__all__ = ["xLSTMConfig"]