team-10/env/Lib/site-packages/transformers/models/siglip2/modular_siglip2.py
2025-08-02 07:34:44 +02:00

637 lines
26 KiB
Python

# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
SiglipForImageClassification,
SiglipModel,
SiglipMultiheadAttentionPoolingHead,
SiglipOutput,
SiglipPreTrainedModel,
SiglipTextModel,
SiglipTextModelOutput,
SiglipVisionModel,
SiglipVisionModelOutput,
SiglipVisionTransformer,
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
class Siglip2TextConfig(SiglipTextConfig):
pass
class Siglip2VisionConfig(SiglipVisionConfig):
r"""
This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a
Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2
[google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
num_patches (`int`, *optional*, defaults to 256):
The number of patches in the image with the size of (`patch_size`, `patch_size`).
The image is resized to fill maximum of this number of patches, and to preserve
the aspect ratio. In case the resulted number of patches is lower, the image is
padded in "patch" dimension.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import Siglip2VisionConfig, Siglip2VisionModel
>>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration
>>> configuration = Siglip2VisionConfig()
>>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration
>>> model = Siglip2VisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
num_patches=256,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.num_patches = num_patches
del self.image_size
class Siglip2Config(SiglipConfig):
pass
class Siglip2VisionOutput(SiglipVisionModelOutput):
pass
class Siglip2TextOutput(SiglipTextModelOutput):
pass
class Siglip2Output(SiglipOutput):
pass
class Siglip2VisionEmbeddings(nn.Module):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Linear(
in_features=config.num_channels * self.patch_size * self.patch_size,
out_features=self.embed_dim,
)
self.num_patches = config.num_patches
self.position_embedding_size = int(self.num_patches**0.5)
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
@staticmethod
def resize_positional_embeddings(
positional_embeddings: torch.Tensor,
spatial_shapes: torch.LongTensor,
max_length: int,
) -> torch.Tensor:
"""
Resize positional embeddings to image-specific size and pad to a fixed size.
Args:
positional_embeddings (`torch.Tensor`):
Position embeddings of shape (height, width, embed_dim)
spatial_shapes (`torch.LongTensor`):
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
max_length (`int`):
Maximum length of the positional embeddings to pad resized positional embeddings to
Returns:
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
"""
batch_size = spatial_shapes.shape[0]
embed_dim = positional_embeddings.shape[-1]
source_dtype = positional_embeddings.dtype
resulted_positional_embeddings = torch.empty(
(batch_size, max_length, embed_dim),
device=positional_embeddings.device,
dtype=source_dtype,
)
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
# Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
if positional_embeddings.device.type == "cpu":
positional_embeddings = positional_embeddings.to(torch.float32)
for i in range(batch_size):
# (1, dim, height, width) -> (1, dim, target_height, target_width)
height, width = spatial_shapes[i]
resized_embeddings = F.interpolate(
positional_embeddings,
size=(height, width),
mode="bilinear",
align_corners=False,
antialias=True,
)
# (1, dim, target_height, target_width) -> (target_height * target_width, dim)
resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
# Cast to original dtype
resized_embeddings = resized_embeddings.to(source_dtype)
resulted_positional_embeddings[i, : height * width] = resized_embeddings
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
return resulted_positional_embeddings
def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
"""
Args:
pixel_values (`torch.FloatTensor`):
Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
spatial_shapes (`list[tuple[int, int]]`):
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
"""
# Apply patch embeddings to already patchified pixel values
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
# Get positional resized and padded positional embeddings
positional_embeddings = self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size, -1
)
resized_positional_embeddings = self.resize_positional_embeddings(
positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
)
# Add positional embeddings to patch embeddings
embeddings = patch_embeds + resized_positional_embeddings
return embeddings
class Siglip2VisionTransformer(SiglipVisionTransformer):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# Update: add `spatial_shapes` and `attention_mask`
def forward(
self,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
hidden_states = self.embeddings(pixel_values, spatial_shapes)
if attention_mask is not None and not self._use_flash_attention_2:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
else:
encoder_attention_mask = attention_mask
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class Siglip2PreTrainedModel(SiglipPreTrainedModel):
pass
class Siglip2TextModel(SiglipTextModel):
pass
class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead):
def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
self.num_heads = config.num_attention_heads
def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
if attention_mask is not None:
target_len, source_len = probe.shape[1], hidden_state.shape[1]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
attention_mask = attention_mask.reshape(-1, target_len, source_len)
hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
class Siglip2VisionModel(SiglipVisionModel):
# Update: add `spatial_shapes` and `pixel_attention_mask`
def forward(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Siglip2VisionModel
>>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
return self.vision_model(
pixel_values=pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
class Siglip2Model(SiglipModel):
# Update: add `spatial_shapes` and `pixel_attention_mask`
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_attention_mask: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`Siglip2VisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... image_features = model.get_image_features(**inputs)
```
"""
# Use Siglip2Model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = vision_outputs.pooler_output
return pooled_output
# Update: add `spatial_shapes` and `pixel_attention_mask`
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_attention_mask: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Siglip2Output:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
>>> # important: we pass `padding=max_length` since the model was trained with this
>>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
31.9% that image 0 is 'a photo of 2 cats'
```
"""
# Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
text_outputs: BaseModelOutputWithPooling = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
image_embeds = vision_outputs.pooler_output
text_embeds = text_outputs.pooler_output
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()
return Siglip2Output(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
class Siglip2ForImageClassification(SiglipForImageClassification):
# Update: add `spatial_shapes` and `pixel_attention_mask`
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_attention_mask: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.LongTensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> ImageClassifierOutput:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Examples:
```python
>>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
>>> import torch
>>> from PIL import Image
>>> import requests
>>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # note: we are loading a `Siglip2Model` from the hub here,
>>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
>>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
>>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the two classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: LABEL_1
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs.last_hidden_state
# average pool the patch tokens
if pixel_attention_mask is not None:
pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
else:
sequence_output = torch.mean(sequence_output, dim=1)
# apply classifier
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"Siglip2Config",
"Siglip2TextConfig",
"Siglip2VisionConfig",
"Siglip2Model",
"Siglip2PreTrainedModel",
"Siglip2TextModel",
"Siglip2VisionModel",
"Siglip2ForImageClassification",
]