team-10/env/Lib/site-packages/transformers/models/janus/modular_janus.py
2025-08-02 07:34:44 +02:00

1594 lines
66 KiB
Python

# coding=utf-8
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Callable, Optional, Union
import numpy as np
import torch
from torch import nn
from transformers.models.blip.image_processing_blip import BlipImageProcessor
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
from ...generation.utils import GenerateDecoderOnlyOutput
from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
)
from ...modeling_outputs import ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_torch_available,
is_vision_available,
logging,
)
from ..auto import AutoModel
from ..blip_2.modeling_blip_2 import Blip2VisionModel
from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig
from ..chameleon.modeling_chameleon import (
ChameleonVQVAE,
ChameleonVQVAEEncoderAttnBlock,
ChameleonVQVAEEncoderConvDownsample,
ChameleonVQVAEEncoderResnetBlock,
ChameleonVQVAEVectorQuantizer,
)
from ..idefics.modeling_idefics import IdeficsBaseModelOutputWithPast, IdeficsCausalLMOutputWithPast
from ..llama.modeling_llama import eager_attention_forward
from ..siglip.configuration_siglip import SiglipVisionConfig
from ..siglip.modeling_siglip import SiglipEncoder, SiglipEncoderLayer, SiglipVisionEmbeddings
if is_torch_available():
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
if is_vision_available():
import PIL
from ...configuration_utils import PretrainedConfig
from ..auto import CONFIG_MAPPING, AutoConfig
logger = logging.get_logger(__name__)
# General docstring
class JanusVisionConfig(SiglipVisionConfig):
r"""
This is the configuration class to store the configuration of a [`JanusVisionModel`]. It is used to instantiate a
`JanusVisionModel` according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
image_size (`int`, *optional*, defaults to 384):
The size (resolution) of each image.
attention_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for attention weights.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"`, and `"gelu_new"` are supported.
mlp_ratio (`float`, *optional*, defaults to 4.0):
Ratio of MLP hidden dimensionality to embedding dimensionality.
attention_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys, and values in the attention layers.
hidden_dropout_rate (`float`, *optional*, defaults to 0.0):
The dropout probability for fully connected layers in the encoder.
projection_dim (`int`, *optional*, defaults to 2048):
Dimensionality of the MLP projection head.
projection_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for the projection layer.
use_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to normalize the query and key matrices.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated normal initializer for initializing all weight matrices.
depth (`int`, *optional*, defaults to 2):
Number of hidden layers in the aligner module.
num_image_tokens (`int`, *optional*, defaults to 576):
Number of image tokens.
"""
model_type = "janus_vision_model"
base_config_key = "vision_config"
def __init__(
self,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
num_channels=3,
patch_size=16,
image_size=384,
attention_dropout=0.0,
layer_norm_eps=1e-6,
hidden_act="gelu",
mlp_ratio=4.0,
attention_bias=True,
hidden_dropout_rate=0.0,
projection_dim=2048,
projection_dropout=0.0,
use_qk_norm=False,
initializer_range=0.02,
depth=2,
num_image_tokens=576,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_channels=num_channels,
patch_size=patch_size,
image_size=image_size,
attention_dropout=attention_dropout,
layer_norm_eps=layer_norm_eps,
hidden_act=hidden_act,
**kwargs,
)
del self.intermediate_size
self.mlp_ratio = mlp_ratio
self.attention_bias = attention_bias
self.hidden_dropout_rate = hidden_dropout_rate
self.projection_dim = projection_dim
self.projection_dropout = projection_dropout
self.use_qk_norm = use_qk_norm
self.initializer_range = initializer_range
self.depth = depth
self.num_image_tokens = num_image_tokens
class JanusVQVAEConfig(ChameleonVQVAEConfig):
r"""
This is the configuration class to store the configuration of a [`JanusVQVAEModel`]. It is used to instantiate a
`JanusVQVAEModel` according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. Instantiating a
configuration with the defaults will yield a similar configuration to the VQModel of the
[deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B).
Args:
embed_dim (`int`, *optional*, defaults to 8):
Dimensionality of each embedding vector.
num_embeddings (`int`, *optional*, defaults to 16384):
Number of codebook embeddings.
double_latent (`bool`, *optional*, defaults to `False`):
Whether to use double z channels.
latent_channels (`int`, *optional*, defaults to 256):
Number of channels for the latent space.
num_patches (`int`, *optional*, defaults to 32):
Num of patches the input images can be divided into.
in_channels (`int`, *optional*, defaults to 3):
Number of input channels.
out_channels (`int`, *optional*, defaults to 3):
Number of out channels.
base_channels (`int`, *optional*, defaults to 128):
Base channel count.
channel_multiplier (`list[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
Channel multipliers for each resolution.
num_res_blocks (`int`, *optional*, defaults to 2):
Number of residual blocks.
dropout (`float`, *optional*, defaults to 0.0):
Dropout rate.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
projection_dim (`int`, *optional*, defaults to 2048):
Dimensionality of the MLP projection head.
num_hidden_layers (`int`, *optional*, defaults to 2):
Number of hidden layers in VAVAE MLP Connecter module.
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
image_token_embed_dim (`int`, *optional*, defaults to 2048):
Dimension of image embeddings. It should be same as the dimensionality of text embeddings.
"""
def __init__(
self,
embed_dim: int = 8,
num_embeddings: int = 16384,
double_latent: bool = False,
latent_channels: int = 256,
num_patches: int = 32,
in_channels: int = 3,
out_channels: int = 3,
base_channels: int = 128,
channel_multiplier: list[int] = [1, 1, 2, 2, 4],
num_res_blocks: int = 2,
dropout: float = 0.0,
initializer_range=0.02,
projection_dim=2048,
num_hidden_layers=2,
hidden_act="gelu",
image_token_embed_dim=2048,
**kwargs,
):
super().__init__(
embed_dim=embed_dim,
num_embeddings=num_embeddings,
double_latent=double_latent,
latent_channels=latent_channels,
in_channels=in_channels,
base_channels=base_channels,
channel_multiplier=channel_multiplier,
num_res_blocks=num_res_blocks,
dropout=dropout,
initializer_range=initializer_range,
**kwargs,
)
self.num_patches = num_patches
self.out_channels = out_channels
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.hidden_act = hidden_act
self.image_token_embed_dim = image_token_embed_dim
del self.resolution
del self.attn_resolutions
del self.attn_type
class JanusConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`JanusModel`]. It is used to instantiate an
Janus model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Janus-1B or Janus-7B models.
e.g. [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B) or
[deepseek-community/Janus-Pro-7B](https://huggingface.co/deepseek-community/Janus-Pro-7B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVisionConfig`):
The config object or dictionary of the vision backbone.
vq_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVQVAEConfig`):
The config object or dictionary of the VQVAE backbone.
image_token_id (`int`, *optional*, defaults to 100581):
Token index of a placeholder image token.
Example:
```python
>>> from transformers import JanusForConditionalGeneration, JanusConfig, JanusVisionConfig, JanusVQVAEConfig, LlamaConfig
>>> # Initializing a Janus vision config
>>> vision_config = JanusVisionConfig()
>>> # Initializing a Llama config
>>> text_config = LlamaConfig()
>>> # Initializing a VQ config
>>> vq_config = JanusVQVAEConfig()
>>> # Initializing a Janus Pro 1B style configuration
>>> configuration = JanusConfig(vision_config=vision_config, text_config=text_config, vq_config=vq_config)
>>> # Initializing a model from the Janus Pro 1B style configuration
>>> model = JanusForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "janus"
sub_configs = {
"text_config": AutoConfig,
"vision_config": JanusVisionConfig,
"vq_config": JanusVQVAEConfig,
}
def __init__(
self,
text_config=None,
vision_config=None,
vq_config=None,
image_token_id=100581,
**kwargs,
):
if isinstance(text_config, dict):
text_config["model_type"] = text_config.get("model_type", "llama")
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
logger.info("`text_config` is None. Initializing with default values")
self.text_config = CONFIG_MAPPING["llama"]()
elif isinstance(text_config, PretrainedConfig):
self.text_config = text_config
else:
raise ValueError(
f"Invalid type for `text_config`. Must be either `dict` or `LlamaConfig`."
f" Type found: {type(text_config)}"
)
if vision_config is None:
logger.info("`vision_config` is None. Initializing with default JanusVisionConfig values")
self.vision_config = JanusVisionConfig()
elif isinstance(vision_config, dict):
self.vision_config = JanusVisionConfig(**vision_config)
elif isinstance(vision_config, JanusVisionConfig):
self.vision_config = vision_config
else:
raise ValueError(
f"Invalid type for `vision_config`. Must be either `dict` or `JanusVisionConfig`."
f" Type found: {type(vision_config)}"
)
if vq_config is None:
logger.info("`vq_config` is None. Initializing with default JanusVQVAEConfig values")
self.vq_config = JanusVQVAEConfig()
elif isinstance(vq_config, dict):
self.vq_config = JanusVQVAEConfig(**vq_config)
elif isinstance(vq_config, JanusVQVAEConfig):
self.vq_config = vq_config
else:
raise ValueError(
f"Invalid type for `vq_config`. Must be either `dict` or `JanusVQVAEConfig`."
f" Type found: {type(vq_config)}"
)
self.initializer_range = self.vision_config.initializer_range
# This dimension is required when decoding discrete image tokens to continuous input.
self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size
# The default is only the index for the 1B model, 7B uses a different one
self.image_token_id = image_token_id
super().__init__(**kwargs)
@auto_docstring
class JanusPreTrainedModel(PreTrainedModel):
config: JanusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn = True
_supports_sdpa = True
_can_compile_fullgraph = True
_supports_param_buffer_assignment = False
@dataclass
@auto_docstring(
custom_intro="""
Base class for Janus VQ-VAE mode model outputs.
"""
)
class JanusVQVAEOutput(ModelOutput):
r"""
decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
Reconstructed pixel values after encoding and decoding the input.
embedding_loss (`torch.FloatTensor`):
Embedding loss.
"""
decoded_pixel_values: Optional[torch.FloatTensor] = None
embedding_loss: torch.FloatTensor = None
class JanusBaseModelOutputWithPast(IdeficsBaseModelOutputWithPast):
pass
class JanusCausalLMOutputWithPast(IdeficsCausalLMOutputWithPast):
pass
class JanusVisionEmbeddings(SiglipVisionEmbeddings):
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
_, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
if interpolate_pos_encoding:
pos_embeds = self.interpolate_pos_encoding(embeddings, height, width)
else:
pos_embeds = self.position_embedding(self.position_ids)
embeddings = embeddings + pos_embeds
return embeddings
class JanusVisionAttention(nn.Module):
"""Attention Class for Janus Vision Encoder"""
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
proj_dropout = config.projection_dropout
qk_norm = config.use_qk_norm
self.is_causal = False
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
self.num_key_value_groups = 1
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
self.q_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
):
batch_size, seq_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
query_states = self.q_norm(query_states)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
key_states = self.k_norm(key_states)
query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
output = self.projection_layer(attn_output)
output = self.projection_dropout(output)
return output, attn_weights
class JanusVisionMLP(nn.Module):
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.config = config
self.intermediate_size = int(config.hidden_size * config.mlp_ratio)
self.activation_fn = ACT2FN[config.hidden_act] # Gelu act
self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size)
self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size)
self.dropout1 = nn.Dropout(config.hidden_dropout_rate)
self.dropout2 = nn.Dropout(config.hidden_dropout_rate)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.dropout1(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout2(hidden_states)
return hidden_states
class JanusVisionEncoderLayer(SiglipEncoderLayer):
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.self_attn = JanusVisionAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = JanusVisionMLP(config)
class JanusVisionEncoder(SiglipEncoder):
def __init__(self, config: JanusVisionConfig):
super().__init__(config)
self.layers = nn.ModuleList([JanusVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
class JanusVisionModel(Blip2VisionModel):
def __init__(self, config: JanusVisionConfig):
super().__init__(config)
self.encoder = JanusVisionEncoder(config)
class JanusVisionAlignerMLP(nn.Module):
def __init__(self, config: JanusVisionConfig):
super().__init__()
self.fc1 = nn.Linear(config.hidden_size, config.projection_dim)
self.hidden_layers = nn.ModuleList(
[nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.depth)]
)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
for layer in self.hidden_layers:
hidden_states = self.activation_fn(hidden_states)
hidden_states = layer(hidden_states)
return hidden_states
class JanusVQVAEVectorQuantizer(ChameleonVQVAEVectorQuantizer):
def __init__(self, config: JanusVQVAEConfig):
super().__init__(config)
self.quant_state_dims = [config.num_patches] * 2
def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
batch_size = image_tokens.shape[0]
emb_dim: int = self.embedding.weight.shape[-1]
# get quantized latent vectors
hidden_state_quant = self.embedding(image_tokens)
# l2 normalization on the last dimension
hidden_state_quant = F.normalize(hidden_state_quant, p=2, dim=-1)
# reshape back to match original input shape
hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim))
hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
return hidden_state_quant
class JanusVQVAEResnetBlock(ChameleonVQVAEEncoderResnetBlock):
pass
class JanusVQVAEAttnBlock(ChameleonVQVAEEncoderAttnBlock):
pass
class JanusVQVAEConvDownsample(ChameleonVQVAEEncoderConvDownsample):
pass
class JanusVQVAEConvUpsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_states):
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
hidden_states = self.conv(hidden_states)
return hidden_states
class JanusVQVAEMidBlock(nn.Module):
def __init__(self, config: JanusVQVAEConfig, channels: int):
super().__init__()
self.block_1 = JanusVQVAEResnetBlock(
config=config,
in_channels=channels,
out_channels=channels,
)
self.attn_1 = JanusVQVAEAttnBlock(channels)
self.block_2 = JanusVQVAEResnetBlock(
config=config,
in_channels=channels,
out_channels=channels,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.block_1(hidden_states)
hidden_states = self.attn_1(hidden_states)
hidden_states = self.block_2(hidden_states)
return hidden_states
class JanusVQVAEEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.num_resolutions = len(config.channel_multiplier)
self.num_res_blocks = config.num_res_blocks
base_channels = config.base_channels
in_channels = config.in_channels
double_latent = config.double_latent
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier
self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
in_channel_multiplier = (1,) + tuple(channel_multiplier)
self.in_channel_multiplier = in_channel_multiplier
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = base_channels * in_channel_multiplier[i_level]
block_out = base_channels * channel_multiplier[i_level]
for i_block in range(self.num_res_blocks):
block.append(
JanusVQVAEResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_out,
)
)
block_in = block_out
if i_level == self.num_resolutions - 1:
attn.append(JanusVQVAEAttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = JanusVQVAEConvDownsample(block_in)
self.down.append(down)
self.mid = JanusVQVAEMidBlock(config, block_in)
self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * latent_channels if double_latent else latent_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, pixel_values: torch.LongTensor):
# downsampling
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block](
hidden_states[-1],
)
if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block](hidden_state)
hidden_states.append(hidden_state)
if i_level != self.num_resolutions - 1:
hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
# middle
last_hidden_state = hidden_states[-1]
last_hidden_state = self.mid(last_hidden_state)
# end
last_hidden_state = self.norm_out(last_hidden_state)
last_hidden_state *= torch.sigmoid(last_hidden_state)
last_hidden_state = self.conv_out(last_hidden_state)
return last_hidden_state
class JanusVQVAEDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.num_resolutions = len(config.channel_multiplier)
self.num_res_blocks = config.num_res_blocks
base_channels = config.base_channels
latent_channels = config.latent_channels
out_channels = config.out_channels
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1]
# z to block_in
self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = JanusVQVAEMidBlock(config, block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = base_channels * config.channel_multiplier[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
JanusVQVAEResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_out,
)
)
block_in = block_out
if i_level == self.num_resolutions - 1:
attn.append(JanusVQVAEAttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = JanusVQVAEConvUpsample(block_in)
self.up.append(up)
# end
self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor:
hidden_state = self.conv_in(hidden_state)
# middle
hidden_state = self.mid(hidden_state)
# upsampling
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks + 1):
hidden_state = self.up[i_level].block[i_block](hidden_state)
if len(self.up[i_level].attn) > 0:
hidden_state = self.up[i_level].attn[i_block](hidden_state)
if i_level != self.num_resolutions - 1:
hidden_state = self.up[i_level].upsample(hidden_state)
hidden_state = self.norm_out(hidden_state)
hidden_state *= torch.sigmoid(hidden_state)
hidden_state = self.conv_out(hidden_state)
return hidden_state
class JanusVQVAE(ChameleonVQVAE):
_no_split_modules = [
"JanusVQVAEAttnBlock",
"JanusVQVAEResnetBlock",
"JanusVQVAEVectorQuantizer",
]
main_input_name = "pixel_values"
def __init__(self, config: JanusVQVAEConfig):
super().__init__(config)
self.decoder = JanusVQVAEDecoder(config)
self.gradient_checkpointing = False
# Initialize the VQVAE model.
self.post_init()
def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
"""
Decodes quantized token IDs into pixel values.
Args:
image_tokens (torch.LongTensor): Batch of token IDs.
Returns:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
Pixel values decoded from the token IDs.
"""
if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]:
raise ValueError(
f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, "
f"but got shape `{image_tokens.shape}`."
)
codebook_entry = self.quantize.get_codebook_entry(image_tokens)
hidden_states = self.post_quant_conv(codebook_entry)
pixel_values = self.decoder(hidden_states)
return pixel_values
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
batch_size = pixel_values.shape[0]
quant, embedding_loss, indices = self.encode(pixel_values)
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
class JanusVQVAEAlignerMLP(nn.Module):
def __init__(self, config: JanusVQVAEConfig):
super().__init__()
self.fc1 = nn.Linear(config.embed_dim, config.projection_dim)
self.hidden_layers = nn.ModuleList(
[nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.num_hidden_layers)]
)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
for layer in self.hidden_layers:
hidden_states = self.activation_fn(hidden_states)
hidden_states = layer(hidden_states)
return hidden_states
class JanusVQVAEHead(nn.Module):
"""Head used for sampling tokens in image generation, replacing the usual lm head."""
def __init__(self, config: JanusVQVAEConfig):
super().__init__()
self.proj_out = nn.Linear(config.image_token_embed_dim, config.projection_dim)
self.activation_fn = ACT2FN[config.hidden_act]
self.vision_head = nn.Linear(config.projection_dim, config.num_embeddings)
def forward(self, hidden_states: torch.Tensor) -> torch.tensor:
hidden_states = self.proj_out(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.vision_head(hidden_states)
return hidden_states
@auto_docstring(
custom_intro="""
The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.
"""
)
class JanusModel(JanusPreTrainedModel):
def __init__(self, config: JanusConfig):
super().__init__(config)
self.config = config
# This is necessary for backward compatibility, see SiglipModel initialization
self.vision_model = JanusVisionModel._from_config(config.vision_config)
self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
self.vqmodel = JanusVQVAE._from_config(config.vq_config)
# Below generation_* modules are used for Image generation.
# Embeddings used for image generation, instead of Janus vision embeddings.
self.generation_embeddings = nn.Embedding(self.vqmodel.config.num_embeddings, self.vqmodel.config.embed_dim)
self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
self.generation_head = JanusVQVAEHead(self.vqmodel.config)
self.language_model = AutoModel.from_config(config=config.text_config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing.
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_image_features(self, pixel_values):
image_embeds = self.vision_model(pixel_values)
image_embeds = self.aligner(image_embeds.last_hidden_state)
return image_embeds
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
if input_ids is None:
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_attention_mask = image_attention_mask.all(-1)
else:
image_attention_mask = input_ids == self.config.image_token_id
image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values)
image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
lm_output = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
return JanusBaseModelOutputWithPast(
last_hidden_state=lm_output.last_hidden_state,
past_key_values=lm_output.past_key_values,
hidden_states=lm_output.hidden_states,
attentions=lm_output.attentions,
image_hidden_states=image_embeds if pixel_values is not None else None,
)
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
_can_compile_fullgraph = True
def __init__(self, config: JanusConfig):
super().__init__(config)
self.config = config
self.model = JanusModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
# Initialize weights and apply final processing.
self.post_init()
def get_input_embeddings(self):
return self.model.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.language_model.set_input_embeddings(value)
def prepare_embeddings_for_image_generation(self, inputs: torch.Tensor) -> torch.Tensor:
hidden_state = self.model.generation_embeddings(inputs)
hidden_state = self.model.generation_aligner(hidden_state)
return hidden_state
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return JanusCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
pixel_values=None,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- extra custom processing
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
return model_inputs
def decode_image_tokens(self, image_tokens: torch.Tensor):
"""
Decodes generated image tokens from language model to continuous pixel values
with VQGAN module via upsampling.
Args:
image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
The tensors corresponding to the input images.
"""
decoded_image = self.model.vqmodel.decode(image_tokens)
decoded_image = decoded_image.permute(0, 2, 3, 1)
return decoded_image
@torch.no_grad
def generate(
self,
inputs: torch.Tensor = None,
attention_mask: Optional[torch.LongTensor] = None,
logits_processor: Optional[LogitsProcessorList] = None,
**kwargs,
):
# 1. Handle generation config and model kwargs
generation_config = kwargs.pop("generation_config", self.generation_config)
generation_config = copy.deepcopy(generation_config)
# Default to "text" generation if mode isn't provided
generation_mode = kwargs.pop("generation_mode", "text")
if generation_mode == "text":
# Set guidance_scale=None to prevent running UnbatchedCFG processor.
return super().generate(
inputs=inputs,
attention_mask=attention_mask,
generation_config=generation_config,
guidance_scale=None,
**kwargs,
)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# Validate generation mode
if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
raise ValueError(
"Got incompatible mode for Image Generation, should be one of greedy or sampling. "
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
)
# Validate the configuration and model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Initialize logit processors
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
# Set `use_cache=True` as we will be using input embeds for generation.
model_kwargs["use_cache"] = True
if generation_config.guidance_scale is None:
logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.")
generation_config.guidance_scale = 5
model_kwargs["guidance_scale"] = generation_config.guidance_scale
# 3. Prepare model inputs
input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
dtype, device = input_ids.dtype, input_ids.device
if len(input_ids.shape) != 2:
raise ValueError(
f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}"
"Passing `inputs embeds` is not supported currently."
)
# Prepare special tokens which will be used generate internally.
kwargs_has_attention_mask = attention_mask is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
# 4. Add CFG processor along with user passed logit processor.
if generation_config.guidance_scale and generation_config.guidance_scale > 1:
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
generation_config.guidance_scale = None # Reset to prevent processor duplication.
# 5. Prepare logits processor
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids.shape[1],
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
device=device,
)
# 6. Expand inputs for multiple image generations per prompt.
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
attention_mask=attention_mask,
expand_size=generation_config.num_return_sequences,
**model_kwargs,
)
# 7. Prepare input and model caches
num_image_tokens = self.model.vision_model.config.num_image_tokens
batch_size, seq_len = input_ids.shape
input_tokens = input_ids.repeat(2, 1) # Double batch size for conditional/unconditional logits
attention_mask = model_kwargs.pop("attention_mask", None)
attention_mask = attention_mask.repeat(2, 1)
model_kwargs["attention_mask"] = attention_mask
# Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
)
input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)
inputs_embeds = self.get_input_embeddings()(input_tokens)
model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.
model_kwargs["past_key_values"] = self._get_cache(
cache_implementation=generation_config.cache_implementation or "static",
# batch_size should account for both conditional/unconditional input; hence multiplied by 2.
batch_size=batch_size * 2,
# we should have at least a cache len of seq_len + num_image_tokens.
max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
device=device,
model_kwargs=model_kwargs,
)
# Placeholder for generated tokens.
generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device)
# 8. init attention / hidden states / scores tuples
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
raw_scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
for i in range(num_image_tokens):
model_inputs = self.prepare_inputs_for_generation(
inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
)
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device)
outputs = self.model.language_model(
**model_inputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# Update model_kwargs like cache_position for next generation.
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
hidden_state = outputs.last_hidden_state[:, -1, :].clone()
# Generate scores using the generation head (Not using above defined LM Head)
scores = self.model.generation_head(hidden_state)
next_token_scores = logits_processor(input_ids, scores)
# Sample next token.
if generation_config.do_sample:
probs = torch.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
else:
next_token = torch.argmax(next_token_scores, dim=-1)
generated_tokens[:, i] = next_token
# Prepare embeddings for the next step.
next_token = torch.cat([next_token, next_token])
next_token = next_token.unsqueeze(-1)
inputs_embeds = self.prepare_embeddings_for_image_generation(next_token)
if return_dict_in_generate:
if output_scores:
raw_scores += (scores,)
if output_logits:
raw_logits += (hidden_state.float(),)
if output_attentions:
decoder_attentions += outputs.attentions
if output_hidden_states:
decoder_hidden_states += outputs.hidden_states
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=generated_tokens,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=outputs.past_key_values,
)
else:
return generated_tokens
class JanusImageProcessor(BlipImageProcessor):
r"""
Constructs a JANUS image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
min_size (`int`, *optional*, defaults to 14):
The minimum allowed size for the resized image. Ensures that neither the height nor width
falls below this value after resizing.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
def __init__(
self,
do_resize: bool = True,
size: Optional[dict[str, int]] = None,
min_size: int = 14,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
do_convert_rgb: Optional[bool] = None,
**kwargs,
):
super().__init__(**kwargs)
self.min_size = min_size
if image_mean is None:
self.background_color = (127, 127, 127)
else:
self.background_color = tuple([int(x * 255) for x in image_mean])
def pad_to_square(
self,
image: np.ndarray,
background_color: Union[int, tuple[int, int, int]] = 0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Pads an image to a square based on the longest edge.
Args:
image (`np.ndarray`):
The image to pad.
background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The padded image.
"""
height, width = get_image_size(image, input_data_format)
num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
if height == width:
image = (
to_channel_dimension_format(image, data_format, input_data_format)
if data_format is not None
else image
)
return image
max_dim = max(height, width)
# Ensure background_color is the correct shape
if isinstance(background_color, int):
background_color = [background_color]
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
if input_data_format == ChannelDimension.FIRST:
result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
for i, color in enumerate(background_color):
result[i, :, :] = color
if width > height:
start = (max_dim - height) // 2
result[:, start : start + height, :] = image
else:
start = (max_dim - width) // 2
result[:, :, start : start + width] = image
else:
result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
for i, color in enumerate(background_color):
result[:, :, i] = color
if width > height:
start = (max_dim - height) // 2
result[start : start + height, :, :] = image
else:
start = (max_dim - width) // 2
result[:, start : start + width, :] = image
return result
def resize(
self,
image: np.ndarray,
size: Union[dict[str, int], int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to dynamically calculated size.
Args:
image (`np.ndarray`):
Image to resize.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `None`: will be inferred from input
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
height, width = get_image_size(image, input_data_format)
max_size = max(height, width)
size = get_size_dict(size, default_to_square=True)
if size["height"] != size["width"]:
raise ValueError(
f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
)
size = size["height"]
delta = size / max_size
# Largest side becomes `size` and the other side is scaled according to the aspect ratio.
output_size_nonpadded = [
max(int(height * delta), self.min_size),
max(int(width * delta), self.min_size),
]
image = resize(
image,
size=output_size_nonpadded,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# Expand and pad the images to obtain a square image of dimensions `size x size`
image = self.pad_to_square(
image=image,
background_color=self.background_color,
input_data_format=input_data_format,
)
return image
def postprocess(
self,
images: ImageInput,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[list[float]] = None,
image_std: Optional[list[float]] = None,
input_data_format: Optional[str] = None,
return_tensors: Optional[str] = None,
):
"""Applies post-processing to the decoded image tokens by reversing transformations applied during preprocessing."""
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
images = make_list_of_images(images) # Ensures input is a list
if isinstance(images[0], PIL.Image.Image):
return images if len(images) > 1 else images[0]
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0]) # Determine format dynamically
pixel_values = []
for image in images:
image = to_numpy_array(image) # Ensure NumPy format
if do_normalize:
image = self.unnormalize(
image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format
)
if do_rescale:
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
image = image.clip(0, 255).astype(np.uint8)
if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
image = PIL.Image.fromarray(image)
pixel_values.append(image)
data = {"pixel_values": pixel_values}
return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
return BatchFeature(data=data, tensor_type=return_tensors)
def unnormalize(
self,
image: np.array,
image_mean: Union[float, Iterable[float]],
image_std: Union[float, Iterable[float]],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`.
image = (image * image_std) + image_mean
Args:
image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`):
Batch of pixel values to postprocess.
image_mean (`float` or `Iterable[float]`):
The mean to use for unnormalization.
image_std (`float` or `Iterable[float]`):
The standard deviation to use for unnormalization.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
num_channels = 3
if isinstance(image_mean, Iterable):
if len(image_mean) != num_channels:
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}")
else:
image_mean = [image_mean] * num_channels
if isinstance(image_std, Iterable):
if len(image_std) != num_channels:
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}")
else:
image_std = [image_std] * num_channels
rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std))
rev_image_std = tuple(1 / std for std in image_std)
image = self.normalize(
image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format
)
return image
__all__ = [
"JanusImageProcessor",
"JanusPreTrainedModel",
"JanusForConditionalGeneration",
"JanusModel",
"JanusVQVAE",
"JanusVisionModel",
"JanusVQVAEConfig",
"JanusVisionConfig",
"JanusConfig",
]