# coding=utf-8 # Copyright 2023 The HuggingFace Inc. & Google team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pix2Struct modeling file""" import math from typing import Optional, Union import torch import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, auto_docstring, is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging, ) from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) # General docstring # Adapted from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pix2Struct class Pix2StructLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states try: from apex.normalization import FusedRMSNorm Pix2StructLayerNorm = FusedRMSNorm # noqa logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNorm") except ImportError: # using the normal Pix2StructLayerNorm pass except Exception: logger.warning("Discovered apex but it failed to load, falling back to Pix2StructLayerNorm") pass class Pix2StructVisionEmbeddings(nn.Module): r""" Construct the embeddings from patch. In `Pix2Struct` the input is different from classic Vision-transformer models. Here the input is a sequence of `seq_len` flattened patches that also combines padding patches (tokens). Each patch is represented by a vector of `hidden_size` values. """ def __init__(self, config: Pix2StructConfig) -> None: super().__init__() self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size) self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size) self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor: # the row and column indices are stored in the first and second position of the flattened_patches # flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2 row_indices = flattened_patches[:, :, 0].long() col_indices = flattened_patches[:, :, 1].long() flattened_patches = flattened_patches[:, :, 2:] embeddings = self.patch_projection(flattened_patches) row_embeddings = self.row_embedder(row_indices) col_embeddings = self.column_embedder(col_indices) # sum all embeddings together embeddings = embeddings + row_embeddings + col_embeddings embeddings = self.dropout(embeddings) return embeddings class Pix2StructVisionAttention(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.key_value_proj_dim = config.d_kv self.n_heads = config.num_attention_heads self.dropout = config.attention_dropout self.inner_dim = self.n_heads * self.key_value_proj_dim # Mesh TensorFlow initialization to avoid scaling before softmax self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False) self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False) self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False) self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, output_attentions=False, ): """ Self-attention block """ # Input is (batch_size, seq_length, dim) # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] def to_projection_shape(states): """projection""" return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # get query states # (batch_size, n_heads, seq_length, dim_per_head) query_states = to_projection_shape(self.query(hidden_states)) # get key/value states key_states = to_projection_shape(self.key(hidden_states)) value_states = to_projection_shape(self.value(hidden_states)) # compute scores # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: position_bias = torch.zeros( (1, self.n_heads, seq_length, seq_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True if attention_mask.dim() == 2: position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device) elif attention_mask is not None: # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + attention_mask.to(position_bias.device) elif not is_torchdynamo_compiling(): attention_mask = torch.ones( (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype ) position_bias = position_bias + attention_mask.to(position_bias.device) position_bias = 1 - position_bias position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min) scores += position_bias_masked scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min)) # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores) # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask attn_output = torch.matmul(attn_weights, value_states) # (batch_size, seq_length, dim) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) attn_output = self.output(attn_output) outputs = (attn_output,) + (position_bias,) if output_attentions: outputs = outputs + (attn_weights,) return outputs # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5DenseGatedActDense->Pix2StructVisionMlp,T5Config->Pix2StructVisionConfig,config.d_model->config.hidden_size,dropout_rate->dropout_rate class Pix2StructVisionMlp(nn.Module): def __init__(self, config: Pix2StructVisionConfig): super().__init__() self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False) self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False) self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False) self.dropout = nn.Dropout(config.dropout_rate) self.act = ACT2FN[config.dense_act_fn] def forward(self, hidden_states): hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. # See https://github.com/huggingface/transformers/issues/20287 # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` if ( isinstance(self.wo.weight, torch.Tensor) and hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8 ): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states class Pix2StructVisionLayer(GradientCheckpointingLayer): def __init__(self, config: Pix2StructConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = Pix2StructVisionAttention(config) self.mlp = Pix2StructVisionMlp(config) self.pre_mlp_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pre_attention_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: residual = hidden_states # in Pix2StructVision, layernorm is applied before self-attention hidden_states = self.pre_attention_layer_norm(hidden_states) self_attention_outputs = self.attention( hidden_states, attention_mask=attention_mask, layer_head_mask=head_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection hidden_states = attention_output + residual # in Pix2StructVision, layernorm is also applied after self-attention layer_output = self.pre_mlp_layer_norm(hidden_states) layer_output = self.mlp(layer_output) + hidden_states # second residual connection outputs = (layer_output,) + outputs return outputs class Pix2StructVisionEncoder(nn.Module): def __init__(self, config: Pix2StructConfig) -> None: super().__init__() self.config = config self.layer = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @auto_docstring class Pix2StructPreTrainedModel(PreTrainedModel): config: Pix2StructConfig _can_compile_fullgraph = False @property def dummy_inputs(self): input_ids = torch.tensor(DUMMY_INPUTS) input_mask = torch.tensor(DUMMY_MASK) dummy_inputs = { "decoder_input_ids": input_ids, "input_ids": input_ids, "decoder_attention_mask": input_mask, } return dummy_inputs def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, Pix2StructLayerNorm): module.weight.data.fill_(factor * 1.0) elif isinstance(module, Pix2StructTextDenseGatedActDense): hidden_size = ( self.config.text_config.hidden_size if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size ) d_ff = self.config.text_config.d_ff if isinstance(self.config, Pix2StructConfig) else self.config.d_ff module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: module.wi_0.bias.data.zero_() module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: module.wi_1.bias.data.zero_() module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, Pix2StructTextAttention): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 hidden_size = ( self.config.text_config.hidden_size if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size ) key_value_proj_dim = ( self.config.text_config.d_kv if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size ) n_heads = ( self.config.text_config.num_heads if isinstance(self.config, Pix2StructConfig) else self.config.num_heads ) module.query.weight.data.normal_(mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5)) module.key.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) module.value.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) module.output.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) elif isinstance(module, nn.Embedding): hidden_size = ( self.config.text_config.hidden_size if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size ) module.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, Pix2StructTextModel): hidden_size = ( self.config.text_config.hidden_size if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size ) module.lm_head.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) elif isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues module.weight.data = nn.init.trunc_normal_( module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range ).to(module.weight.dtype) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, Pix2StructLayerNorm): if module.weight is not None: module.weight.data.fill_(1.0) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->Pix2Struct def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id if decoder_start_token_id is None: raise ValueError( "self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id. " "See Pix2Struct docs for more information." ) # shift inputs to the right if is_torch_fx_proxy(input_ids): # Item assignment is not supported natively for proxies. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) else: shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids @auto_docstring class Pix2StructVisionModel(Pix2StructPreTrainedModel): config: Pix2StructVisionConfig main_input_name = "flattened_patches" supports_gradient_checkpointing = True _no_split_modules = ["Pix2StructVisionLayer"] def __init__(self, config: Pix2StructConfig): super().__init__(config) self.config = config self.embeddings = Pix2StructVisionEmbeddings(config) self.encoder = Pix2StructVisionEncoder(config) self.layernorm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.patch_projection def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) @auto_docstring def forward( self, flattened_patches: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPooling]: r""" flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`): Flattened and padded pixel values. These values can be obtained using [`AutoImageProcessor`]. See [`Pix2StructVisionImageProcessor.__call__`] for details. Check the [original paper](https://huggingface.co/papers/2210.03347) (figure 5) for more details. Example: ```python >>> import requests >>> from PIL import Image >>> from transformers import AutoProcessor, Pix2StructVisionModel >>> image_processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") >>> model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base") >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = image_processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state >>> list(last_hidden_states.shape) [1, 2048, 768] ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if flattened_patches is None: raise ValueError("You have to specify flattened_patches") if attention_mask is None: # check where `flattened_patches` is not 0 attention_mask = (flattened_patches.sum(dim=-1) != 0).float() # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings(flattened_patches) encoder_outputs = self.encoder( embedding_output, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) if not return_dict: head_outputs = (sequence_output,) return head_outputs + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pix2StructText,d_model->hidden_size class Pix2StructTextDenseGatedActDense(nn.Module): def __init__(self, config: Pix2StructTextConfig): super().__init__() self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False) self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False) self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False) self.dropout = nn.Dropout(config.dropout_rate) self.act = ACT2FN[config.dense_act_fn] def forward(self, hidden_states): hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. # See https://github.com/huggingface/transformers/issues/20287 # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` if ( isinstance(self.wo.weight, torch.Tensor) and hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8 ): hidden_states = hidden_states.to(self.wo.weight.dtype) hidden_states = self.wo(hidden_states) return hidden_states class Pix2StructTextLayerFF(nn.Module): def __init__(self, config: Pix2StructTextConfig): super().__init__() self.DenseReluDense = Pix2StructTextDenseGatedActDense(config) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) # Copied from transformers.models.t5.modeling_t5.T5LayerFF.forward def forward(self, hidden_states): forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states) hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states class Pix2StructTextAttention(nn.Module): def __init__( self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None ): super().__init__() self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_max_distance = config.relative_attention_max_distance self.hidden_size = config.hidden_size self.key_value_proj_dim = config.d_kv self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) # Mesh TensorFlow initialization to avoid scaling before softmax self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.output = nn.Linear(self.hidden_size, self.hidden_size, bias=False) if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() self.gradient_checkpointing = False @staticmethod # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ relative_buckets = 0 if bidirectional: num_buckets //= 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance relative_position_if_large = max_exact + ( torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) ) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device if cache_position is None: context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] else: context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=False, num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward def forward( self, hidden_states, mask=None, key_value_states=None, position_bias=None, past_key_value=None, layer_head_mask=None, query_length=None, use_cache=False, output_attentions=False, cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] # if key_value_states are provided this layer is used as a cross-attention layer for the decoder is_cross_attention = key_value_states is not None query_states = self.query(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache else: curr_past_key_value = past_key_value current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.key(current_states) value_states = self.value(current_states) key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: key_length = key_states.shape[-2] # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: position_bias = self.compute_bias( real_seq_length, key_length, device=scores.device, cache_position=cache_position ) position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: causal_mask = mask[:, :, :, : key_states.shape[-2]] position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) mask[list(self.pruned_heads)] = 0 position_bias_masked = position_bias[:, mask.bool()] else: position_bias_masked = position_bias scores += position_bias_masked # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.output(attn_output) outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) return outputs # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerSelfAttention(nn.Module): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.attention = Pix2StructTextAttention( config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx ) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, past_key_value=None, use_cache=False, output_attentions=False, cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( normed_hidden_states, mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerCrossAttention(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward( self, hidden_states, key_value_states, attention_mask=None, position_bias=None, layer_head_mask=None, past_key_value=None, use_cache=False, query_length=None, output_attentions=False, cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( normed_hidden_states, mask=attention_mask, key_value_states=key_value_states, position_bias=position_bias, layer_head_mask=layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them return outputs class Pix2StructTextBlock(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.self_attention = Pix2StructTextLayerSelfAttention( config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx, ) self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention( config, layer_idx=layer_idx, ) self.mlp = Pix2StructTextLayerFF(config) def forward( self, hidden_states, attention_mask=None, position_bias=None, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, past_key_value=None, use_cache=False, output_attentions=False, return_dict=True, cache_position=None, ): self_attention_outputs = self.self_attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) hidden_states = self_attention_outputs[0] attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) do_cross_attention = encoder_hidden_states is not None if do_cross_attention: cross_attention_outputs = self.encoder_decoder_attention( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[1:] # Apply Feed Forward layer hidden_states = self.mlp(hidden_states) # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) return outputs + attention_outputs @auto_docstring( custom_intro=""" The standalone text decoder of Pix2Struct """ ) class Pix2StructTextModel(Pix2StructPreTrainedModel): config: Pix2StructTextConfig _no_split_modules = ["Pix2StructTextBlock"] _tied_weights_keys = ["lm_head.weight"] supports_gradient_checkpointing = True def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layer = nn.ModuleList( [ Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers) ] ) self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() self.gradient_checkpointing = False def set_input_embeddings(self, new_embeddings): self.embed_tokens = new_embeddings @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Pix2StructText is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail. [What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining take a look a [Pix2StructText Training](./t5#training). cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. Example: ```python >>> from transformers import AutoProcessor, Pix2StructTextModel >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") >>> model = Pix2StructTextModel.from_pretrained("google/pix2struct-textcaps-base") >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> loss = outputs.loss ``` """ use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.gradient_checkpointing and self.training and use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape if use_cache and past_key_values is None: if self.config.is_encoder_decoder: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) else: past_key_values = DynamicCache() past_key_values_length = 0 if cache_position is not None: past_key_values_length = cache_position[0] elif past_key_values is not None: past_key_values_length = past_key_values.get_seq_length() if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) if attention_mask is None: # required mask seq length can be calculated via length of past mask_seq_length = ( past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length ) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.config.is_decoder: causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values.self_attention_cache if isinstance(past_key_values, EncoderDecoderCache) else past_key_values, output_attentions, ) else: causal_mask = attention_mask[:, None, None, :] causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions) else None position_bias = None encoder_decoder_position_bias = None hidden_states = self.dropout(inputs_embeds) for i, layer_module in enumerate(self.layer): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module( hidden_states, causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, ) hidden_states = layer_outputs[0] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[1] if encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] if output_attentions: all_attentions = all_attentions + (layer_outputs[2],) if encoder_hidden_states is not None: all_cross_attentions = all_cross_attentions + (layer_outputs[4],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) logits = self.lm_head(hidden_states) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean") loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) if not return_dict: return tuple( v for v in [ loss, logits, past_key_values, all_hidden_states, all_attentions, all_cross_attentions, ] if v is not None ) return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, cache_position: torch.Tensor, batch_size: int, **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask @auto_docstring( custom_intro=""" A conditional generation model with a language modeling head. Can be used for sequence generation tasks. """ ) class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin): config: Pix2StructConfig main_input_name = "flattened_patches" _tied_weights_keys = ["decoder.lm_head.weight"] def __init__(self, config: Pix2StructConfig): super().__init__(config) self.encoder = Pix2StructVisionModel(config.vision_config) self.decoder = Pix2StructTextModel(config.text_config) self.is_vqa = config.is_vqa # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.decoder.get_input_embeddings() def set_input_embeddings(self, new_embeddings): self.decoder.set_input_embeddings(new_embeddings) def get_output_embeddings(self) -> nn.Module: return self.decoder.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.decoder.set_output_embeddings(new_embeddings) def get_decoder(self): return self.decoder def get_encoder(self): return self.encoder @auto_docstring def forward( self, flattened_patches: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Cache] = None, labels: Optional[torch.LongTensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`): Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` = `num_channels` * `patch_size` * `patch_size` The process of flattening the pixel patches is done by `Pix2StructProcessor`. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText Training](./t5#training). decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss for the decoder. Example: Inference: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> # autoregressive generation >>> generated_ids = model.generate(**inputs, max_new_tokens=50) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> print(generated_text) A stop sign is on a street corner. >>> # conditional generation >>> text = "A picture of" >>> inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False) >>> generated_ids = model.generate(**inputs, max_new_tokens=50) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> print(generated_text) A picture of a stop sign with a red stop sign ``` Training: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration >>> processor = AutoProcessor.from_pretrained("google/pix2struct-base") >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base") >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> text = "A stop sign is on the street corner." >>> inputs = processor(images=image, return_tensors="pt") >>> labels = processor(text=text, return_tensors="pt").input_ids >>> # forward pass >>> outputs = model(**inputs, labels=labels) >>> loss = outputs.loss >>> print(f"{loss.item():.5f}") 5.94282 ```""" use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: encoder_outputs = self.encoder( flattened_patches=flattened_patches, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) decoder_attention_mask = ( decoder_attention_mask if decoder_attention_mask is not None else decoder_input_ids.ne(self.config.pad_token_id).float() ) # Always attend to the first token decoder_attention_mask[:, 0] = 1 # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, labels=labels, return_dict=return_dict, cache_position=cache_position, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=decoder_outputs.loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) __all__ = [ "Pix2StructPreTrainedModel", "Pix2StructForConditionalGeneration", "Pix2StructVisionModel", "Pix2StructTextModel", ]