182 lines
8.9 KiB
Python
182 lines
8.9 KiB
Python
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""OLMoE model configuration"""
|
|
|
|
from ...configuration_utils import PretrainedConfig
|
|
from ...modeling_rope_utils import rope_config_validation
|
|
|
|
|
|
class OlmoeConfig(PretrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`OlmoeModel`]. It is used to instantiate an OLMoE
|
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
defaults will yield a similar configuration to that of the [allenai/OLMoE-1B-7B-0924](https://huggingface.co/allenai/OLMoE-1B-7B-0924).
|
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PretrainedConfig`] for more information.
|
|
|
|
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 50304):
|
|
Vocabulary size of the OLMoE model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`OlmoeModel`]
|
|
hidden_size (`int`, *optional*, defaults to 2048):
|
|
Dimension of the hidden representations.
|
|
intermediate_size (`int`, *optional*, defaults to 2048):
|
|
Dimension of the MLP representations.
|
|
num_hidden_layers (`int`, *optional*, defaults to 16):
|
|
Number of hidden layers in the Transformer decoder.
|
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
|
Number of attention heads for each attention layer in the Transformer decoder.
|
|
num_key_value_heads (`int`, *optional*):
|
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
by meanpooling all the original heads within that group. For more details, check out [this
|
|
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
|
`num_attention_heads`.
|
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
The non-linear activation function (function or string) in the decoder.
|
|
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
|
The maximum sequence length that this model might ever be used with.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
The epsilon used by the rms normalization layers.
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
relevant if `config.is_decoder=True`.
|
|
pad_token_id (`int`, *optional*, defaults to 1):
|
|
Padding token id.
|
|
bos_token_id (`int`, *optional*):
|
|
Beginning of stream token id.
|
|
eos_token_id (`int`, *optional*, defaults to 50279):
|
|
End of stream token id.
|
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
Whether to tie weight embeddings
|
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
|
The base period of the RoPE embeddings.
|
|
rope_scaling (`Dict`, *optional*):
|
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
|
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
|
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
|
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
|
these scaling strategies behave:
|
|
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
|
experimental feature, subject to breaking API changes in future versions.
|
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the attention probabilities.
|
|
clip_qkv (`float`, *optional*):
|
|
If not `None`, elements of query, key and value attention states are clipped so that their
|
|
absolute value does not exceed this value.
|
|
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
|
Number of selected experts.
|
|
num_experts (`int`, *optional*, defaults to 64):
|
|
Number of routed experts.
|
|
output_router_logits (`bool`, *optional*, defaults to `False`):
|
|
Whether or not the router logits should be returned by the model. Enabling this will also
|
|
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
|
router_aux_loss_coef (`float`, *optional*, defaults to 0.01):
|
|
The aux loss factor for the total loss.
|
|
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
|
Whether to normalize the topk probabilities.
|
|
|
|
```python
|
|
>>> from transformers import OlmoeModel, OlmoeConfig
|
|
|
|
>>> # Initializing a OLMoE 7B A1B style configuration
|
|
>>> configuration = OlmoeConfig()
|
|
|
|
>>> # Initializing a model from the OLMoE 7B A1B style configuration
|
|
>>> model = OlmoeModel(configuration)
|
|
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
model_type = "olmoe"
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size=50304,
|
|
hidden_size=2048,
|
|
intermediate_size=2048,
|
|
num_hidden_layers=16,
|
|
num_attention_heads=16,
|
|
num_key_value_heads=None,
|
|
hidden_act="silu",
|
|
max_position_embeddings=4096,
|
|
initializer_range=0.02,
|
|
rms_norm_eps=1e-05,
|
|
use_cache=True,
|
|
pad_token_id=1,
|
|
bos_token_id=None,
|
|
eos_token_id=50279,
|
|
tie_word_embeddings=False,
|
|
rope_theta=10000.0,
|
|
rope_scaling=None,
|
|
attention_bias=False,
|
|
attention_dropout=0.0,
|
|
clip_qkv=None,
|
|
num_experts_per_tok=8,
|
|
num_experts=64,
|
|
output_router_logits=False,
|
|
router_aux_loss_coef=0.01,
|
|
norm_topk_prob=False,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
|
|
# for backward compatibility
|
|
if num_key_value_heads is None:
|
|
num_key_value_heads = num_attention_heads
|
|
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.hidden_act = hidden_act
|
|
self.initializer_range = initializer_range
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.use_cache = use_cache
|
|
self.rope_theta = rope_theta
|
|
self.rope_scaling = rope_scaling
|
|
self.attention_bias = attention_bias
|
|
self.attention_dropout = attention_dropout
|
|
self.clip_qkv = clip_qkv
|
|
self.num_experts_per_tok = num_experts_per_tok
|
|
self.num_experts = num_experts
|
|
self.output_router_logits = output_router_logits
|
|
self.router_aux_loss_coef = router_aux_loss_coef
|
|
self.norm_topk_prob = norm_topk_prob
|
|
# Validate the correctness of rotary position embeddings parameters
|
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
|
rope_config_validation(self)
|
|
|
|
super().__init__(
|
|
pad_token_id=pad_token_id,
|
|
bos_token_id=bos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
__all__ = ["OlmoeConfig"]
|