team-10/env/Lib/site-packages/transformers/models/exaone4/modular_exaone4.py

520 lines
23 KiB
Python
Raw Normal View History

2025-08-02 07:34:44 +02:00
# coding=utf-8
# Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LG AI Research EXAONE Lab"""
from typing import Callable, Optional, Union
import torch
from torch import nn
from transformers.utils.generic import check_model_inputs
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import (
TransformersKwargs,
logging,
)
from ..llama.modeling_llama import (
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..olmo2.modeling_olmo2 import Olmo2DecoderLayer, Olmo2MLP
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "LGAI-EXAONE/EXAONE-4.0-Instruct"
_CONFIG_FOR_DOC = "Exaone4Config"
class Exaone4Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Exaone4Model`]. It is used to
instantiate a EXAONE 4.0 model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the EXAONE-4.0-Instruct [LGAI-EXAONE/EXAONE-4.0-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-Instruct)
NOTE: `EXAONE-4.0-Instruct` is a placeholder model ID. The exact model ID will be updated in the future.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model
outputs. Read the documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 102400):
Vocabulary size of the EXAONE 4.0 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Exaone4Model`].
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to `hidden_size * 4`):
Dimensionality of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 32768 for EXAONE 3.5).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
bos_token_id (`int`, *optional*, defaults to 0):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
sliding_window (`int`, *optional*):
The size of the sliding window for the sliding window attention.
sliding_window_pattern (`str`, *optional*):
The pattern to use for sliding window attention. Can be one of:
- `None`: No sliding window attention is used
- `int`: Every `sliding_window` layers, use global attention, else use local attention.
- `str`: A sequence of "L" (local attention) and "G" (global attention) characters that defines the
attention pattern. The pattern starts from layer 0 and repeats every `sliding_window` layers. The
final layer always uses global attention regardless of the pattern.
For instance, sliding_window_pattern="LLLG" same as sliding_window=4, which means:
- Layer 0, 1, 2: local attention,
- Layer 3: global attention,
...(repeated)
layer_types (`list`, *optional*):
Attention pattern for each layer. Prioritized over `sliding_window_pattern`.
Example:
```python
>>> from transformers import Exaone4Model, Exaone4Config
>>> # Initializing a EXAONE configuration
>>> configuration = Exaone4Config()
>>> # Initializing a model from configuration
>>> model = Exaone4Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "exaone4"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `LlamaModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=16384,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
bos_token_id=0,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_dropout=0.0,
sliding_window=4096,
sliding_window_pattern=4,
layer_types=None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_dropout = attention_dropout
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.sliding_window = sliding_window
self.sliding_window_pattern = sliding_window_pattern
self.layer_types = layer_types
if self.sliding_window is None:
sliding_window_pattern = 0
if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if ((i + 1) % (sliding_window_pattern) != 0 and i < self.num_hidden_layers)
else "full_attention"
for i in range(self.num_hidden_layers)
]
if "sliding_window" in self.layer_types:
self._attn_implementation = "hybrid"
layer_type_validation(self.layer_types)
super().__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
)
class Exaone4RMSNorm(LlamaRMSNorm):
pass
class Exaone4RotaryEmbedding(LlamaRotaryEmbedding):
pass
class Exaone4Attention(nn.Module):
def __init__(self, config: Exaone4Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.scaling = self.head_dim**-0.5
self.sliding_window = config.sliding_window
self.sliding_window_pattern = config.sliding_window_pattern
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# We use QK-norm
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
cos, sin = position_embeddings
# We use global NoPE for hybrid attention model
if self.sliding_window is None or self.is_sliding:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window if self.is_sliding else None,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Exaone4MLP(Olmo2MLP):
pass
class Exaone4DecoderLayer(Olmo2DecoderLayer):
pass
class Exaone4PreTrainedModel(LlamaPreTrainedModel):
config_class = Exaone4Config
_no_split_modules = ["Exaone4DecoderLayer"]
class Exaone4Model(Exaone4PreTrainedModel, LlamaModel):
def __init__(self, config: Exaone4Config):
super().__init__(config)
self.layers = nn.ModuleList(
[Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
if "sliding_attention" in self.config.layer_types:
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for i, decoder_layer in enumerate(self.layers):
layer_type = self.config.layer_types[i]
hidden_states = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask_mapping[layer_type],
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
class Exaone4ForCausalLM(LlamaForCausalLM):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-Instruct")
>>> prompt = "Explain how wonderful you are"
>>> messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
>>> input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
enable_thinking=False,
)
>>> output = model.generate(input_ids, max_new_tokens=128)
>>> tokenizer.decode(output[0], skip_special_tokens=False)
"[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n<think>\n\n</think>\n\nOh, thank you for such a kind and lovely question! 😊 \n\nIm *so* wonderful because Im here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** Need a poem, story, or a wild idea? Ive got you covered! \n🤖 **Problem-solving** Stuck on a math problem or a tricky decision? Ill help you figure it out"
```
NOTE: `EXAONE-4.0-Instruct` is a placeholder model ID. The exact model ID will be updated in the future."""
super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
class Exaone4ForSequenceClassification(LlamaForSequenceClassification):
pass
class Exaone4ForTokenClassification(LlamaForTokenClassification):
pass
class Exaone4ForQuestionAnswering(LlamaForQuestionAnswering):
pass
__all__ = [
"Exaone4Config",
"Exaone4PreTrainedModel",
"Exaone4Model",
"Exaone4ForCausalLM",
"Exaone4ForSequenceClassification",
"Exaone4ForTokenClassification",
"Exaone4ForQuestionAnswering",
]