229 lines
9.3 KiB
Python
229 lines
9.3 KiB
Python
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from transformers.models.ijepa.configuration_ijepa import IJepaConfig
|
|
|
|
from ...modeling_outputs import ImageClassifierOutput
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...utils import auto_docstring, torch_int
|
|
from ..vit.modeling_vit import ViTEmbeddings, ViTForImageClassification, ViTModel
|
|
|
|
|
|
class IJepaEmbeddings(ViTEmbeddings):
|
|
def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None:
|
|
super().__init__(config, use_mask_token)
|
|
# Remove cls_token from IJepaEmbeddings, as it is not used in the model
|
|
del self.cls_token
|
|
num_patches = self.patch_embeddings.num_patches
|
|
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))
|
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
"""
|
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
|
images. This method is also adapted to support torch.jit tracing.
|
|
|
|
Adapted from:
|
|
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
|
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
|
"""
|
|
|
|
num_patches = embeddings.shape[1]
|
|
num_positions = self.position_embeddings.shape[1]
|
|
|
|
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
|
return self.position_embeddings
|
|
|
|
patch_pos_embed = self.position_embeddings
|
|
|
|
dim = embeddings.shape[-1]
|
|
|
|
new_height = height // self.patch_size
|
|
new_width = width // self.patch_size
|
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5)
|
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
|
|
patch_pos_embed = nn.functional.interpolate(
|
|
patch_pos_embed,
|
|
size=(new_height, new_width),
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
)
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
|
|
return patch_pos_embed
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
) -> torch.Tensor:
|
|
batch_size, _, height, width = pixel_values.shape
|
|
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
|
|
|
if bool_masked_pos is not None:
|
|
seq_length = embeddings.shape[1]
|
|
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
|
# replace the masked visual tokens by mask_tokens
|
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
|
|
|
# add positional encoding to each token
|
|
if interpolate_pos_encoding:
|
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
|
else:
|
|
embeddings = embeddings + self.position_embeddings
|
|
|
|
embeddings = self.dropout(embeddings)
|
|
|
|
return embeddings
|
|
|
|
|
|
@auto_docstring
|
|
class IJepaPreTrainedModel(PreTrainedModel):
|
|
config: IJepaConfig
|
|
base_model_prefix = "ijepa"
|
|
main_input_name = "pixel_values"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
|
|
_supports_sdpa = True
|
|
_supports_flash_attn = True
|
|
_supports_flex_attn = True
|
|
_supports_attention_backend = True
|
|
|
|
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
|
"""Initialize the weights"""
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
|
# `trunc_normal_cpu` not implemented in `half` issues
|
|
module.weight.data = nn.init.trunc_normal_(
|
|
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
|
).to(module.weight.dtype)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
elif isinstance(module, IJepaEmbeddings):
|
|
module.position_embeddings.data = nn.init.trunc_normal_(
|
|
module.position_embeddings.data.to(torch.float32),
|
|
mean=0.0,
|
|
std=self.config.initializer_range,
|
|
).to(module.position_embeddings.dtype)
|
|
if module.mask_token is not None:
|
|
module.mask_token.data.zero_()
|
|
|
|
|
|
class IJepaModel(IJepaPreTrainedModel, ViTModel):
|
|
def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False):
|
|
r"""
|
|
add_pooling_layer (bool, *optional*, defaults to `True`):
|
|
Whether to add a pooling layer
|
|
use_mask_token (`bool`, *optional*, defaults to `False`):
|
|
Whether to use a mask token for masked image modeling.
|
|
"""
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states)
|
|
e.g. for ImageNet.
|
|
|
|
<Tip>
|
|
|
|
Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by
|
|
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
|
position embeddings to the higher resolution.
|
|
|
|
</Tip>
|
|
"""
|
|
)
|
|
class IJepaForImageClassification(IJepaPreTrainedModel, ViTForImageClassification):
|
|
def __init__(self, config: IJepaConfig):
|
|
super().__init__(config)
|
|
self.ijepa = IJepaModel(config, add_pooling_layer=False)
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
interpolate_pos_encoding: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple, ImageClassifierOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.ijepa(
|
|
pixel_values,
|
|
head_mask=head_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.classifier(sequence_output.mean(dim=1))
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
# move labels to correct device to enable model parallelism
|
|
labels = labels.to(logits.device)
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
loss = loss_fct(logits, labels)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return ImageClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"IJepaPreTrainedModel",
|
|
"IJepaModel",
|
|
"IJepaForImageClassification",
|
|
]
|