# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/janus/modular_janus.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_janus.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from dataclasses import dataclass from typing import Callable, Optional, Union import torch from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList from ...generation.utils import GenerateDecoderOnlyOutput from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, logging, torch_int, ) from ..auto import AutoModel from .configuration_janus import JanusConfig, JanusVisionConfig, JanusVQVAEConfig if is_torch_available(): import torch.nn.functional as F logger = logging.get_logger(__name__) @auto_docstring class JanusPreTrainedModel(PreTrainedModel): config: JanusConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn = True _supports_sdpa = True _can_compile_fullgraph = True _supports_param_buffer_assignment = False @dataclass @auto_docstring( custom_intro=""" Base class for Janus VQ-VAE mode model outputs. """ ) class JanusVQVAEOutput(ModelOutput): r""" decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): Reconstructed pixel values after encoding and decoding the input. embedding_loss (`torch.FloatTensor`): Embedding loss. """ decoded_pixel_values: Optional[torch.FloatTensor] = None embedding_loss: torch.FloatTensor = None @dataclass @auto_docstring( custom_intro=""" Base class for Janus model's outputs that may also contain a past key/values (to speed up sequential decoding). """ ) class JanusBaseModelOutputWithPast(ModelOutput): r""" last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver """ last_hidden_state: Optional[torch.FloatTensor] = None past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[tuple[torch.FloatTensor]] = None @dataclass @auto_docstring( custom_intro=""" Base class for Janus causal language model (or autoregressive) outputs. """ ) class JanusCausalLMOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[list[torch.FloatTensor]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[tuple[torch.FloatTensor]] = None class JanusVisionEmbeddings(nn.Module): def __init__(self, config: JanusVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing and no class embeddings. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] num_positions = self.position_embedding.weight.shape[0] # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embedding(self.position_ids) patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: _, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(2).transpose(1, 2) if interpolate_pos_encoding: pos_embeds = self.interpolate_pos_encoding(embeddings, height, width) else: pos_embeds = self.position_embedding(self.position_ids) embeddings = embeddings + pos_embeds return embeddings def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class JanusVisionAttention(nn.Module): """Attention Class for Janus Vision Encoder""" def __init__(self, config: JanusVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout proj_dropout = config.projection_dropout qk_norm = config.use_qk_norm self.is_causal = False # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1. self.num_key_value_groups = 1 self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim) self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity() self.q_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity() self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ): batch_size, seq_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) query_states = self.q_norm(query_states) key_states = key_states.reshape(-1, self.num_heads, self.head_dim) key_states = self.k_norm(key_states) query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scale, is_causal=self.is_causal, **kwargs, ) attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim) output = self.projection_layer(attn_output) output = self.projection_dropout(output) return output, attn_weights class JanusVisionMLP(nn.Module): def __init__(self, config: JanusVisionConfig): super().__init__() self.config = config self.intermediate_size = int(config.hidden_size * config.mlp_ratio) self.activation_fn = ACT2FN[config.hidden_act] # Gelu act self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size) self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size) self.dropout1 = nn.Dropout(config.hidden_dropout_rate) self.dropout2 = nn.Dropout(config.hidden_dropout_rate) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.dropout1(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout2(hidden_states) return hidden_states class JanusVisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: JanusVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = JanusVisionAttention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = JanusVisionMLP(config) self.config = config def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class JanusVisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`JanusVisionEncoderLayer`]. Args: config: JanusVisionConfig """ def __init__(self, config: JanusVisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList([JanusVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False # Ignore copy @can_return_tuple def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) @auto_docstring class JanusVisionModel(JanusPreTrainedModel): main_input_name = "pixel_values" config: JanusVisionConfig def __init__(self, config: JanusVisionConfig): super().__init__(config) self.config = config embed_dim = config.hidden_size self.embeddings = JanusVisionEmbeddings(config) self.encoder = JanusVisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_init() @auto_docstring def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> Union[tuple, BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def get_input_embeddings(self): return self.embeddings class JanusVisionAlignerMLP(nn.Module): def __init__(self, config: JanusVisionConfig): super().__init__() self.fc1 = nn.Linear(config.hidden_size, config.projection_dim) self.hidden_layers = nn.ModuleList( [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.depth)] ) self.activation_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): hidden_states = self.fc1(hidden_states) for layer in self.hidden_layers: hidden_states = self.activation_fn(hidden_states) hidden_states = layer(hidden_states) return hidden_states class JanusVQVAEVectorQuantizer(nn.Module): """ A module for vector quantization using learned embedding vectors. This module implements the quantization process similar to te one described in the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous input vectors into discrete codebook vectors, which are learned during training. Current implementation improves over previous ones by avoiding costly matrix multiplications and allowing for post-hoc remapping of indices. """ def __init__(self, config: JanusVQVAEConfig): super().__init__() self.num_embeddings = config.num_embeddings self.embedding_dim = config.embed_dim self.beta = getattr(config, "beta", 0.25) self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) self.quant_state_dims = [config.num_patches] * 2 def forward(self, hidden_state: torch.Tensor): hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z distances = ( torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1)) ) min_encoding_indices = torch.argmin(distances, dim=1) hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape) # compute loss for embedding loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean( (hidden_state_quant - hidden_state.detach()) ** 2 ) # preserve gradients hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach() # reshape back to match original input shape hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() return hidden_state_quant, loss, min_encoding_indices def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor: batch_size = image_tokens.shape[0] emb_dim: int = self.embedding.weight.shape[-1] # get quantized latent vectors hidden_state_quant = self.embedding(image_tokens) # l2 normalization on the last dimension hidden_state_quant = F.normalize(hidden_state_quant, p=2, dim=-1) # reshape back to match original input shape hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim)) hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() return hidden_state_quant class JanusVQVAEResnetBlock(nn.Module): def __init__( self, config, in_channels, out_channels=None, conv_shortcut=False, ): super().__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) self.dropout = torch.nn.Dropout(config.dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, hidden_states): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states *= torch.sigmoid(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states) hidden_states *= torch.sigmoid(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.in_channels != self.out_channels: if self.use_conv_shortcut: residual = self.conv_shortcut(residual) else: residual = self.nin_shortcut(residual) return residual + hidden_states class JanusVQVAEAttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, hidden_states): residual = hidden_states hidden_states = self.norm(hidden_states) query_states = self.q(hidden_states) key_states = self.k(hidden_states) value_states = self.v(hidden_states) # compute attention batch_size, channels, height, width = query_states.shape query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1) key_states = key_states.reshape(batch_size, channels, height * width) attn_weights = torch.bmm(query_states, key_states) attn_weights = attn_weights * (int(channels) ** (-0.5)) attn_weights = F.softmax(attn_weights, dim=2) # attend to values value_states = value_states.reshape(batch_size, channels, height * width) attn_weights = attn_weights.permute(0, 2, 1) attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width) attn_output = self.proj_out(attn_output) return residual + attn_output class JanusVQVAEConvDownsample(nn.Module): def __init__(self, in_channels): super().__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, hidden_states): # no asymmetric padding in torch conv, must do it ourselves hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) hidden_states = self.conv(hidden_states) return hidden_states class JanusVQVAEConvUpsample(nn.Module): def __init__(self, in_channels): super().__init__() self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, hidden_states): hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") hidden_states = self.conv(hidden_states) return hidden_states class JanusVQVAEMidBlock(nn.Module): def __init__(self, config: JanusVQVAEConfig, channels: int): super().__init__() self.block_1 = JanusVQVAEResnetBlock( config=config, in_channels=channels, out_channels=channels, ) self.attn_1 = JanusVQVAEAttnBlock(channels) self.block_2 = JanusVQVAEResnetBlock( config=config, in_channels=channels, out_channels=channels, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.block_1(hidden_states) hidden_states = self.attn_1(hidden_states) hidden_states = self.block_2(hidden_states) return hidden_states class JanusVQVAEEncoder(nn.Module): def __init__(self, config): super().__init__() self.num_resolutions = len(config.channel_multiplier) self.num_res_blocks = config.num_res_blocks base_channels = config.base_channels in_channels = config.in_channels double_latent = config.double_latent latent_channels = config.latent_channels channel_multiplier = config.channel_multiplier self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) in_channel_multiplier = (1,) + tuple(channel_multiplier) self.in_channel_multiplier = in_channel_multiplier self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = base_channels * in_channel_multiplier[i_level] block_out = base_channels * channel_multiplier[i_level] for i_block in range(self.num_res_blocks): block.append( JanusVQVAEResnetBlock( config=config, in_channels=block_in, out_channels=block_out, ) ) block_in = block_out if i_level == self.num_resolutions - 1: attn.append(JanusVQVAEAttnBlock(block_in)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = JanusVQVAEConvDownsample(block_in) self.down.append(down) self.mid = JanusVQVAEMidBlock(config, block_in) self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = torch.nn.Conv2d( block_in, 2 * latent_channels if double_latent else latent_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, pixel_values: torch.LongTensor): # downsampling hidden_states = [self.conv_in(pixel_values)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): hidden_state = self.down[i_level].block[i_block]( hidden_states[-1], ) if len(self.down[i_level].attn) > 0: hidden_state = self.down[i_level].attn[i_block](hidden_state) hidden_states.append(hidden_state) if i_level != self.num_resolutions - 1: hidden_states.append(self.down[i_level].downsample(hidden_states[-1])) # middle last_hidden_state = hidden_states[-1] last_hidden_state = self.mid(last_hidden_state) # end last_hidden_state = self.norm_out(last_hidden_state) last_hidden_state *= torch.sigmoid(last_hidden_state) last_hidden_state = self.conv_out(last_hidden_state) return last_hidden_state class JanusVQVAEDecoder(nn.Module): def __init__(self, config): super().__init__() self.num_resolutions = len(config.channel_multiplier) self.num_res_blocks = config.num_res_blocks base_channels = config.base_channels latent_channels = config.latent_channels out_channels = config.out_channels # compute in_ch_mult, block_in and curr_res at lowest res block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1] # z to block_in self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = JanusVQVAEMidBlock(config, block_in) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = base_channels * config.channel_multiplier[i_level] for i_block in range(self.num_res_blocks + 1): block.append( JanusVQVAEResnetBlock( config=config, in_channels=block_in, out_channels=block_out, ) ) block_in = block_out if i_level == self.num_resolutions - 1: attn.append(JanusVQVAEAttnBlock(block_in)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = JanusVQVAEConvUpsample(block_in) self.up.append(up) # end self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor: hidden_state = self.conv_in(hidden_state) # middle hidden_state = self.mid(hidden_state) # upsampling for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks + 1): hidden_state = self.up[i_level].block[i_block](hidden_state) if len(self.up[i_level].attn) > 0: hidden_state = self.up[i_level].attn[i_block](hidden_state) if i_level != self.num_resolutions - 1: hidden_state = self.up[i_level].upsample(hidden_state) hidden_state = self.norm_out(hidden_state) hidden_state *= torch.sigmoid(hidden_state) hidden_state = self.conv_out(hidden_state) return hidden_state @auto_docstring( custom_intro=""" The VQ-VAE model used in Janus for encoding/decoding images into discrete tokens. This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://huggingface.co/papers/2203.13131). """ ) class JanusVQVAE(JanusPreTrainedModel): config: JanusVQVAEConfig _no_split_modules = [ "JanusVQVAEAttnBlock", "JanusVQVAEResnetBlock", "JanusVQVAEVectorQuantizer", ] main_input_name = "pixel_values" def __init__(self, config: JanusVQVAEConfig): super().__init__(config) self.encoder = JanusVQVAEEncoder(config) self.quantize = JanusVQVAEVectorQuantizer(config) self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) self.eval() # Janus's VQ model is frozen self.decoder = JanusVQVAEDecoder(config) self.gradient_checkpointing = False # Initialize the VQVAE model. self.post_init() def encode(self, pixel_values: torch.LongTensor): hidden_states = self.encoder(pixel_values) hidden_states = self.quant_conv(hidden_states) quant, emb_loss, indices = self.quantize(hidden_states) return quant, emb_loss, indices def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor: """ Decodes quantized token IDs into pixel values. Args: image_tokens (torch.LongTensor): Batch of token IDs. Returns: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): Pixel values decoded from the token IDs. """ if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]: raise ValueError( f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, " f"but got shape `{image_tokens.shape}`." ) codebook_entry = self.quantize.get_codebook_entry(image_tokens) hidden_states = self.post_quant_conv(codebook_entry) pixel_values = self.decoder(hidden_states) return pixel_values @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.FloatTensor, ) -> tuple[torch.FloatTensor, torch.FloatTensor]: batch_size = pixel_values.shape[0] quant, embedding_loss, indices = self.encode(pixel_values) decoded_pixel_values = self.decode(indices.view(batch_size, -1)) return JanusVQVAEOutput(decoded_pixel_values, embedding_loss) class JanusVQVAEAlignerMLP(nn.Module): def __init__(self, config: JanusVQVAEConfig): super().__init__() self.fc1 = nn.Linear(config.embed_dim, config.projection_dim) self.hidden_layers = nn.ModuleList( [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.num_hidden_layers)] ) self.activation_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): hidden_states = self.fc1(hidden_states) for layer in self.hidden_layers: hidden_states = self.activation_fn(hidden_states) hidden_states = layer(hidden_states) return hidden_states class JanusVQVAEHead(nn.Module): """Head used for sampling tokens in image generation, replacing the usual lm head.""" def __init__(self, config: JanusVQVAEConfig): super().__init__() self.proj_out = nn.Linear(config.image_token_embed_dim, config.projection_dim) self.activation_fn = ACT2FN[config.hidden_act] self.vision_head = nn.Linear(config.projection_dim, config.num_embeddings) def forward(self, hidden_states: torch.Tensor) -> torch.tensor: hidden_states = self.proj_out(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.vision_head(hidden_states) return hidden_states @auto_docstring( custom_intro=""" The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model. """ ) class JanusModel(JanusPreTrainedModel): def __init__(self, config: JanusConfig): super().__init__(config) self.config = config # This is necessary for backward compatibility, see SiglipModel initialization self.vision_model = JanusVisionModel._from_config(config.vision_config) self.aligner = JanusVisionAlignerMLP(self.vision_model.config) self.vqmodel = JanusVQVAE._from_config(config.vq_config) # Below generation_* modules are used for Image generation. # Embeddings used for image generation, instead of Janus vision embeddings. self.generation_embeddings = nn.Embedding(self.vqmodel.config.num_embeddings, self.vqmodel.config.embed_dim) self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config) self.generation_head = JanusVQVAEHead(self.vqmodel.config) self.language_model = AutoModel.from_config(config=config.text_config) self.gradient_checkpointing = False # Initialize weights and apply final processing. self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_image_features(self, pixel_values): image_embeds = self.vision_model(pixel_values) image_embeds = self.aligner(image_embeds.last_hidden_state) return image_embeds @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: if input_ids is None: image_attention_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) image_attention_mask = image_attention_mask.all(-1) else: image_attention_mask = input_ids == self.config.image_token_id image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values) image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1]) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) lm_output = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) return JanusBaseModelOutputWithPast( last_hidden_state=lm_output.last_hidden_state, past_key_values=lm_output.past_key_values, hidden_states=lm_output.hidden_states, attentions=lm_output.attentions, image_hidden_states=image_embeds if pixel_values is not None else None, ) class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] _can_compile_fullgraph = True def __init__(self, config: JanusConfig): super().__init__(config) self.config = config self.model = JanusModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) # Initialize weights and apply final processing. self.post_init() def get_input_embeddings(self): return self.model.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.model.language_model.set_input_embeddings(value) def prepare_embeddings_for_image_generation(self, inputs: torch.Tensor) -> torch.Tensor: hidden_state = self.model.generation_embeddings(inputs) hidden_state = self.model.generation_aligner(hidden_state) return hidden_state def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) return JanusCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, pixel_values=None, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, logits_to_keep=None, **kwargs, ): # Overwritten -- extra custom processing model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values return model_inputs def decode_image_tokens(self, image_tokens: torch.Tensor): """ Decodes generated image tokens from language model to continuous pixel values with VQGAN module via upsampling. Args: image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`): The tensors corresponding to the input images. """ decoded_image = self.model.vqmodel.decode(image_tokens) decoded_image = decoded_image.permute(0, 2, 3, 1) return decoded_image @torch.no_grad def generate( self, inputs: torch.Tensor = None, attention_mask: Optional[torch.LongTensor] = None, logits_processor: Optional[LogitsProcessorList] = None, **kwargs, ): # 1. Handle generation config and model kwargs generation_config = kwargs.pop("generation_config", self.generation_config) generation_config = copy.deepcopy(generation_config) # Default to "text" generation if mode isn't provided generation_mode = kwargs.pop("generation_mode", "text") if generation_mode == "text": # Set guidance_scale=None to prevent running UnbatchedCFG processor. return super().generate( inputs=inputs, attention_mask=attention_mask, generation_config=generation_config, guidance_scale=None, **kwargs, ) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs # Validate generation mode if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): raise ValueError( "Got incompatible mode for Image Generation, should be one of greedy or sampling. " "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." ) # Validate the configuration and model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Initialize logit processors logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() # Set `use_cache=True` as we will be using input embeds for generation. model_kwargs["use_cache"] = True if generation_config.guidance_scale is None: logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.") generation_config.guidance_scale = 5 model_kwargs["guidance_scale"] = generation_config.guidance_scale # 3. Prepare model inputs input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) dtype, device = input_ids.dtype, input_ids.device if len(input_ids.shape) != 2: raise ValueError( f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}" "Passing `inputs embeds` is not supported currently." ) # Prepare special tokens which will be used generate internally. kwargs_has_attention_mask = attention_mask is not None self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) # 4. Add CFG processor along with user passed logit processor. if generation_config.guidance_scale and generation_config.guidance_scale > 1: logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) generation_config.guidance_scale = None # Reset to prevent processor duplication. # 5. Prepare logits processor logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids.shape[1], encoder_input_ids=input_ids, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, device=device, ) # 6. Expand inputs for multiple image generations per prompt. input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, attention_mask=attention_mask, expand_size=generation_config.num_return_sequences, **model_kwargs, ) # 7. Prepare input and model caches num_image_tokens = self.model.vision_model.config.num_image_tokens batch_size, seq_len = input_ids.shape input_tokens = input_ids.repeat(2, 1) # Double batch size for conditional/unconditional logits attention_mask = model_kwargs.pop("attention_mask", None) attention_mask = attention_mask.repeat(2, 1) model_kwargs["attention_mask"] = attention_mask # Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits. mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & ( input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"] ) input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id) inputs_embeds = self.get_input_embeddings()(input_tokens) model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs) if model_kwargs.get("past_key_values", None) is None: # Prepare cache if not provided. model_kwargs["past_key_values"] = self._get_cache( cache_implementation=generation_config.cache_implementation or "static", # batch_size should account for both conditional/unconditional input; hence multiplied by 2. batch_size=batch_size * 2, # we should have at least a cache len of seq_len + num_image_tokens. max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len), device=device, model_kwargs=model_kwargs, ) # Placeholder for generated tokens. generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device) # 8. init attention / hidden states / scores tuples output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate raw_scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None for i in range(num_image_tokens): model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs ) model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device) outputs = self.model.language_model( **model_inputs, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # Update model_kwargs like cache_position for next generation. model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs) hidden_state = outputs.last_hidden_state[:, -1, :].clone() # Generate scores using the generation head (Not using above defined LM Head) scores = self.model.generation_head(hidden_state) next_token_scores = logits_processor(input_ids, scores) # Sample next token. if generation_config.do_sample: probs = torch.softmax(next_token_scores, dim=-1) next_token = torch.multinomial(probs, num_samples=1).squeeze(-1) else: next_token = torch.argmax(next_token_scores, dim=-1) generated_tokens[:, i] = next_token # Prepare embeddings for the next step. next_token = torch.cat([next_token, next_token]) next_token = next_token.unsqueeze(-1) inputs_embeds = self.prepare_embeddings_for_image_generation(next_token) if return_dict_in_generate: if output_scores: raw_scores += (scores,) if output_logits: raw_logits += (hidden_state.float(),) if output_attentions: decoder_attentions += outputs.attentions if output_hidden_states: decoder_hidden_states += outputs.hidden_states if return_dict_in_generate: return GenerateDecoderOnlyOutput( sequences=generated_tokens, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=outputs.past_key_values, ) else: return generated_tokens __all__ = ["JanusPreTrainedModel", "JanusForConditionalGeneration", "JanusModel", "JanusVQVAE", "JanusVisionModel"]