# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_mlcd.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Union import torch import torch.nn as nn from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, torch_int from .configuration_mlcd import MLCDVisionConfig class MLCDMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class MLCDRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor: """ Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size. Args: num_patches_height (int): Number of patches in the height dimension. num_patches_width (int): Number of patches in the width dimension. Returns: torch.Tensor: Rotary positional embeddings for the given grid size. """ # Generate position IDs for height and width dimensions hpos_ids = ( torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width) ) wpos_ids = ( torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1) ) # Flatten and stack the position IDs pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) # Generate the full rotary positional embeddings for the maximum grid size max_grid_size = max(num_patches_height, num_patches_width) seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype) rotary_pos_emb_full = torch.outer(seq, self.inv_freq) # Select and flatten the embeddings based on the position IDs rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb class MLCDVisionEmbeddings(nn.Module): def __init__(self, config: MLCDVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False, ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 position_embedding = self.position_embedding.weight.unsqueeze(0) num_positions = position_embedding.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embedding(self.position_ids) class_pos_embed = position_embedding[:, :1] patch_pos_embed = position_embedding[:, 1:] dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: batch_size = pixel_values.shape[0] target_dtype = self.patch_embedding.weight.dtype # patch_embeds -> shape = [batch, width, grid, grid] patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) return embeddings def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) q_embed = q_embed.to(orig_q_dtype) k_embed = k_embed.to(orig_k_dtype) return q_embed, k_embed class MLCDAttention(nn.Module): """Multi-headed attention with RoPE. Refer to papers: - Attention is all you need: https://huggingface.co/papers/1706.03762 - RoFormer: Enhanced Transformer with Rotary Position Embedding: https://huggingface.co/papers/2104.09864 """ def __init__(self, config: MLCDVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.num_key_value_groups = config.num_key_value_groups def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, seq_length = hidden_states.shape[:-1] # Each of shape: [batch_size, seq_length, num_heads, head_dim] query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) # Apply positional embeddings cos = position_embeddings[0].unsqueeze(0).float() sin = position_embeddings[1].unsqueeze(0).float() query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) # Each of shape: [batch_size, num_heads, seq_length, head_dim] query_states = query_states.permute(0, 2, 1, 3).contiguous() key_states = key_states.permute(0, 2, 1, 3).contiguous() value_states = value_states.permute(0, 2, 1, 3).contiguous() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.dropout, scaling=self.scale, is_causal=self.is_causal, **kwargs, ) attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim] attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim] attn_output = self.out_proj(attn_output) attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim] return attn_output, attn_weights class MLCDEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MLCDVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = MLCDAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = MLCDMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. Represents the hidden states from the previous layer or the input embeddings. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`): A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`. Represents absolute positional embeddings for the query and key in the attention mechanism. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class MLCDEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`MLCDEncoderLayer`]. Args: config: MLCDVisionConfig """ def __init__(self, config: MLCDVisionConfig): """Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`.""" super().__init__() self.config = config self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( self, inputs_embeds: torch.FloatTensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutput]: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`): A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`. Represents absolute positional embeddings for the query and key in the attention mechanism. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) layer_outputs = encoder_layer( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) class MLCDVisionTransformer(nn.Module): def __init__(self, config: MLCDVisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = MLCDVisionEmbeddings(config) self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = MLCDEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2)) @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPooling]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions if pixel_values is None: raise ValueError("You have to specify pixel_values") num_patches_height = pixel_values.shape[-2] // self.config.patch_size num_patches_width = pixel_values.shape[-1] // self.config.patch_size rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width) rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device) rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( inputs_embeds=hidden_states, position_embeddings=position_embeddings, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @auto_docstring class MLCDPreTrainedModel(PreTrainedModel): config: MLCDVisionConfig base_model_prefix = "mlcd" supports_gradient_checkpointing = True _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, MLCDVisionEmbeddings): factor = self.config.initializer_factor nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, MLCDAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor nn.init.normal_(module.q_proj.weight, std=in_proj_std) nn.init.normal_(module.k_proj.weight, std=in_proj_std) nn.init.normal_(module.v_proj.weight, std=in_proj_std) nn.init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, MLCDMLP): factor = self.config.initializer_factor in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, MLCDVisionTransformer): factor = self.config.initializer_factor pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @auto_docstring( custom_intro=""" The vision model from M_L_C_D without any head or projection on top. """ ) class MLCDVisionModel(MLCDPreTrainedModel): config: MLCDVisionConfig main_input_name = "pixel_values" _no_split_modules = ["MLCDEncoderLayer"] def __init__(self, config: MLCDVisionConfig): super().__init__(config) self.vision_model = MLCDVisionTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPooling]: r""" Example: ```python >>> import requests >>> from PIL import Image >>> from transformers import AutoProcessor, MLCDVisionModel >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs, output_attentions=True) >>> features = outputs.last_hidden_state >>> print(f"Extracted features shape: {features.shape}") >>> print(f"Number of attention layers: {len(outputs.attentions)}") >>> print(f"Attention shape: {outputs.attentions[0].shape}") ```""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) __all__ = ["MLCDPreTrainedModel", "MLCDVisionModel"]