team-10/venv/Lib/site-packages/transformers/models/gemma3n/configuration_gemma3n.py

681 lines
36 KiB
Python
Raw Permalink Normal View History

2025-08-02 02:00:33 +02:00
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma3n.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any, Optional, Union
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...modeling_rope_utils import rope_config_validation
from ...utils import is_timm_available, logging, requires_backends
if is_timm_available():
from timm.data import ImageNetInfo, infer_imagenet_subset
logger = logging.get_logger(__name__)
class Gemma3nTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an
Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.
[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read
the documentation from [`Gemma3nTextConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 262400):
Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by
the `inputs_ids` passed when calling [`Gemma3nTextModel`]
vocab_size_per_layer_input (`int`, *optional*, defaults to 262144):
Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
hidden_size_per_layer_input (`int`, *optional*, defaults to 256):
Dimension of the hidden representations for per-layer emebeddings.
intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384):
Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers
to account for vairable intermediate_size values across layers. In such cases,
`len(intermediate_size) == num_hidden_layers`.
num_hidden_layers (`int`, *optional*, defaults to 35):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout this
[paper](https://arxiv.org/pdf/2305.13245.pdf). If not specified, will default to `num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. Will default to
`"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
activation function.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention.
NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we
recommend you to update this value accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
sliding_window (`int`, *optional*, defaults to 512):
This is the size of the sliding window used by local attention layers.
layer_types (`Optional`, *optional*):
A sequence of strings defining the attention type for that layer as either "sliding_attention" or
"full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern
of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always
be a "full_attention" layer.
final_logit_softcapping (`float`, *optional*, defaults to 30.0):
Scaling factor when applying tanh softcapping on the logits.
altup_active_idx (`int`, *optional*, defaults to 0):
The index of the prediction from which AltUp will compute additional predictions or correct
altup_coef_clip (`float`, *optional*, defaults to 120.0):
The maximum amplitude of an AltUp prediction or correction coeficient weight.
altup_correct_scale (`bool`, *optional*, defaults to `True`):
If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`.
altup_num_inputs (`int`, *optional*, defaults to 4):
The number of predictions that AltUp should be make given the input sequence.
num_kv_shared_layers (`int`, *optional*, defaults to 15):
The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
layers in the model "share" the KV values in that each local and global layer in this range uses the KV
cache values computed for the last local or global layer, respectively, before entering this range. The
value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`.
laurel_rank (int, *optional*, defaults to 64):
The intermediate size for the linear projections in the Learned Augmented Residual Layer.
activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`):
The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
explicitly provide a sparsity value for each layer in the model.
```python
>>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
>>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
>>> configuration = Gemma3nTextConfig()
>>> # Initializing a model from the gemma3n_text-E4B style configuration
>>> model = Gemma3nTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "gemma3n_text"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size: int = 262_400,
vocab_size_per_layer_input: int = 262_144,
hidden_size: int = 2048,
hidden_size_per_layer_input: int = 256,
intermediate_size: Union[int, Sequence[int]] = 16_384,
num_hidden_layers: int = 35,
num_attention_heads: int = 8,
num_key_value_heads: int = 2,
head_dim: int = 256,
hidden_activation: str = "gelu_pytorch_tanh",
max_position_embeddings: int = 32_768,
initializer_range: float = 0.02,
rms_norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = 0,
eos_token_id: int = 1,
bos_token_id: int = 2,
rope_theta: float = 1_000_000.0,
rope_scaling: Optional[dict[str, Any]] = None,
rope_local_base_freq: float = 10_000.0,
attention_bias: bool = False,
attention_dropout: float = 0.0,
sliding_window: int = 512,
layer_types: Optional[Sequence[str]] = None,
final_logit_softcapping: float = 30.0,
altup_active_idx: int = 0,
altup_coef_clip: float = 120.0,
altup_correct_scale: bool = True,
altup_num_inputs: int = 4,
num_kv_shared_layers: int = 15,
laurel_rank: int = 64,
activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers:
raise ValueError(
"intermediate_size must have an explicit intermediate size for every layer or one for all layers. "
f"Expected {num_hidden_layers} values but got {intsize_len}."
)
elif not isinstance(intermediate_size, Sequence):
intermediate_size = [intermediate_size] * num_hidden_layers
self.vocab_size = vocab_size
self.vocab_size_per_layer_input = vocab_size_per_layer_input
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.layer_types = layer_types
self.rope_local_base_freq = rope_local_base_freq
self.rope_scaling = rope_scaling
rope_config_validation(self)
if layer_types is None:
self.layer_types = [
"full_attention" if (i + 1) % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
]
else:
self.layer_types = layer_types
layer_type_validation(self.layer_types)
self.hidden_size_per_layer_input = hidden_size_per_layer_input
self.num_kv_shared_layers = num_kv_shared_layers
self.altup_active_idx = altup_active_idx
self.altup_coef_clip = altup_coef_clip
self.altup_correct_scale = altup_correct_scale
self.altup_num_inputs = altup_num_inputs
self.laurel_rank = laurel_rank
if activation_sparsity_pattern is None:
activation_sparsity_pattern = [0.0] * num_hidden_layers
if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
raise ValueError(
"activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
f"Expected {num_hidden_layers} values but got {len_asp}."
)
self.activation_sparsity_pattern = activation_sparsity_pattern
class Gemma3nAudioConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`]. It is used to instantiate
an `Gemma3nAudioEncoder` model according to the specified arguments, defining the model architecture. Instantiating
a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.,
[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
the documentation from [`Gemma3nAudioConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 128):
Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings
included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder
tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model.
vocab_offset (`int`, *optional*, defaults to 262272):
Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
input_feat_size (`int`, *optional*, defaults to 128):
The number of channels in each mel-spectrogram frame.
hidden_size (`int`, *optional*, defaults to 1536):
Dimension of the hidden representations.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
Clipping value used to stablize extremely large gradient values.
conf_attention_chunk_size (`int`, *optional*, defaults to 12):
The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_context_left (`int`, *optional*, defaults to 13):
The left context size of the local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_context_right (`int`, *optional*, defaults to 0):
The right context size of the local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
Logit cap applied during local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_num_attention_heads (`int`, *optional*, defaults to 8):
The number of attention heads in local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_num_hidden_layers (`int`, *optional*, defaults to 12):
The number of layers that use local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_conv_kernel_size (`int`, *optional*, defaults to 5):
Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
Universal Speech Model.
conf_reduction_factor (`int`, *optional*, defaults to 4):
Reduction factor used in the conformer block inside the Conformer ("conf") section of the
Universal Speech Model.
conf_residual_weight (`float`, *optional*, defaults to 0.5):
Residual connection weight inside the Conformer ("conf") section of the
Universal Speech Model.
sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
("sscp") section of the Universal Speech Model.
sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001):
Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution
Projection ("sscp") section of the Universal Speech Model.
sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`):
Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a
tuple of height and width for each layer, where the height corresponds to the time dimension and the width
corresponds to the frequency dimension.
sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`):
Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a
tuple of height and width for each layer, where the height corresponds to the time dimension and the width
corresponds to the frequency dimension.
Example:
```python
>>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder
>>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
>>> configuration = Gemma3nAudioConfig()
>>> # Initializing a model from the gemma3n_audio-E4B style configuration
>>> model = Gemma3nAudioEncoder(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "gemma3n_audio"
def __init__(
self,
vocab_size: int = 128,
vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size
input_feat_size: int = 128,
hidden_size: int = 1536,
rms_norm_eps: float = 1e-6,
gradient_clipping: float = 10_000_000_000.0,
conf_attention_chunk_size: int = 12,
conf_attention_context_left: int = 13,
conf_attention_context_right: int = 0,
conf_attention_logit_cap: float = 50.0,
conf_num_attention_heads: int = 8,
conf_num_hidden_layers: int = 12,
conf_conv_kernel_size: int = 5,
conf_reduction_factor: int = 4,
conf_residual_weight: float = 0.5,
sscp_conv_channel_size: tuple[int, int] = (128, 32),
sscp_conv_group_norm_eps: float = 1e-3,
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
(3, 3),
(3, 3),
),
sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = (
(2, 2),
(2, 2),
),
**kwargs,
):
super().__init__(**kwargs)
self.input_feat_size = input_feat_size
self.hidden_size = hidden_size
self.rms_norm_eps = rms_norm_eps
self.vocab_size = vocab_size
self.vocab_offset = vocab_offset
self.gradient_clipping = gradient_clipping
self.conf_attention_chunk_size = conf_attention_chunk_size
self.conf_attention_context_left = conf_attention_context_left
self.conf_attention_context_right = conf_attention_context_right
self.conf_attention_logit_cap = conf_attention_logit_cap
self.conf_num_attention_heads = conf_num_attention_heads
self.conf_num_hidden_layers = conf_num_hidden_layers
self.conf_conv_kernel_size = conf_conv_kernel_size
self.conf_reduction_factor = conf_reduction_factor
self.conf_residual_weight = conf_residual_weight
self.sscp_conv_channel_size = sscp_conv_channel_size
self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
self.sscp_conv_kernel_size = sscp_conv_kernel_size
self.sscp_conv_stride_size = sscp_conv_stride_size
class Gemma3nVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to
instantiate an timm model model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B
vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the
documentation from [`Gemma3nVisionConfig`] for more information.
Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
imagenet models is set to `None` due to occlusions in the label descriptions.
Args:
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
do_pooling (`bool`, *optional*, defaults to `False`):
Whether to do pooling for the last_hidden_state in `TimmWrapper` or not.
architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`):
Determines vision architecture for TimmWrapper.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
vocab_size (`int`, *optional*, defaults to 128):
Vocabulary size of the additional hard-token embeddings for vision model.
vocab_offset (`int`, *optional*, defaults to 262144):
Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
Example:
```python
>>> from transformers import Gemma3nVisionConfig, TimmWrapper
>>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
>>> configuration = Gemma3nVisionConfig()
>>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
>>> model = TimmWrapper(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "gemma3n_vision"
def __init__(
self,
initializer_range: float = 0.02,
do_pooling: bool = False,
architecture: str = "mobilenetv5_300m_enc",
hidden_size: int = 2048,
vocab_size: int = 128,
vocab_offset: int = 262_144,
rms_norm_eps: float = 1e-06,
model_args: Optional[dict] = None,
**kwargs,
):
super().__init__(**kwargs)
self.initializer_range = initializer_range
self.do_pooling = do_pooling
self.model_args = model_args # named "model_args" for BC with timm
self.architecture = architecture
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.vocab_offset = vocab_offset
self.rms_norm_eps = rms_norm_eps
@classmethod
def from_dict(cls, config_dict: dict[str, Any], **kwargs):
label_names = config_dict.get("label_names", None)
is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
# if no labels added to config, use imagenet labeller in timm
if label_names is None and not is_custom_model:
requires_backends(cls, ["timm"])
imagenet_subset = infer_imagenet_subset(config_dict)
if imagenet_subset:
dataset_info = ImageNetInfo(imagenet_subset)
synsets = dataset_info.label_names()
label_descriptions = dataset_info.label_descriptions(as_dict=True)
label_names = [label_descriptions[synset] for synset in synsets]
if label_names is not None and not is_custom_model:
kwargs["id2label"] = dict(enumerate(label_names))
# if all label names are unique, create label2id mapping as well
if len(set(label_names)) == len(label_names):
kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
else:
kwargs["label2id"] = None
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
# and to avoid duplicate attributes
num_labels_in_kwargs = kwargs.pop("num_labels", None)
num_labels_in_dict = config_dict.pop("num_classes", None)
# passed num_labels has priority over num_classes in config_dict
kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
# pop num_classes from "pretrained_cfg",
# it is not necessary to have it, only root one is used in timm
if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
config_dict["pretrained_cfg"].pop("num_classes", None)
return super().from_dict(config_dict, **kwargs)
def to_dict(self) -> dict[str, Any]:
output = super().to_dict()
output["num_classes"] = self.num_labels
output["label_names"] = list(self.id2label.values())
output.pop("id2label", None)
output.pop("label2id", None)
return output
class Gemma3nConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to
instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
Gemma3n-E4B.
e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[Gemma3nTextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
audio_config (`Union[AutoConfig, dict]`, *optional*):
Custom audio config or dict.
audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
The number of soft tokens per audio clip.
vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
The number of soft tokens per image.
boi_token_id (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_id (`int`, *optional*, defaults to 262144):
The end-of-image token index to wrap the image prompt.
image_token_id (`int`, *optional*, defaults to 262145):
The image token index to encode the image prompt.
boa_token_id (`int`, *optional*, defaults to 256000):
The begin-of-audio token index to wrap the audio prompt.
eoa_token_id (`int`, *optional*, defaults to 262272):
The end-of-audio token index to wrap the audio prompt.
audio_token_id (`int`, *optional*, defaults to 262273):
The audio token index to encode the audio prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig
>>> # Initializing a MobileNet vision config, which is loaded from TIMM
>>> vision_config = Gemma3nVisionConfig()
>>> # Initializing a Gemma3n Audio config
>>> audio_config = Gemma3nAudioConfig()
>>> # Initializing a Gemma3n Text config
>>> text_config = Gemma3nTextConfig()
>>> # Initializing a Gemma3n gemma-3-4b style configuration
>>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3nTextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3n"
sub_configs = {
"text_config": Gemma3nTextConfig,
"vision_config": Gemma3nVisionConfig,
"audio_config": Gemma3nAudioConfig,
}
def __init__(
self,
text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
audio_soft_tokens_per_image: int = 188,
vision_soft_tokens_per_image: int = 256,
boi_token_id: int = 255_999,
eoi_token_id: int = 262_144,
image_token_id: int = 262_145,
boa_token_id: int = 256_000,
eoa_token_id: int = 262_272,
audio_token_id: int = 262_273,
initializer_range: float = 0.02,
**kwargs,
):
super().__init__(**kwargs)
if isinstance(text_config, dict):
text_config = Gemma3nTextConfig(**text_config)
elif text_config is None:
text_config = Gemma3nTextConfig()
logger.info("text_config is None. Using default Gemma3nTextConfig.")
if isinstance(vision_config, dict):
vision_config = Gemma3nVisionConfig(**vision_config)
elif vision_config is None:
vision_config = Gemma3nVisionConfig()
logger.info("vision_config is None. Using default Gemma3nVisionConfig.")
if isinstance(audio_config, dict):
audio_config = Gemma3nAudioConfig(**audio_config)
elif audio_config is None:
audio_config = Gemma3nAudioConfig()
logger.info("audio_config is None. Using default Gemma3nAudioConfig.")
self.text_config = text_config
self.vision_config = vision_config
self.audio_config = audio_config
self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
self.boi_token_id = boi_token_id
self.eoi_token_id = eoi_token_id
self.image_token_id = image_token_id
self.boa_token_id = boa_token_id
self.eoa_token_id = eoa_token_id
self.audio_token_id = audio_token_id
self.initializer_range = initializer_range
__all__ = ["Gemma3nAudioConfig", "Gemma3nConfig", "Gemma3nTextConfig", "Gemma3nVisionConfig"]