# coding=utf-8 # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch KOSMOS-2 model.""" import math from dataclasses import dataclass from typing import Any, Callable, Optional, Union import torch import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, CausalLMOutputWithCrossAttentions, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig logger = logging.get_logger(__name__) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is modified from fairseq's `utils.make_positions`. Args: x: torch.Tensor x: Returns: torch.Tensor """ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. mask = input_ids.ne(padding_idx).int() incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask return incremental_indices.long() + padding_idx @dataclass @auto_docstring( custom_intro=""" Base class for text model's outputs that also contains a pooling of the last hidden states. """ ) class Kosmos2ModelOutput(ModelOutput): r""" past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. projection_attentions (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute the weighted average in the self-attention heads. vision_model_output (`BaseModelOutputWithPooling`, *optional*): The output of the [`Kosmos2VisionModel`]. """ last_hidden_state: Optional[torch.FloatTensor] = None past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_embeds: Optional[torch.FloatTensor] = None projection_attentions: Optional[tuple[torch.FloatTensor]] = None vision_model_output: BaseModelOutputWithPooling = None def to_tuple(self) -> tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() ) @dataclass @auto_docstring( custom_intro=""" Model output class for `Kosmos2ForConditionalGeneration`. """ ) class Kosmos2ForConditionalGenerationModelOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. projection_attentions (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute the weighted average in the self-attention heads. vision_model_output (`BaseModelOutputWithPooling`, *optional*): The output of the [`Kosmos2VisionModel`]. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_embeds: Optional[torch.FloatTensor] = None projection_attentions: Optional[tuple[torch.FloatTensor]] = None vision_model_output: BaseModelOutputWithPooling = None def to_tuple(self) -> tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() ) # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Kosmos2 class Kosmos2VisionEmbeddings(nn.Module): def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False, ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 position_embedding = self.position_embedding.weight.unsqueeze(0) num_positions = position_embedding.shape[1] - 1 # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embedding(self.position_ids) class_pos_embed = position_embedding[:, :1] patch_pos_embed = position_embedding[:, 1:] dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})." ) target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> Kosmos2 doesn't cast attn weights to fp32 def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Kosmos2VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, seq_length, embed_dim = hidden_states.shape queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) values = self.v_proj(hidden_states) queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) # CLIP text model uses both `causal_attention_mask` and `attention_mask` # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` if self.config._attn_implementation != "flash_attention_2": if attention_mask is not None and causal_attention_mask is not None: attention_mask = attention_mask + causal_attention_mask elif causal_attention_mask is not None: attention_mask = causal_attention_mask else: self.is_causal = causal_attention_mask is not None attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, queries, keys, values, attention_mask, is_causal=self.is_causal, scaling=self.scale, dropout=0.0 if not self.training else self.dropout, ) attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision class Kosmos2VisionMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = Kosmos2VisionAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Kosmos2VisionMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, causal_attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(config.encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Kosmos2Vision class Kosmos2VisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Kosmos2VisionEncoderLayer`]. Args: config: Kosmos2VisionConfig """ def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False @can_return_tuple def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutput]: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Causal mask for the text model. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) layer_outputs = encoder_layer( hidden_states, attention_mask, causal_attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) # Similar to `transformers.models.clip.modeling_clip.CLIPVisionTransformer` but without docstring for `forward` class Kosmos2VisionTransformer(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPVision->Kosmos2Vision,ALTCLIP_VISION->KOSMOS2_VISION,AltCLIP->Kosmos2Vision def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = Kosmos2VisionEmbeddings(config) self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = Kosmos2VisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) # Similar to `transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding` but allowing to pass `position_ids` class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.__init__ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__() self.offset = 2 self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.make_weights def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) if hasattr(self, "weights"): # in forward put the weights on the correct dtype and device of the param emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) self.register_buffer("weights", emb_weights, persistent=False) @staticmethod # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.get_embedding def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): """ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 return emb.to(torch.get_default_dtype()) @torch.no_grad() def forward( self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None, ): if input_ids is not None: bsz, seq_len = input_ids.size() if position_ids is None: # Create the position ids from the input token ids. Any padded tokens remain padded. position_ids = create_position_ids_from_input_ids( input_ids, self.padding_idx, past_key_values_length ).to(input_ids.device) else: bsz, seq_len = inputs_embeds.size()[:-1] if position_ids is None: position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) # expand embeddings if needed max_pos = self.padding_idx + 1 + seq_len + past_key_values_length if max_pos > self.weights.size(0): self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): """ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. Args: inputs_embeds: torch.Tensor Returns: torch.Tensor """ input_shape = inputs_embeds.size()[:-1] sequence_length = input_shape[1] position_ids = torch.arange( self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device ) return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length class KosmosTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" # Similar to transformers.models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`. def __init__( self, config, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: Optional[bool] = False, add_inner_attn_layernorm: Optional[bool] = False, bias: Optional[bool] = True, layer_idx: Optional[bool] = None, ): super().__init__() self.config = config self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # End opy self.inner_attn_ln = None if add_inner_attn_layernorm: self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] query_states = self.q_proj(hidden_states) query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache curr_past_key_value = past_key_value.cross_attention_cache else: curr_past_key_value = past_key_value.self_attention_cache else: curr_past_key_value = past_key_value current_states = encoder_hidden_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.key_cache[self.layer_idx] value_states = curr_past_key_value.value_cache[self.layer_idx] else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() if self.inner_attn_ln is not None: attn_output = self.inner_attn_ln(attn_output) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class Kosmos2TextFFN(nn.Module): def __init__(self, config: Kosmos2TextConfig): super().__init__() self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(config.embed_dim, config.ffn_dim) self.fc2 = nn.Linear(config.ffn_dim, config.embed_dim) self.ffn_layernorm = nn.LayerNorm(config.ffn_dim, eps=config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.ffn_layernorm(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return hidden_states class Kosmos2TextBlock(GradientCheckpointingLayer): def __init__(self, config: Kosmos2TextConfig, layer_idx=None): super().__init__() self.embed_dim = config.embed_dim self.self_attn = KosmosTextAttention( config, embed_dim=self.embed_dim, num_heads=config.attention_heads, dropout=config.attention_dropout, is_decoder=True, add_inner_attn_layernorm=True, layer_idx=layer_idx, ) self.dropout = config.dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) if config.add_cross_attention: self.encoder_attn = KosmosTextAttention( config, embed_dim=self.embed_dim, num_heads=config.attention_heads, dropout=config.attention_dropout, is_decoder=True, add_inner_attn_layernorm=False, layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.ffn = Kosmos2TextFFN(config) self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block cross_attn_weights = None if encoder_hidden_states is not None: if not hasattr(self, "encoder_attn"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" " by setting `config.add_cross_attention=True`" ) residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) # FFN hidden_states = self.ffn(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) return outputs class Kosmos2TextTransformer(nn.Module): """ Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`]. Args: config: Kosmos2TextConfig """ def __init__(self, config: Kosmos2TextConfig): super().__init__() self.config = config self.dropout = config.dropout self.layerdrop = config.layerdrop self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0 self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id) self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding( num_positions=config.max_position_embeddings, embedding_dim=config.embed_dim, padding_idx=config.pad_token_id, ) self.layers = nn.ModuleList([Kosmos2TextBlock(config, layer_idx=i) for i in range(config.layers)]) self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) self.gradient_checkpointing = False def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask def forward_embedding( self, input_ids, inputs_embeds: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, img_input_mask: Optional[torch.Tensor] = None, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None, ): # The argument `inputs_embeds` should be the one without being multiplied by `self.embed_scale`. if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if image_embeds is not None: inputs_embeds[img_input_mask.to(dtype=torch.bool)] = image_embeds.to(inputs_embeds.device).view( -1, image_embeds.size(-1) ) inputs_embeds = inputs_embeds * self.embed_scale # embed positions positions = self.embed_positions( input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids, ) positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return hidden_states def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, image_embeds_position_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) return_legacy_cache = True past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # We don't need img info. when `past_key_values_length` > 0 if past_key_values_length > 0: image_embeds = None image_embeds_position_mask = None hidden_states = self.forward_embedding( input_ids=input_ids, inputs_embeds=inputs_embeds, image_embeds=image_embeds, img_input_mask=image_embeds_position_mask, past_key_values_length=past_key_values_length, position_ids=position_ids, ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, hidden_states, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != (len(self.layers)): raise ValueError( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: continue layer_outputs = decoder_layer( hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) # add final layer norm hidden_states = self.layer_norm(hidden_states) if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) @auto_docstring class Kosmos2PreTrainedModel(PreTrainedModel): config: Kosmos2Config supports_gradient_checkpointing = True _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"] _supports_attention_backend = True _supports_flash_attn = True _supports_sdpa = True def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(self, Kosmos2VisionModel): factor = self.config.initializer_factor elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)): factor = self.config.vision_config.initializer_factor if isinstance(self, (Kosmos2TextModel, Kosmos2TextForCausalLM)): std = self.config.init_std elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)): std = self.config.text_config.init_std if isinstance(module, Kosmos2VisionEmbeddings): nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) elif isinstance(module, Kosmos2VisionAttention): in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor out_proj_std = (module.embed_dim**-0.5) * factor nn.init.normal_(module.q_proj.weight, std=in_proj_std) nn.init.normal_(module.k_proj.weight, std=in_proj_std) nn.init.normal_(module.v_proj.weight, std=in_proj_std) nn.init.normal_(module.out_proj.weight, std=out_proj_std) elif isinstance(module, Kosmos2VisionMLP): in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc2.weight, std=in_proj_std) elif isinstance(module, KosmosTextAttention): nn.init.normal_(module.q_proj.weight, std=std) nn.init.normal_(module.k_proj.weight, std=std) nn.init.normal_(module.v_proj.weight, std=std) nn.init.normal_(module.out_proj.weight, std=std) elif isinstance(module, Kosmos2TextFFN): nn.init.normal_(module.fc1.weight, std=std) nn.init.normal_(module.fc2.weight, std=std) elif isinstance(module, Kosmos2TextForCausalLM): nn.init.normal_(module.lm_head.weight, std=std) elif isinstance(module, Kosmos2ImageToTextProjection): nn.init.normal_(module.dense.weight, std=std) nn.init.normal_(module.latent_query) elif isinstance(module, Kosmos2TextTransformer): module.embed_tokens.weight.data.normal_(mean=0.0, std=std) if module.embed_tokens.padding_idx is not None: module.embed_tokens.weight.data[module.embed_tokens.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class Kosmos2VisionModel(Kosmos2PreTrainedModel): config: Kosmos2VisionConfig main_input_name = "pixel_values" # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model def __init__(self, config: Kosmos2VisionConfig): super().__init__(config) self.model = Kosmos2VisionTransformer(config) # Initialize weights and apply final processing self.post_init() # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.get_input_embeddings with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model def get_input_embeddings(self) -> nn.Module: return self.model.embeddings.patch_embedding @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPooling]: return self.model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, ) class Kosmos2TextModel(Kosmos2PreTrainedModel): config: Kosmos2TextConfig def __init__(self, config: Kosmos2TextConfig): super().__init__(config) self.model = Kosmos2TextTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.model.embed_tokens @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, image_embeds_position_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, 1]`: - 1 for places where to put the image features, - 0 for places that are not for image features (i.e. for text tokens). cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. """ return self.model( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs, ) @auto_docstring( custom_intro=""" The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input embeddings). """ ) class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): config: Kosmos2TextConfig _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: Kosmos2TextConfig): super().__init__(config) self.model = Kosmos2TextTransformer(config) self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.model.embed_tokens def get_output_embeddings(self) -> nn.Module: return self.lm_head @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, image_embeds_position_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, 1]`: - 1 for places where to put the image features, - 0 for places that are not for image features (i.e. for text tokens). cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **kwargs, ) lm_logits = self.lm_head(outputs[0]) loss = None if labels is not None: loss = self.loss_function(logits=lm_logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation( self, input_ids, image_embeds=None, image_embeds_position_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, cache_position=None, **model_kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore if cache_position[0] != 0: image_embeds = None image_embeds_position_mask = None # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) elif image_embeds_position_mask is not None: batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size() mask_len = image_embeds_position_mask.size()[-1] image_embeds_position_mask = torch.cat( ( image_embeds_position_mask, torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device), ), dim=1, ) model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **model_kwargs, ) # Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer model_inputs.pop("position_ids", None) return model_inputs class Kosmos2ImageToTextProjection(nn.Module): """The layer that transforms the image model's output to part of the text model's input (namely, image features)""" def __init__(self, config: Kosmos2Config): super().__init__() self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim) self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim)) self.x_attn = KosmosTextAttention( config.text_config, config.text_config.embed_dim, config.text_config.attention_heads, dropout=config.text_config.attention_dropout, is_decoder=False, add_inner_attn_layernorm=False, ) def forward(self, features): hidden_states = self.dense(features) # shape = [batch, latent_query_num, h_dim] latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1) key_value_states = torch.cat([hidden_states, latent_query], dim=1) hidden_states, attn_weights = self.x_attn( hidden_states=latent_query, encoder_hidden_states=key_value_states, past_key_value=None, attention_mask=None, output_attentions=None, ) return hidden_states, attn_weights @auto_docstring( custom_intro=""" KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model. """ ) class Kosmos2Model(Kosmos2PreTrainedModel): config: Kosmos2Config main_input_name = "pixel_values" def __init__(self, config: Kosmos2Config): super().__init__(config) self.text_model = Kosmos2TextModel(config.text_config) self.vision_model = Kosmos2VisionModel(config.vision_config) self.image_to_text_projection = Kosmos2ImageToTextProjection(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.model.embed_tokens def set_input_embeddings(self, value): self.text_model.model.embed_tokens = value def get_image_features( self, pixel_values: torch.FloatTensor, return_attentions: Optional[bool] = False, interpolate_pos_encoding: Optional[bool] = False, ): """ Encodes images into continuous embeddings that can be forwarded to the language model. Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): The tensors corresponding to the input images. return_attentions (`bool`, *optional*, defaults to `False`): Whether to return `projection_attentions` or not. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate positional embeddings or not. """ vision_model_output = self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) # normalized features image_embeds = nn.functional.normalize(image_embeds, dim=-1) image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) if return_attentions: return image_embeds, projection_attentions return image_embeds @can_return_tuple @auto_docstring def forward( self, pixel_values: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, image_embeds_position_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, image_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, Kosmos2ModelOutput]: r""" image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, 1]`: - 1 for places where to put the image features, - 0 for places that are not for image features (i.e. for text tokens). image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Kosmos2Model >>> model = Kosmos2Model.from_pretrained("microsoft/kosmos-2-patch14-224") >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> text = ( ... " An image of a snowman" ... " warming himself by a fire" ... "" ... ) >>> inputs = processor(text=text, images=image, return_tensors="pt", add_eos_token=True) >>> last_hidden_state = model( ... pixel_values=inputs["pixel_values"], ... input_ids=inputs["input_ids"], ... attention_mask=inputs["attention_mask"], ... image_embeds_position_mask=inputs["image_embeds_position_mask"], ... ).last_hidden_state >>> list(last_hidden_state.shape) [1, 91, 2048] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_model_output = None projection_attentions = None if image_embeds is None: if pixel_values is None: raise ValueError("You have to specify either `pixel_values` or `image_embeds`.") image_embeds, projection_attentions = self.get_image_features( pixel_values, return_attentions=True, interpolate_pos_encoding=interpolate_pos_encoding ) outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, **kwargs, ) return Kosmos2ModelOutput( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_embeds=image_embeds, projection_attentions=projection_attentions, vision_model_output=vision_model_output, ) @auto_docstring( custom_intro=""" KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a language model. """ ) class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): config: Kosmos2Config main_input_name = "pixel_values" _tied_weights_keys = ["text_model.lm_head.weight"] def __init__(self, config: Kosmos2Config): super().__init__(config) self.text_model = Kosmos2TextForCausalLM(config.text_config) self.vision_model = Kosmos2VisionModel(config.vision_config) self.image_to_text_projection = Kosmos2ImageToTextProjection(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.model.embed_tokens def set_input_embeddings(self, value): self.text_model.model.embed_tokens = value def get_output_embeddings(self) -> nn.Module: return self.text_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.text_model.set_output_embeddings(new_embeddings) @can_return_tuple @auto_docstring def forward( self, pixel_values: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, image_embeds_position_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, image_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Kosmos2ForConditionalGenerationModelOutput]: r""" image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, 1]`: - 1 for places where to put the image features, - 0 for places that are not for image features (i.e. for text tokens). image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, Kosmos2ForConditionalGeneration >>> model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224") >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> prompt = " An image of" >>> inputs = processor(text=prompt, images=image, return_tensors="pt") >>> generated_ids = model.generate( ... pixel_values=inputs["pixel_values"], ... input_ids=inputs["input_ids"], ... attention_mask=inputs["attention_mask"], ... image_embeds=None, ... image_embeds_position_mask=inputs["image_embeds_position_mask"], ... use_cache=True, ... max_new_tokens=64, ... ) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False) >>> processed_text ' An image of a snowman warming himself by a fire.' >>> caption, entities = processor.post_process_generation(generated_text) >>> caption 'An image of a snowman warming himself by a fire.' >>> entities [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) vision_model_output = None projection_attentions = None if image_embeds is None: if pixel_values is None: raise ValueError("You have to specify either `pixel_values` or `image_embeds`.") vision_model_output = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) # normalized features image_embeds = nn.functional.normalize(image_embeds, dim=-1) image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) lm_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, position_ids=position_ids, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, **kwargs, ) return Kosmos2ForConditionalGenerationModelOutput( loss=lm_outputs.loss, logits=lm_outputs.logits, past_key_values=lm_outputs.past_key_values, hidden_states=lm_outputs.hidden_states, attentions=lm_outputs.attentions, image_embeds=image_embeds, projection_attentions=projection_attentions, vision_model_output=vision_model_output, ) def generate( self, pixel_values: Optional[torch.Tensor] = None, image_embeds_position_mask: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): # in order to allow `inputs` argument (as in `GenerationMixin`) inputs = kwargs.pop("inputs", None) if pixel_values is not None and inputs is not None: raise ValueError( f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed." f"Make sure to either pass `inputs` or pixel_values=..." ) if pixel_values is None and inputs is not None: pixel_values = inputs if image_embeds is None: vision_model_output = self.vision_model(pixel_values) # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) # normalized features image_embeds = nn.functional.normalize(image_embeds, dim=-1) image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) output = self.text_model.generate( input_ids=input_ids, attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, inputs_embeds=inputs_embeds, **kwargs, ) return output __all__ = ["Kosmos2ForConditionalGeneration", "Kosmos2Model", "Kosmos2PreTrainedModel"]