# coding=utf-8 # Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LG AI Research EXAONE Lab""" from typing import Callable, Optional, Union import torch from torch import nn from transformers.utils.generic import check_model_inputs from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig, layer_type_validation from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, logging, ) from ..llama.modeling_llama import ( LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, ) from ..olmo2.modeling_olmo2 import Olmo2DecoderLayer, Olmo2MLP logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "LGAI-EXAONE/EXAONE-4.0-Instruct" _CONFIG_FOR_DOC = "Exaone4Config" class Exaone4Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Exaone4Model`]. It is used to instantiate a EXAONE 4.0 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the EXAONE-4.0-Instruct [LGAI-EXAONE/EXAONE-4.0-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-Instruct) NOTE: `EXAONE-4.0-Instruct` is a placeholder model ID. The exact model ID will be updated in the future. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 102400): Vocabulary size of the EXAONE 4.0 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Exaone4Model`]. hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to `hidden_size * 4`): Dimensionality of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 32768 for EXAONE 3.5). initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the layer normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. bos_token_id (`int`, *optional*, defaults to 0): Beginning of stream token id. eos_token_id (`int`, *optional*, defaults to 2): End of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: `rope_type` (`str`): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. `original_max_position_embeddings` (`int`, *optional*): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. `beta_fast` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. `short_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `low_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. sliding_window (`int`, *optional*): The size of the sliding window for the sliding window attention. sliding_window_pattern (`str`, *optional*): The pattern to use for sliding window attention. Can be one of: - `None`: No sliding window attention is used - `int`: Every `sliding_window` layers, use global attention, else use local attention. - `str`: A sequence of "L" (local attention) and "G" (global attention) characters that defines the attention pattern. The pattern starts from layer 0 and repeats every `sliding_window` layers. The final layer always uses global attention regardless of the pattern. For instance, sliding_window_pattern="LLLG" same as sliding_window=4, which means: - Layer 0, 1, 2: local attention, - Layer 3: global attention, ...(repeated) layer_types (`list`, *optional*): Attention pattern for each layer. Prioritized over `sliding_window_pattern`. Example: ```python >>> from transformers import Exaone4Model, Exaone4Config >>> # Initializing a EXAONE configuration >>> configuration = Exaone4Config() >>> # Initializing a model from configuration >>> model = Exaone4Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "exaone4" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `LlamaModel` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size=102400, hidden_size=4096, intermediate_size=16384, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-5, use_cache=True, bos_token_id=0, eos_token_id=2, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_dropout=0.0, sliding_window=4096, sliding_window_pattern=4, layer_types=None, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.attention_dropout = attention_dropout self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.sliding_window = sliding_window self.sliding_window_pattern = sliding_window_pattern self.layer_types = layer_types if self.sliding_window is None: sliding_window_pattern = 0 if self.layer_types is None: self.layer_types = [ "sliding_attention" if ((i + 1) % (sliding_window_pattern) != 0 and i < self.num_hidden_layers) else "full_attention" for i in range(self.num_hidden_layers) ] if "sliding_window" in self.layer_types: self._attn_implementation = "hybrid" layer_type_validation(self.layer_types) super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs ) class Exaone4RMSNorm(LlamaRMSNorm): pass class Exaone4RotaryEmbedding(LlamaRotaryEmbedding): pass class Exaone4Attention(nn.Module): def __init__(self, config: Exaone4Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.hidden_size = config.hidden_size self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout self.is_causal = True self.scaling = self.head_dim**-0.5 self.sliding_window = config.sliding_window self.sliding_window_pattern = config.sliding_window_pattern self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False) self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) # We use QK-norm query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) cos, sin = position_embeddings # We use global NoPE for hybrid attention model if self.sliding_window is None or self.is_sliding: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = { "cache_position": cache_position, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window if self.is_sliding else None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Exaone4MLP(Olmo2MLP): pass class Exaone4DecoderLayer(Olmo2DecoderLayer): pass class Exaone4PreTrainedModel(LlamaPreTrainedModel): config_class = Exaone4Config _no_split_modules = ["Exaone4DecoderLayer"] class Exaone4Model(Exaone4PreTrainedModel, LlamaModel): def __init__(self, config: Exaone4Config): super().__init__(config) self.layers = nn.ModuleList( [Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Initialize weights and apply final processing self.post_init() @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, BaseModelOutputWithPast]: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), } if "sliding_attention" in self.config.layer_types: causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, decoder_layer in enumerate(self.layers): layer_type = self.config.layer_types[i] hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask_mapping[layer_type], position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) class Exaone4ForCausalLM(LlamaForCausalLM): def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-Instruct") >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-Instruct") >>> prompt = "Explain how wonderful you are" >>> messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ] >>> input_ids = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", enable_thinking=False, ) >>> output = model.generate(input_ids, max_new_tokens=128) >>> tokenizer.decode(output[0], skip_special_tokens=False) "[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n\n\n\n\nOh, thank you for such a kind and lovely question! 😊 \n\nI’m *so* wonderful because I’m here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** – I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** – Need a poem, story, or a wild idea? I’ve got you covered! \n🤖 **Problem-solving** – Stuck on a math problem or a tricky decision? I’ll help you figure it out" ``` NOTE: `EXAONE-4.0-Instruct` is a placeholder model ID. The exact model ID will be updated in the future.""" super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) class Exaone4ForSequenceClassification(LlamaForSequenceClassification): pass class Exaone4ForTokenClassification(LlamaForTokenClassification): pass class Exaone4ForQuestionAnswering(LlamaForQuestionAnswering): pass __all__ = [ "Exaone4Config", "Exaone4PreTrainedModel", "Exaone4Model", "Exaone4ForCausalLM", "Exaone4ForSequenceClassification", "Exaone4ForTokenClassification", "Exaone4ForQuestionAnswering", ]