# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig from transformers.models.siglip.modeling_siglip import ( BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, SiglipForImageClassification, SiglipModel, SiglipMultiheadAttentionPoolingHead, SiglipOutput, SiglipPreTrainedModel, SiglipTextModel, SiglipTextModelOutput, SiglipVisionModel, SiglipVisionModelOutput, SiglipVisionTransformer, ) from ...modeling_attn_mask_utils import _prepare_4d_attention_mask class Siglip2TextConfig(SiglipTextConfig): pass class Siglip2VisionConfig(SiglipVisionConfig): r""" This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2 [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. num_channels (`int`, *optional*, defaults to 3): Number of channels in the input images. num_patches (`int`, *optional*, defaults to 256): The number of patches in the image with the size of (`patch_size`, `patch_size`). The image is resized to fill maximum of this number of patches, and to preserve the aspect ratio. In case the resulted number of patches is lower, the image is padded in "patch" dimension. patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. Example: ```python >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration >>> configuration = Siglip2VisionConfig() >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration >>> model = Siglip2VisionModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" def __init__( self, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, num_channels=3, num_patches=256, patch_size=16, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, **kwargs, ): super().__init__(**kwargs) self.num_patches = num_patches del self.image_size class Siglip2Config(SiglipConfig): pass class Siglip2VisionOutput(SiglipVisionModelOutput): pass class Siglip2TextOutput(SiglipTextModelOutput): pass class Siglip2Output(SiglipOutput): pass class Siglip2VisionEmbeddings(nn.Module): def __init__(self, config: Siglip2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.patch_size = config.patch_size self.patch_embedding = nn.Linear( in_features=config.num_channels * self.patch_size * self.patch_size, out_features=self.embed_dim, ) self.num_patches = config.num_patches self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) @staticmethod def resize_positional_embeddings( positional_embeddings: torch.Tensor, spatial_shapes: torch.LongTensor, max_length: int, ) -> torch.Tensor: """ Resize positional embeddings to image-specific size and pad to a fixed size. Args: positional_embeddings (`torch.Tensor`): Position embeddings of shape (height, width, embed_dim) spatial_shapes (`torch.LongTensor`): Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to max_length (`int`): Maximum length of the positional embeddings to pad resized positional embeddings to Returns: `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) """ batch_size = spatial_shapes.shape[0] embed_dim = positional_embeddings.shape[-1] source_dtype = positional_embeddings.dtype resulted_positional_embeddings = torch.empty( (batch_size, max_length, embed_dim), device=positional_embeddings.device, dtype=source_dtype, ) # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU if positional_embeddings.device.type == "cpu": positional_embeddings = positional_embeddings.to(torch.float32) for i in range(batch_size): # (1, dim, height, width) -> (1, dim, target_height, target_width) height, width = spatial_shapes[i] resized_embeddings = F.interpolate( positional_embeddings, size=(height, width), mode="bilinear", align_corners=False, antialias=True, ) # (1, dim, target_height, target_width) -> (target_height * target_width, dim) resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) # Cast to original dtype resized_embeddings = resized_embeddings.to(source_dtype) resulted_positional_embeddings[i, : height * width] = resized_embeddings resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] return resulted_positional_embeddings def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor`): Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) spatial_shapes (`list[tuple[int, int]]`): Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to """ # Apply patch embeddings to already patchified pixel values target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # Get positional resized and padded positional embeddings positional_embeddings = self.position_embedding.weight.reshape( self.position_embedding_size, self.position_embedding_size, -1 ) resized_positional_embeddings = self.resize_positional_embeddings( positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] ) # Add positional embeddings to patch embeddings embeddings = patch_embeds + resized_positional_embeddings return embeddings class Siglip2VisionTransformer(SiglipVisionTransformer): def __init__(self, config: Siglip2VisionConfig): super().__init__() self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # Update: add `spatial_shapes` and `attention_mask` def forward( self, pixel_values: torch.FloatTensor, attention_mask: torch.Tensor, spatial_shapes: torch.LongTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutputWithPooling: r""" spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) hidden_states = self.embeddings(pixel_values, spatial_shapes) if attention_mask is not None and not self._use_flash_attention_2: # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) else: encoder_attention_mask = attention_mask encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class Siglip2PreTrainedModel(SiglipPreTrainedModel): pass class Siglip2TextModel(SiglipTextModel): pass class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): def __init__(self, config: Siglip2VisionConfig): super().__init__(config) self.num_heads = config.num_attention_heads def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) if attention_mask is not None: target_len, source_len = probe.shape[1], hidden_state.shape[1] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len) attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1) attention_mask = attention_mask.reshape(-1, target_len, source_len) hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0] class Siglip2VisionModel(SiglipVisionModel): # Update: add `spatial_shapes` and `pixel_attention_mask` def forward( self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.Tensor, spatial_shapes: torch.LongTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutputWithPooling: r""" pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): Mask to avoid performing attention on padding pixel indices. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Siglip2VisionModel >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" return self.vision_model( pixel_values=pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) class Siglip2Model(SiglipModel): # Update: add `spatial_shapes` and `pixel_attention_mask` def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.Tensor] = None, spatial_shapes: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> torch.FloatTensor: r""" pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): Mask to avoid performing attention on padding pixel indices. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. Returns: image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`]. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... image_features = model.get_image_features(**inputs) ``` """ # Use Siglip2Model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) pooled_output = vision_outputs.pooler_output return pooled_output # Update: add `spatial_shapes` and `pixel_attention_mask` def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.Tensor] = None, spatial_shapes: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> Siglip2Output: r""" pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): Mask to avoid performing attention on padding pixel indices. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. return_loss (`bool`, *optional*): Whether or not to return the contrastive loss. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] >>> # important: we pass `padding=max_length` since the model was trained with this >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> logits_per_image = outputs.logits_per_image >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") 31.9% that image 0 is 'a photo of 2 cats' ``` """ # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) image_embeds = vision_outputs.pooler_output text_embeds = text_outputs.pooler_output # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) logits_per_text = logits_per_text * logit_scale.exp() + logit_bias logits_per_image = logits_per_text.t() loss = None if return_loss: # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287 eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) nll = -torch.sum(loglik, dim=-1) loss = nll.mean() return Siglip2Output( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, ) class Siglip2ForImageClassification(SiglipForImageClassification): # Update: add `spatial_shapes` and `pixel_attention_mask` def forward( self, pixel_values: Optional[torch.Tensor] = None, pixel_attention_mask: Optional[torch.Tensor] = None, spatial_shapes: Optional[torch.LongTensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> ImageClassifierOutput: r""" pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): Mask to avoid performing attention on padding pixel indices. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Examples: ```python >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification >>> import torch >>> from PIL import Image >>> import requests >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> # note: we are loading a `Siglip2Model` from the hub here, >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224") >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224") >>> inputs = image_processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> logits = outputs.logits >>> # model predicts one of the two classes >>> predicted_class_idx = logits.argmax(-1).item() >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) Predicted class: LABEL_1 ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values, attention_mask=pixel_attention_mask, spatial_shapes=spatial_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) sequence_output = outputs.last_hidden_state # average pool the patch tokens if pixel_attention_mask is not None: pool_mask = pixel_attention_mask[..., None].to(sequence_output.device) sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1) else: sequence_output = torch.mean(sequence_output, dim=1) # apply classifier logits = self.classifier(sequence_output) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "Siglip2Config", "Siglip2TextConfig", "Siglip2VisionConfig", "Siglip2Model", "Siglip2PreTrainedModel", "Siglip2TextModel", "Siglip2VisionModel", "Siglip2ForImageClassification", ]