team-10/env/Lib/site-packages/transformers/models/chameleon/modeling_chameleon.py
2025-08-02 07:34:44 +02:00

1156 lines
49 KiB
Python

# coding=utf-8
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Chameleon model."""
from functools import cached_property
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_torchdynamo_compiling,
logging,
)
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
logger = logging.get_logger(__name__)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon
class ChameleonRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
ChameleonRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def forward(self, x, position_ids):
# difference to the original RoPE: a scaling factor is applied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin
class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def forward(self, x, position_ids):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
cos, sin = super().forward(x, position_ids)
return cos, sin
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon
class ChameleonMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
# Ignore copy
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class ChameleonLayerNorm(nn.LayerNorm):
"""
LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
from each shard separately to each head, instead of reducing. We can apply each head's own
gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
"""
def __init__(self, hidden_size, *args, **kwargs):
super().__init__(hidden_size, *args, **kwargs)
self.normalized_shape = (hidden_size[-1],)
def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
hidden_states = hidden_states * self.weight + self.bias
return hidden_states
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class ChameleonAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: ChameleonConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.model_parallel_size = config.model_parallel_size
self.scaling = self.head_dim**-0.5
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
self._init_rope()
# copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
# TODO(joao): add me back asap :)
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = ChameleonRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = ChameleonLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = ChameleonDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
query_states = self.q_norm(query_states)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
key_states = self.k_norm(key_states)
query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
class ChameleonDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
self.mlp = ChameleonMLP(config)
self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class ChameleonSwinDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
self.mlp = ChameleonMLP(config)
self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`):
input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
"""
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.input_layernorm(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class ChameleonVQVAEVectorQuantizer(nn.Module):
"""
A module for vector quantization using learned embedding vectors.
This module implements the quantization process similar to te one described in
the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
input vectors into discrete codebook vectors, which are learned during training.
Current implementation improves over previous ones by avoiding costly matrix multiplications
and allowing for post-hoc remapping of indices.
"""
def __init__(self, config):
super().__init__()
self.num_embeddings = config.num_embeddings
self.embedding_dim = config.embed_dim
self.beta = getattr(config, "beta", 0.25)
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
def forward(self, hidden_state: torch.Tensor):
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
distances = (
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
)
min_encoding_indices = torch.argmin(distances, dim=1)
hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
# compute loss for embedding
loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
(hidden_state_quant - hidden_state.detach()) ** 2
)
# preserve gradients
hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
# reshape back to match original input shape
hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
return hidden_state_quant, loss, min_encoding_indices
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, hidden_states):
# no asymmetric padding in torch conv, must do it ourselves
hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
def __init__(
self,
config,
in_channels,
out_channels=None,
conv_shortcut=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.dropout = torch.nn.Dropout(config.dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return residual + hidden_states
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.norm(hidden_states)
query_states = self.q(hidden_states)
key_states = self.k(hidden_states)
value_states = self.v(hidden_states)
# compute attention
batch_size, channels, height, width = query_states.shape
query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
key_states = key_states.reshape(batch_size, channels, height * width)
attn_weights = torch.bmm(query_states, key_states)
attn_weights = attn_weights * (int(channels) ** (-0.5))
attn_weights = F.softmax(attn_weights, dim=2)
# attend to values
value_states = value_states.reshape(batch_size, channels, height * width)
attn_weights = attn_weights.permute(0, 2, 1)
attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
attn_output = self.proj_out(attn_output)
return residual + attn_output
class ChameleonVQVAEEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.num_resolutions = len(config.channel_multiplier)
self.num_res_blocks = config.num_res_blocks
base_channels = config.base_channels
resolution = config.resolution
in_channels = config.in_channels
double_latent = config.double_latent
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier
self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_channel_multiplier = (1,) + tuple(channel_multiplier)
self.in_channel_multiplier = in_channel_multiplier
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = base_channels * in_channel_multiplier[i_level]
block_out = base_channels * channel_multiplier[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_out,
)
)
block_in = block_out
if (
config.attn_resolutions is not None
and curr_res in config.attn_resolutions
and config.attn_type == "vanilla"
):
attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_in,
)
self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity()
self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_in,
)
self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * latent_channels if double_latent else latent_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, pixel_values: torch.LongTensor):
# downsampling
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block](
hidden_states[-1],
)
if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block](hidden_state)
hidden_states.append(hidden_state)
if i_level != self.num_resolutions - 1:
hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
# middle
last_hidden_state = hidden_states[-1]
last_hidden_state = self.mid.block_1(last_hidden_state)
last_hidden_state = self.mid.attn_1(last_hidden_state)
last_hidden_state = self.mid.block_2(last_hidden_state)
# end
last_hidden_state = self.norm_out(last_hidden_state)
last_hidden_state *= torch.sigmoid(last_hidden_state)
last_hidden_state = self.conv_out(last_hidden_state)
return last_hidden_state
class ChameleonImageVocabularyMapping:
"""
A class for mapping discrete image tokens from VQGAN to BPE tokens.
"""
def __init__(self, vocab_map):
self.vocab_map = vocab_map
self.image_token_id = vocab_map.get("<image>")
@cached_property
def val2name(self):
return {v: k for k, v in self.vocab_map.items()}
@cached_property
def image_tokens(self):
return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")])
@cached_property
def bpe2img(self):
img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
def remap(old_name: str) -> str:
return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1])
return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
@cached_property
def img2bpe(self):
return {v: k for k, v in self.bpe2img.items()}
@cached_property
def bpe2img_search_tensors(self):
return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values()))
@cached_property
def img2bpe_mapping_tensor(self):
mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
for k, v in self.img2bpe.items():
mapping[k] = v
return mapping
def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
device = img_batch.device
img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
return img_tokens.to(device)
@auto_docstring
class ChameleonPreTrainedModel(PreTrainedModel):
config: ChameleonConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn = True
_supports_sdpa = True
_can_compile_fullgraph = True
_supports_param_buffer_assignment = False
_supports_flex_attn = True
_supports_attention_backend = True
@auto_docstring(
custom_intro="""
The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
Taigman](https://huggingface.co/papers/2203.13131).
"""
)
class ChameleonVQVAE(ChameleonPreTrainedModel):
config: ChameleonVQVAEConfig
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__(config)
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
def encode(self, pixel_values: torch.LongTensor):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
return quant, emb_loss, indices
@auto_docstring
class ChameleonModel(ChameleonPreTrainedModel):
def __init__(self, config: ChameleonConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer
self.layers = nn.ModuleList(
[decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_image_tokens(self, pixel_values: torch.FloatTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
special tokens.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images.
"""
batch_size = pixel_values.shape[0]
_, _, image_toks = self.vqmodel.encode(pixel_values)
bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
bpe_toks = bpe_toks.view(batch_size, -1)
return bpe_toks
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Tokenizes images into discrete tokens with VQGAN module and embeds
them with text embeddings layer
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images.
"""
image_tokens = self.get_image_tokens(pixel_values)
vision_embeddings = self.get_input_embeddings()(image_tokens)
return vision_embeddings
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if pixel_values is not None:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
n_image_tokens_in_text = special_image_mask.sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_embeds.numel():
n_image_features = image_embeds.shape[0] * image_embeds.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@auto_docstring(
custom_intro="""
Chameleon Model with a head on top used for outputting logits for next token prediction.
"""
)
class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = ChameleonModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def get_image_tokens(self, pixel_values):
return self.model.get_image_tokens(pixel_values)
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
>>> import torch
>>> import requests
>>> from PIL import Image
>>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16)
>>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
>>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
>>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
>>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)
>>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
>>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
>>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# Disallow image tokens which does not include special begin-image and end-image tokens
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, :, image_tokens] = torch.finfo(logits.dtype).min
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
pixel_values=None,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
pixel_values=pixel_values,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
**kwargs,
)
if cache_position[0] != 0:
# If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = None
return model_inputs
__all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"]