# coding=utf-8 # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Chameleon model.""" from functools import cached_property from typing import Callable, Optional, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging, ) from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig logger = logging.get_logger(__name__) # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon class ChameleonRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ ChameleonRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon # TODO(joao): add me back asap :) class ChameleonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding): """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def forward(self, x, position_ids): # difference to the original RoPE: a scaling factor is applied to the position ids position_ids = position_ids.float() / self.scaling_factor cos, sin = super().forward(x, position_ids) return cos, sin class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding): """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def forward(self, x, position_ids): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length seq_len = torch.max(position_ids) + 1 if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / ( base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation cos, sin = super().forward(x, position_ids) return cos, sin # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon class ChameleonMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] # Ignore copy def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class ChameleonLayerNorm(nn.LayerNorm): """ LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta from each shard separately to each head, instead of reducing. We can apply each head's own gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed in the last dimension. This module applies gamma/beta manually to fulfill this requirement. """ def __init__(self, hidden_size, *args, **kwargs): super().__init__(hidden_size, *args, **kwargs) self.normalized_shape = (hidden_size[-1],) def forward(self, hidden_states): hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5) hidden_states = hidden_states * self.weight + self.bias return hidden_states # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) # Copied from transformers.models.llama.modeling_llama.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class ChameleonAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: ChameleonConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.model_parallel_size = config.model_parallel_size self.scaling = self.head_dim**-0.5 if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim)) self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim)) self._init_rope() # copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon # TODO(joao): add me back asap :) def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = ChameleonRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = ChameleonLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = ChameleonDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) query_states = self.q_norm(query_states) key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) key_states = self.k_norm(key_states) query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON class ChameleonDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ChameleonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx) self.mlp = ChameleonMLP(config) self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class ChameleonSwinDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ChameleonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx) self.mlp = ChameleonMLP(config) self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. """ residual = hidden_states # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = self.input_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class ChameleonVQVAEVectorQuantizer(nn.Module): """ A module for vector quantization using learned embedding vectors. This module implements the quantization process similar to te one described in the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous input vectors into discrete codebook vectors, which are learned during training. Current implementation improves over previous ones by avoiding costly matrix multiplications and allowing for post-hoc remapping of indices. """ def __init__(self, config): super().__init__() self.num_embeddings = config.num_embeddings self.embedding_dim = config.embed_dim self.beta = getattr(config, "beta", 0.25) self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) def forward(self, hidden_state: torch.Tensor): hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z distances = ( torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1)) ) min_encoding_indices = torch.argmin(distances, dim=1) hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape) # compute loss for embedding loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean( (hidden_state_quant - hidden_state.detach()) ** 2 ) # preserve gradients hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach() # reshape back to match original input shape hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() return hidden_state_quant, loss, min_encoding_indices class ChameleonVQVAEEncoderConvDownsample(nn.Module): def __init__(self, in_channels): super().__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, hidden_states): # no asymmetric padding in torch conv, must do it ourselves hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) hidden_states = self.conv(hidden_states) return hidden_states class ChameleonVQVAEEncoderResnetBlock(nn.Module): def __init__( self, config, in_channels, out_channels=None, conv_shortcut=False, ): super().__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) self.dropout = torch.nn.Dropout(config.dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, hidden_states): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states *= torch.sigmoid(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states) hidden_states *= torch.sigmoid(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.in_channels != self.out_channels: if self.use_conv_shortcut: residual = self.conv_shortcut(residual) else: residual = self.nin_shortcut(residual) return residual + hidden_states class ChameleonVQVAEEncoderAttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, hidden_states): residual = hidden_states hidden_states = self.norm(hidden_states) query_states = self.q(hidden_states) key_states = self.k(hidden_states) value_states = self.v(hidden_states) # compute attention batch_size, channels, height, width = query_states.shape query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1) key_states = key_states.reshape(batch_size, channels, height * width) attn_weights = torch.bmm(query_states, key_states) attn_weights = attn_weights * (int(channels) ** (-0.5)) attn_weights = F.softmax(attn_weights, dim=2) # attend to values value_states = value_states.reshape(batch_size, channels, height * width) attn_weights = attn_weights.permute(0, 2, 1) attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width) attn_output = self.proj_out(attn_output) return residual + attn_output class ChameleonVQVAEEncoder(nn.Module): def __init__(self, config): super().__init__() self.num_resolutions = len(config.channel_multiplier) self.num_res_blocks = config.num_res_blocks base_channels = config.base_channels resolution = config.resolution in_channels = config.in_channels double_latent = config.double_latent latent_channels = config.latent_channels channel_multiplier = config.channel_multiplier self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) curr_res = resolution in_channel_multiplier = (1,) + tuple(channel_multiplier) self.in_channel_multiplier = in_channel_multiplier self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = base_channels * in_channel_multiplier[i_level] block_out = base_channels * channel_multiplier[i_level] for i_block in range(self.num_res_blocks): block.append( ChameleonVQVAEEncoderResnetBlock( config=config, in_channels=block_in, out_channels=block_out, ) ) block_in = block_out if ( config.attn_resolutions is not None and curr_res in config.attn_resolutions and config.attn_type == "vanilla" ): attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in) curr_res = curr_res // 2 self.down.append(down) self.mid = nn.Module() self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock( config=config, in_channels=block_in, out_channels=block_in, ) self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity() self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( config=config, in_channels=block_in, out_channels=block_in, ) self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = torch.nn.Conv2d( block_in, 2 * latent_channels if double_latent else latent_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, pixel_values: torch.LongTensor): # downsampling hidden_states = [self.conv_in(pixel_values)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): hidden_state = self.down[i_level].block[i_block]( hidden_states[-1], ) if len(self.down[i_level].attn) > 0: hidden_state = self.down[i_level].attn[i_block](hidden_state) hidden_states.append(hidden_state) if i_level != self.num_resolutions - 1: hidden_states.append(self.down[i_level].downsample(hidden_states[-1])) # middle last_hidden_state = hidden_states[-1] last_hidden_state = self.mid.block_1(last_hidden_state) last_hidden_state = self.mid.attn_1(last_hidden_state) last_hidden_state = self.mid.block_2(last_hidden_state) # end last_hidden_state = self.norm_out(last_hidden_state) last_hidden_state *= torch.sigmoid(last_hidden_state) last_hidden_state = self.conv_out(last_hidden_state) return last_hidden_state class ChameleonImageVocabularyMapping: """ A class for mapping discrete image tokens from VQGAN to BPE tokens. """ def __init__(self, vocab_map): self.vocab_map = vocab_map self.image_token_id = vocab_map.get("") @cached_property def val2name(self): return {v: k for k, v in self.vocab_map.items()} @cached_property def image_tokens(self): return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]) @cached_property def bpe2img(self): img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} def remap(old_name: str) -> str: return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]) return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens} @cached_property def img2bpe(self): return {v: k for k, v in self.bpe2img.items()} @cached_property def bpe2img_search_tensors(self): return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values())) @cached_property def img2bpe_mapping_tensor(self): mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) for k, v in self.img2bpe.items(): mapping[k] = v return mapping def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: device = img_batch.device img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] return img_tokens.to(device) @auto_docstring class ChameleonPreTrainedModel(PreTrainedModel): config: ChameleonConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn = True _supports_sdpa = True _can_compile_fullgraph = True _supports_param_buffer_assignment = False _supports_flex_attn = True _supports_attention_backend = True @auto_docstring( custom_intro=""" The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens. This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://huggingface.co/papers/2203.13131). """ ) class ChameleonVQVAE(ChameleonPreTrainedModel): config: ChameleonVQVAEConfig _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] def __init__(self, config: ChameleonVQVAEConfig): super().__init__(config) self.encoder = ChameleonVQVAEEncoder(config) self.quantize = ChameleonVQVAEVectorQuantizer(config) self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) self.eval() # Chameleon's VQ model is frozen def encode(self, pixel_values: torch.LongTensor): hidden_states = self.encoder(pixel_values) hidden_states = self.quant_conv(hidden_states) quant, emb_loss, indices = self.quantize(hidden_states) return quant, emb_loss, indices @auto_docstring class ChameleonModel(ChameleonPreTrainedModel): def __init__(self, config: ChameleonConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map) decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer self.layers = nn.ModuleList( [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.vqmodel = ChameleonVQVAE._from_config(config.vq_config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_image_tokens(self, pixel_values: torch.FloatTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts obtained image tokens into BPE tokens and wraps with "boi" and "eoi" special tokens. Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. """ batch_size = pixel_values.shape[0] _, _, image_toks = self.vqmodel.encode(pixel_values) bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks) bpe_toks = bpe_toks.view(batch_size, -1) return bpe_toks def get_image_features(self, pixel_values: torch.FloatTensor): """ Tokenizes images into discrete tokens with VQGAN module and embeds them with text embeddings layer Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. """ image_tokens = self.get_image_tokens(pixel_values) vision_embeddings = self.get_input_embeddings()(image_tokens) return vision_embeddings @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if pixel_values is not None: if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.vocabulary_mapping.image_token_id n_image_tokens_in_text = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_embeds.numel(): n_image_features = image_embeds.shape[0] * image_embeds.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) # torch.jit.trace() doesn't support cache objects in the output if use_cache and past_key_values is None and not torch.jit.is_tracing(): past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @auto_docstring( custom_intro=""" Chameleon Model with a head on top used for outputting logits for next token prediction. """ ) class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = ChameleonModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def get_image_tokens(self, pixel_values): return self.model.get_image_tokens(pixel_values) def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration >>> import torch >>> import requests >>> from PIL import Image >>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16) >>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") >>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation." >>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw) >>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw) >>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16) >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) # Disallow image tokens which does not include special begin-image and end-image tokens image_tokens = self.model.vocabulary_mapping.image_tokens logits[:, :, image_tokens] = torch.finfo(logits.dtype).min loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, pixel_values=None, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = super().prepare_inputs_for_generation( input_ids, pixel_values=pixel_values, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, **kwargs, ) if cache_position[0] != 0: # If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = None return model_inputs __all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"]