team-10/env/Lib/site-packages/transformers/models/janus/modeling_janus.py
2025-08-02 07:34:44 +02:00

1424 lines
59 KiB
Python

# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/janus/modular_janus.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_janus.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from dataclasses import dataclass
from typing import Callable, Optional, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
from ...generation.utils import GenerateDecoderOnlyOutput
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_torch_available,
logging,
torch_int,
)
from ..auto import AutoModel
from .configuration_janus import JanusConfig, JanusVisionConfig, JanusVQVAEConfig
if is_torch_available():
import torch.nn.functional as F
logger = logging.get_logger(__name__)
@auto_docstring
class JanusPreTrainedModel(PreTrainedModel):
config: JanusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn = True
_supports_sdpa = True
_can_compile_fullgraph = True
_supports_param_buffer_assignment = False
@dataclass
@auto_docstring(
custom_intro="""
Base class for Janus VQ-VAE mode model outputs.
"""
)
class JanusVQVAEOutput(ModelOutput):
r"""
decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
Reconstructed pixel values after encoding and decoding the input.
embedding_loss (`torch.FloatTensor`):
Embedding loss.
"""
decoded_pixel_values: Optional[torch.FloatTensor] = None
embedding_loss: torch.FloatTensor = None
@dataclass
@auto_docstring(
custom_intro="""
Base class for Janus model's outputs that may also contain a past key/values (to speed up sequential decoding).
"""
)
class JanusBaseModelOutputWithPast(ModelOutput):
r"""
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""
last_hidden_state: Optional[torch.FloatTensor] = None
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
@dataclass
@auto_docstring(
custom_intro="""
Base class for Janus causal language model (or autoregressive) outputs.
"""
)
class JanusCausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[list[torch.FloatTensor]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
class JanusVisionEmbeddings(nn.Module):
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing and no class embeddings.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1]
num_positions = self.position_embedding.weight.shape[0]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embedding(self.position_ids)
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
_, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
if interpolate_pos_encoding:
pos_embeds = self.interpolate_pos_encoding(embeddings, height, width)
else:
pos_embeds = self.position_embedding(self.position_ids)
embeddings = embeddings + pos_embeds
return embeddings
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class JanusVisionAttention(nn.Module):
"""Attention Class for Janus Vision Encoder"""
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
proj_dropout = config.projection_dropout
qk_norm = config.use_qk_norm
self.is_causal = False
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
self.num_key_value_groups = 1
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
self.q_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
):
batch_size, seq_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
query_states = self.q_norm(query_states)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
key_states = self.k_norm(key_states)
query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
output = self.projection_layer(attn_output)
output = self.projection_dropout(output)
return output, attn_weights
class JanusVisionMLP(nn.Module):
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.config = config
self.intermediate_size = int(config.hidden_size * config.mlp_ratio)
self.activation_fn = ACT2FN[config.hidden_act] # Gelu act
self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size)
self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size)
self.dropout1 = nn.Dropout(config.hidden_dropout_rate)
self.dropout2 = nn.Dropout(config.hidden_dropout_rate)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.dropout1(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout2(hidden_states)
return hidden_states
class JanusVisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.self_attn = JanusVisionAttention(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = JanusVisionMLP(config)
self.config = config
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class JanusVisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`JanusVisionEncoderLayer`].
Args:
config: JanusVisionConfig
"""
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([JanusVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
@can_return_tuple
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
@auto_docstring
class JanusVisionModel(JanusPreTrainedModel):
main_input_name = "pixel_values"
config: JanusVisionConfig
def __init__(self, config: JanusVisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = JanusVisionEmbeddings(config)
self.encoder = JanusVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.post_init()
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings
class JanusVisionAlignerMLP(nn.Module):
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.fc1 = nn.Linear(config.hidden_size, config.projection_dim)
self.hidden_layers = nn.ModuleList(
[nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.depth)]
)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
for layer in self.hidden_layers:
hidden_states = self.activation_fn(hidden_states)
hidden_states = layer(hidden_states)
return hidden_states
class JanusVQVAEVectorQuantizer(nn.Module):
"""
A module for vector quantization using learned embedding vectors.
This module implements the quantization process similar to te one described in
the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
input vectors into discrete codebook vectors, which are learned during training.
Current implementation improves over previous ones by avoiding costly matrix multiplications
and allowing for post-hoc remapping of indices.
"""
def __init__(self, config: JanusVQVAEConfig):
super().__init__()
self.num_embeddings = config.num_embeddings
self.embedding_dim = config.embed_dim
self.beta = getattr(config, "beta", 0.25)
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.quant_state_dims = [config.num_patches] * 2
def forward(self, hidden_state: torch.Tensor):
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
distances = (
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
)
min_encoding_indices = torch.argmin(distances, dim=1)
hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
# compute loss for embedding
loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
(hidden_state_quant - hidden_state.detach()) ** 2
)
# preserve gradients
hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
# reshape back to match original input shape
hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
return hidden_state_quant, loss, min_encoding_indices
def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
batch_size = image_tokens.shape[0]
emb_dim: int = self.embedding.weight.shape[-1]
# get quantized latent vectors
hidden_state_quant = self.embedding(image_tokens)
# l2 normalization on the last dimension
hidden_state_quant = F.normalize(hidden_state_quant, p=2, dim=-1)
# reshape back to match original input shape
hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim))
hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
return hidden_state_quant
class JanusVQVAEResnetBlock(nn.Module):
def __init__(
self,
config,
in_channels,
out_channels=None,
conv_shortcut=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.dropout = torch.nn.Dropout(config.dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return residual + hidden_states
class JanusVQVAEAttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.norm(hidden_states)
query_states = self.q(hidden_states)
key_states = self.k(hidden_states)
value_states = self.v(hidden_states)
# compute attention
batch_size, channels, height, width = query_states.shape
query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
key_states = key_states.reshape(batch_size, channels, height * width)
attn_weights = torch.bmm(query_states, key_states)
attn_weights = attn_weights * (int(channels) ** (-0.5))
attn_weights = F.softmax(attn_weights, dim=2)
# attend to values
value_states = value_states.reshape(batch_size, channels, height * width)
attn_weights = attn_weights.permute(0, 2, 1)
attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
attn_output = self.proj_out(attn_output)
return residual + attn_output
class JanusVQVAEConvDownsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, hidden_states):
# no asymmetric padding in torch conv, must do it ourselves
hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states
class JanusVQVAEConvUpsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_states):
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
hidden_states = self.conv(hidden_states)
return hidden_states
class JanusVQVAEMidBlock(nn.Module):
def __init__(self, config: JanusVQVAEConfig, channels: int):
super().__init__()
self.block_1 = JanusVQVAEResnetBlock(
config=config,
in_channels=channels,
out_channels=channels,
)
self.attn_1 = JanusVQVAEAttnBlock(channels)
self.block_2 = JanusVQVAEResnetBlock(
config=config,
in_channels=channels,
out_channels=channels,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.block_1(hidden_states)
hidden_states = self.attn_1(hidden_states)
hidden_states = self.block_2(hidden_states)
return hidden_states
class JanusVQVAEEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.num_resolutions = len(config.channel_multiplier)
self.num_res_blocks = config.num_res_blocks
base_channels = config.base_channels
in_channels = config.in_channels
double_latent = config.double_latent
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier
self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
in_channel_multiplier = (1,) + tuple(channel_multiplier)
self.in_channel_multiplier = in_channel_multiplier
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = base_channels * in_channel_multiplier[i_level]
block_out = base_channels * channel_multiplier[i_level]
for i_block in range(self.num_res_blocks):
block.append(
JanusVQVAEResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_out,
)
)
block_in = block_out
if i_level == self.num_resolutions - 1:
attn.append(JanusVQVAEAttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = JanusVQVAEConvDownsample(block_in)
self.down.append(down)
self.mid = JanusVQVAEMidBlock(config, block_in)
self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * latent_channels if double_latent else latent_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, pixel_values: torch.LongTensor):
# downsampling
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block](
hidden_states[-1],
)
if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block](hidden_state)
hidden_states.append(hidden_state)
if i_level != self.num_resolutions - 1:
hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
# middle
last_hidden_state = hidden_states[-1]
last_hidden_state = self.mid(last_hidden_state)
# end
last_hidden_state = self.norm_out(last_hidden_state)
last_hidden_state *= torch.sigmoid(last_hidden_state)
last_hidden_state = self.conv_out(last_hidden_state)
return last_hidden_state
class JanusVQVAEDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.num_resolutions = len(config.channel_multiplier)
self.num_res_blocks = config.num_res_blocks
base_channels = config.base_channels
latent_channels = config.latent_channels
out_channels = config.out_channels
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1]
# z to block_in
self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = JanusVQVAEMidBlock(config, block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = base_channels * config.channel_multiplier[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
JanusVQVAEResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_out,
)
)
block_in = block_out
if i_level == self.num_resolutions - 1:
attn.append(JanusVQVAEAttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = JanusVQVAEConvUpsample(block_in)
self.up.append(up)
# end
self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor:
hidden_state = self.conv_in(hidden_state)
# middle
hidden_state = self.mid(hidden_state)
# upsampling
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks + 1):
hidden_state = self.up[i_level].block[i_block](hidden_state)
if len(self.up[i_level].attn) > 0:
hidden_state = self.up[i_level].attn[i_block](hidden_state)
if i_level != self.num_resolutions - 1:
hidden_state = self.up[i_level].upsample(hidden_state)
hidden_state = self.norm_out(hidden_state)
hidden_state *= torch.sigmoid(hidden_state)
hidden_state = self.conv_out(hidden_state)
return hidden_state
@auto_docstring(
custom_intro="""
The VQ-VAE model used in Janus for encoding/decoding images into discrete tokens.
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
Taigman](https://huggingface.co/papers/2203.13131).
"""
)
class JanusVQVAE(JanusPreTrainedModel):
config: JanusVQVAEConfig
_no_split_modules = [
"JanusVQVAEAttnBlock",
"JanusVQVAEResnetBlock",
"JanusVQVAEVectorQuantizer",
]
main_input_name = "pixel_values"
def __init__(self, config: JanusVQVAEConfig):
super().__init__(config)
self.encoder = JanusVQVAEEncoder(config)
self.quantize = JanusVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Janus's VQ model is frozen
self.decoder = JanusVQVAEDecoder(config)
self.gradient_checkpointing = False
# Initialize the VQVAE model.
self.post_init()
def encode(self, pixel_values: torch.LongTensor):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
return quant, emb_loss, indices
def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
"""
Decodes quantized token IDs into pixel values.
Args:
image_tokens (torch.LongTensor): Batch of token IDs.
Returns:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
Pixel values decoded from the token IDs.
"""
if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]:
raise ValueError(
f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, "
f"but got shape `{image_tokens.shape}`."
)
codebook_entry = self.quantize.get_codebook_entry(image_tokens)
hidden_states = self.post_quant_conv(codebook_entry)
pixel_values = self.decoder(hidden_states)
return pixel_values
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
batch_size = pixel_values.shape[0]
quant, embedding_loss, indices = self.encode(pixel_values)
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
class JanusVQVAEAlignerMLP(nn.Module):
def __init__(self, config: JanusVQVAEConfig):
super().__init__()
self.fc1 = nn.Linear(config.embed_dim, config.projection_dim)
self.hidden_layers = nn.ModuleList(
[nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.num_hidden_layers)]
)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
for layer in self.hidden_layers:
hidden_states = self.activation_fn(hidden_states)
hidden_states = layer(hidden_states)
return hidden_states
class JanusVQVAEHead(nn.Module):
"""Head used for sampling tokens in image generation, replacing the usual lm head."""
def __init__(self, config: JanusVQVAEConfig):
super().__init__()
self.proj_out = nn.Linear(config.image_token_embed_dim, config.projection_dim)
self.activation_fn = ACT2FN[config.hidden_act]
self.vision_head = nn.Linear(config.projection_dim, config.num_embeddings)
def forward(self, hidden_states: torch.Tensor) -> torch.tensor:
hidden_states = self.proj_out(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.vision_head(hidden_states)
return hidden_states
@auto_docstring(
custom_intro="""
The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.
"""
)
class JanusModel(JanusPreTrainedModel):
def __init__(self, config: JanusConfig):
super().__init__(config)
self.config = config
# This is necessary for backward compatibility, see SiglipModel initialization
self.vision_model = JanusVisionModel._from_config(config.vision_config)
self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
self.vqmodel = JanusVQVAE._from_config(config.vq_config)
# Below generation_* modules are used for Image generation.
# Embeddings used for image generation, instead of Janus vision embeddings.
self.generation_embeddings = nn.Embedding(self.vqmodel.config.num_embeddings, self.vqmodel.config.embed_dim)
self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
self.generation_head = JanusVQVAEHead(self.vqmodel.config)
self.language_model = AutoModel.from_config(config=config.text_config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing.
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_image_features(self, pixel_values):
image_embeds = self.vision_model(pixel_values)
image_embeds = self.aligner(image_embeds.last_hidden_state)
return image_embeds
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
if input_ids is None:
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_attention_mask = image_attention_mask.all(-1)
else:
image_attention_mask = input_ids == self.config.image_token_id
image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values)
image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
lm_output = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
return JanusBaseModelOutputWithPast(
last_hidden_state=lm_output.last_hidden_state,
past_key_values=lm_output.past_key_values,
hidden_states=lm_output.hidden_states,
attentions=lm_output.attentions,
image_hidden_states=image_embeds if pixel_values is not None else None,
)
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
_can_compile_fullgraph = True
def __init__(self, config: JanusConfig):
super().__init__(config)
self.config = config
self.model = JanusModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
# Initialize weights and apply final processing.
self.post_init()
def get_input_embeddings(self):
return self.model.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.language_model.set_input_embeddings(value)
def prepare_embeddings_for_image_generation(self, inputs: torch.Tensor) -> torch.Tensor:
hidden_state = self.model.generation_embeddings(inputs)
hidden_state = self.model.generation_aligner(hidden_state)
return hidden_state
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return JanusCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
pixel_values=None,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- extra custom processing
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
return model_inputs
def decode_image_tokens(self, image_tokens: torch.Tensor):
"""
Decodes generated image tokens from language model to continuous pixel values
with VQGAN module via upsampling.
Args:
image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
The tensors corresponding to the input images.
"""
decoded_image = self.model.vqmodel.decode(image_tokens)
decoded_image = decoded_image.permute(0, 2, 3, 1)
return decoded_image
@torch.no_grad
def generate(
self,
inputs: torch.Tensor = None,
attention_mask: Optional[torch.LongTensor] = None,
logits_processor: Optional[LogitsProcessorList] = None,
**kwargs,
):
# 1. Handle generation config and model kwargs
generation_config = kwargs.pop("generation_config", self.generation_config)
generation_config = copy.deepcopy(generation_config)
# Default to "text" generation if mode isn't provided
generation_mode = kwargs.pop("generation_mode", "text")
if generation_mode == "text":
# Set guidance_scale=None to prevent running UnbatchedCFG processor.
return super().generate(
inputs=inputs,
attention_mask=attention_mask,
generation_config=generation_config,
guidance_scale=None,
**kwargs,
)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# Validate generation mode
if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
raise ValueError(
"Got incompatible mode for Image Generation, should be one of greedy or sampling. "
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
)
# Validate the configuration and model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Initialize logit processors
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
# Set `use_cache=True` as we will be using input embeds for generation.
model_kwargs["use_cache"] = True
if generation_config.guidance_scale is None:
logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.")
generation_config.guidance_scale = 5
model_kwargs["guidance_scale"] = generation_config.guidance_scale
# 3. Prepare model inputs
input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
dtype, device = input_ids.dtype, input_ids.device
if len(input_ids.shape) != 2:
raise ValueError(
f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}"
"Passing `inputs embeds` is not supported currently."
)
# Prepare special tokens which will be used generate internally.
kwargs_has_attention_mask = attention_mask is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Add CFG processor along with user passed logit processor.
if generation_config.guidance_scale and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None # Reset to prevent processor duplication.
# 5. Prepare logits processor
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids.shape[1],
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=device,
)
# 6. Expand inputs for multiple image generations per prompt.
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
attention_mask=attention_mask,
expand_size=generation_config.num_return_sequences,
**model_kwargs,
)
# 7. Prepare input and model caches
num_image_tokens = self.model.vision_model.config.num_image_tokens
batch_size, seq_len = input_ids.shape
input_tokens = input_ids.repeat(2, 1) # Double batch size for conditional/unconditional logits
attention_mask = model_kwargs.pop("attention_mask", None)
attention_mask = attention_mask.repeat(2, 1)
model_kwargs["attention_mask"] = attention_mask
# Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
)
input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)
inputs_embeds = self.get_input_embeddings()(input_tokens)
model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.
model_kwargs["past_key_values"] = self._get_cache(
cache_implementation=generation_config.cache_implementation or "static",
# batch_size should account for both conditional/unconditional input; hence multiplied by 2.
batch_size=batch_size * 2,
# we should have at least a cache len of seq_len + num_image_tokens.
max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
device=device,
model_kwargs=model_kwargs,
)
# Placeholder for generated tokens.
generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device)
# 8. init attention / hidden states / scores tuples
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
raw_scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
for i in range(num_image_tokens):
model_inputs = self.prepare_inputs_for_generation(
inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
)
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device)
outputs = self.model.language_model(
**model_inputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# Update model_kwargs like cache_position for next generation.
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
hidden_state = outputs.last_hidden_state[:, -1, :].clone()
# Generate scores using the generation head (Not using above defined LM Head)
scores = self.model.generation_head(hidden_state)
next_token_scores = logits_processor(input_ids, scores)
# Sample next token.
if generation_config.do_sample:
probs = torch.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
else:
next_token = torch.argmax(next_token_scores, dim=-1)
generated_tokens[:, i] = next_token
# Prepare embeddings for the next step.
next_token = torch.cat([next_token, next_token])
next_token = next_token.unsqueeze(-1)
inputs_embeds = self.prepare_embeddings_for_image_generation(next_token)
if return_dict_in_generate:
if output_scores:
raw_scores += (scores,)
if output_logits:
raw_logits += (hidden_state.float(),)
if output_attentions:
decoder_attentions += outputs.attentions
if output_hidden_states:
decoder_hidden_states += outputs.hidden_states
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=generated_tokens,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=outputs.past_key_values,
)
else:
return generated_tokens
__all__ = ["JanusPreTrainedModel", "JanusForConditionalGeneration", "JanusModel", "JanusVQVAE", "JanusVisionModel"]