1732 lines
79 KiB
Python
1732 lines
79 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import itertools
|
|
from typing import Callable, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
from torch.nn import LayerNorm
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...configuration_utils import PretrainedConfig
|
|
from ...feature_extraction_utils import BatchFeature
|
|
from ...image_utils import ImageInput
|
|
from ...masking_utils import create_causal_mask
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import BaseModelOutputWithPast
|
|
from ...modeling_rope_utils import rope_config_validation
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
from ...processing_utils import ImagesKwargs, Unpack
|
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
|
from ...video_utils import VideoInput
|
|
from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, eager_attention_forward
|
|
from ..qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
|
from ..qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
Qwen2_5_VisionPatchEmbed,
|
|
Qwen2_5_VisionRotaryEmbedding,
|
|
Qwen2_5_VLCausalLMOutputWithPast,
|
|
Qwen2_5_VLForConditionalGeneration,
|
|
Qwen2_5_VLMLP,
|
|
Qwen2_5_VLModel,
|
|
Qwen2_5_VLModelOutputWithPast,
|
|
Qwen2_5_VLPreTrainedModel,
|
|
Qwen2_5_VLRotaryEmbedding,
|
|
Qwen2_5_VLTextModel,
|
|
Qwen2_5_VLVisionAttention,
|
|
Qwen2_5_VLVisionBlock,
|
|
)
|
|
from ..qwen2_5_vl.processing_qwen2_5_vl import (
|
|
Qwen2_5_VLProcessor,
|
|
Qwen2_5_VLProcessorKwargs,
|
|
Qwen2_5_VLVideosProcessorKwargs,
|
|
)
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Glm4vVisionConfig(PretrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`Glm4vVisionModel`]. It is used to instantiate an Glm4vVisionModel
|
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
|
|
a similar configuration to that of
|
|
GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
|
|
|
|
Args:
|
|
hidden_size (`int`, *optional*, defaults to 1536):
|
|
Dimensionality of the encoder layers and the pooler layer.
|
|
depth (`int`, *optional*, defaults to 24):
|
|
Number of layers (depth) in the model.
|
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
|
Whether to add a bias to the queries, keys and values.
|
|
intermediate_size (`int`, *optional*, defaults to 13696):
|
|
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
|
hidden_act (`str` or `function`, *optional*, defaults to `"selu"`):
|
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
|
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
|
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
Dropout probability for attention weights.
|
|
projection_dropout (`float`, *optional*, defaults to 0.0):
|
|
Dropout probability for the projection layer.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
image_size (`int` or `list[int]`, *optional*, defaults to `[336, 336]`):
|
|
The size (resolution) of each image.
|
|
patch_size (`int`, *optional*, defaults to `14`):
|
|
The size (resolution) of each patch.
|
|
num_channels (`int`, *optional*, defaults to 3):
|
|
The number of input channels.
|
|
out_hidden_size (`int`, *optional*, defaults to 4096):
|
|
The output hidden size of the vision model.
|
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
The epsilon used by the rms normalization layers.
|
|
spatial_merge_size (`int`, *optional*, defaults to 2):
|
|
The size used for merging spatial dimensions.
|
|
temporal_patch_size (`int`, *optional*, defaults to 2):
|
|
The size used for patches along the temporal dimension.
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import Glm4vVisionConfig, Glm4vVisionModel
|
|
|
|
>>> # Initializing a Glm4vVisionConfig GLM-4.1V-9B style configuration
|
|
>>> configuration = Glm4vVisionConfig()
|
|
|
|
>>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration
|
|
>>> model = Glm4vVisionModel(configuration)
|
|
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
model_type = "glm4v"
|
|
base_config_key = "vision_config"
|
|
|
|
def __init__(
|
|
self,
|
|
depth=24,
|
|
hidden_size=1536,
|
|
hidden_act="silu",
|
|
attention_bias=False,
|
|
attention_dropout=0.0,
|
|
num_heads=12,
|
|
in_channels=3,
|
|
image_size=336,
|
|
patch_size=14,
|
|
rms_norm_eps=1e-05,
|
|
spatial_merge_size=2,
|
|
temporal_patch_size=1,
|
|
out_hidden_size=4096,
|
|
intermediate_size=13696,
|
|
initializer_range=0.02,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
|
|
self.depth = depth
|
|
self.hidden_size = hidden_size
|
|
self.hidden_act = hidden_act
|
|
self.num_heads = num_heads
|
|
self.in_channels = in_channels
|
|
self.image_size = image_size
|
|
self.patch_size = patch_size
|
|
self.spatial_merge_size = spatial_merge_size
|
|
self.temporal_patch_size = temporal_patch_size
|
|
self.out_hidden_size = out_hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.initializer_range = initializer_range
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.attention_bias = attention_bias
|
|
self.attention_dropout = attention_dropout
|
|
|
|
|
|
class Glm4vTextConfig(PretrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`Glm4vModel`]. It is used to instantiate a
|
|
GLM-4.1V model according to the specified arguments, defining the model architecture. Instantiating a
|
|
configuration with the defaults will yield a similar configuration to that of
|
|
GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
|
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PretrainedConfig`] for more information.
|
|
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 151552):
|
|
Vocabulary size of the Glm4v model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`Glm4vModel`]
|
|
hidden_size (`int`, *optional*, defaults to 4096):
|
|
Dimension of the hidden representations.
|
|
intermediate_size (`int`, *optional*, defaults to 13696):
|
|
Dimension of the MLP representations.
|
|
num_hidden_layers (`int`, *optional*, defaults to 40):
|
|
Number of hidden layers in the Transformer encoder.
|
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
|
Number of attention heads for each attention layer in the Transformer encoder.
|
|
num_key_value_heads (`int`, *optional*, defaults to 2):
|
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
by meanpooling all the original heads within that group. For more details checkout [this
|
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
The non-linear activation function (function or string) in the decoder.
|
|
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
|
The maximum sequence length that this model might ever be used with.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
|
The epsilon used by the rms normalization layers.
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
relevant if `config.is_decoder=True`.
|
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
Whether the model's input and output word embeddings should be tied.
|
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
|
The base period of the RoPE embeddings.
|
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the attention probabilities.
|
|
rope_scaling (`Dict`, *optional*):
|
|
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
|
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
|
accordingly.
|
|
Expected contents:
|
|
`rope_type` (`str`):
|
|
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
|
'llama3'], with 'default' being the original RoPE implementation.
|
|
`factor` (`float`, *optional*):
|
|
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
|
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
|
original maximum pre-trained length.
|
|
`original_max_position_embeddings` (`int`, *optional*):
|
|
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
|
pretraining.
|
|
`attention_factor` (`float`, *optional*):
|
|
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
|
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
|
`factor` field to infer the suggested value.
|
|
image_token_id (`int`, *optional*):
|
|
Token index used as placeholder for image embeddings.
|
|
video_token_id (`int`, *optional*):
|
|
Token index used as placeholder for video embeddings.
|
|
|
|
```python
|
|
>>> from transformers import Glm4vTextModel, Glm4vConfig
|
|
|
|
>>> # Initializing a GLM-4.1V style configuration
|
|
>>> configuration = Glm4vConfig()
|
|
|
|
>>> # Initializing a model from the GLM-4.1V style configuration
|
|
>>> model = Glm4vTextModel(configuration)
|
|
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
model_type = "glm4v_text"
|
|
base_config_key = "text_config"
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
# Default tensor parallel plan for base model `Glm4v`
|
|
base_model_tp_plan = {
|
|
"layers.*.self_attn.q_proj": "colwise",
|
|
"layers.*.self_attn.k_proj": "colwise",
|
|
"layers.*.self_attn.v_proj": "colwise",
|
|
"layers.*.self_attn.o_proj": "rowwise",
|
|
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
|
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
|
}
|
|
base_model_pp_plan = {
|
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
|
"norm": (["hidden_states"], ["hidden_states"]),
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size=151552,
|
|
hidden_size=4096,
|
|
intermediate_size=13696,
|
|
num_hidden_layers=40,
|
|
num_attention_heads=32,
|
|
num_key_value_heads=2,
|
|
hidden_act="silu",
|
|
max_position_embeddings=32768,
|
|
initializer_range=0.02,
|
|
rms_norm_eps=1e-05,
|
|
use_cache=True,
|
|
tie_word_embeddings=False,
|
|
rope_theta=10000.0,
|
|
attention_dropout=0.0,
|
|
rope_scaling=None,
|
|
image_token_id=None,
|
|
video_token_id=None,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
|
|
# for backward compatibility
|
|
if num_key_value_heads is None:
|
|
num_key_value_heads = num_attention_heads
|
|
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.hidden_act = hidden_act
|
|
self.initializer_range = initializer_range
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.use_cache = use_cache
|
|
self.rope_theta = rope_theta
|
|
self.attention_dropout = attention_dropout
|
|
self.rope_scaling = rope_scaling
|
|
|
|
# Validate the correctness of rotary position embeddings parameters
|
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
|
rope_config_validation(self, ignore_keys={"mrope_section"})
|
|
self.image_token_id = image_token_id
|
|
self.video_token_id = video_token_id
|
|
|
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
|
|
|
|
|
class Glm4vConfig(Qwen2_5_VLConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`Glm4vModel`]. It is used to instantiate a
|
|
GLM-4.1V model according to the specified arguments, defining the model architecture. Instantiating a
|
|
configuration with the defaults will yield a similar configuration to that of
|
|
GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
|
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PretrainedConfig`] for more information.
|
|
|
|
|
|
Args:
|
|
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vTextConfig`):
|
|
The config object or dictionary of the text backbone.
|
|
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vVisionConfig`):
|
|
The config object or dictionary of the vision backbone.
|
|
image_token_id (`int`, *optional*, defaults to 151343):
|
|
The image token index to encode the image prompt.
|
|
video_token_id (`int`, *optional*, defaults to 151344):
|
|
The video token index to encode the image prompt.
|
|
image_start_token_id (`int`, *optional*, defaults to 151339):
|
|
The image start token index to encode the start of image.
|
|
image_end_token_id (`int`, *optional*, defaults to 151340):
|
|
The image end token index to encode the end of image.
|
|
video_start_token_id (`int`, *optional*, defaults to 151341):
|
|
The video start token index to encode the start of video.
|
|
video_end_token_id (`int`, *optional*, defaults to 151342):
|
|
The video end token index to encode the end of video.
|
|
|
|
```python
|
|
>>> from transformers import Glm4vForConditionalGeneration, Glm4vConfig
|
|
|
|
>>> # Initializing a GLM-4.1V style configuration
|
|
>>> configuration = Glm4vConfig()
|
|
|
|
>>> # Initializing a model from the GLM-4.1V style configuration
|
|
>>> model = Glm4vForConditionalGeneration(configuration)
|
|
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
def __init__(
|
|
self,
|
|
text_config=None,
|
|
vision_config=None,
|
|
image_token_id=151343,
|
|
video_token_id=151344,
|
|
image_start_token_id=151339,
|
|
image_end_token_id=151340,
|
|
video_start_token_id=151341,
|
|
video_end_token_id=151342,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.video_start_token_id = video_start_token_id
|
|
self.video_end_token_id = video_end_token_id
|
|
self.image_start_token_id = image_start_token_id
|
|
self.image_end_token_id = image_end_token_id
|
|
|
|
|
|
# Will be used for both Text and Vision modalities
|
|
class Glm4vRMSNorm(Glm4RMSNorm):
|
|
pass
|
|
|
|
|
|
class Glm4VisionMlp(Qwen2_5_VLMLP):
|
|
def __init__(self, config, bias: bool = False):
|
|
super().__init__(config, bias)
|
|
self.intermediate_size = config.out_hidden_size
|
|
|
|
|
|
class Glm4vVisionPatchEmbed(Qwen2_5_VisionPatchEmbed):
|
|
def __init__(self, config: Glm4vVisionConfig) -> None:
|
|
Qwen2_5_VisionPatchEmbed.__init__()
|
|
self.patch_size = config.patch_size
|
|
self.temporal_patch_size = config.temporal_patch_size
|
|
self.in_channels = config.in_channels
|
|
self.embed_dim = config.hidden_size
|
|
|
|
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
|
self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
|
|
|
|
|
|
class Glm4vVisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding):
|
|
pass
|
|
|
|
|
|
class Glm4vVisionPatchMerger(nn.Module):
|
|
def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None:
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim, dim, bias=bias)
|
|
self.post_projection_norm = LayerNorm(dim)
|
|
self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
|
|
self.up_proj = nn.Linear(dim, context_dim, bias=bias)
|
|
self.down_proj = nn.Linear(context_dim, dim, bias=bias)
|
|
self.act1 = nn.GELU()
|
|
self.act_fn = ACT2FN[hidden_act]
|
|
|
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
hidden_state = self.proj(hidden_state)
|
|
hidden_state = self.act1(self.post_projection_norm(hidden_state))
|
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
|
class Glm4vVisionEmbeddings(nn.Module):
|
|
def __init__(self, config: Glm4vVisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
self.num_positions = self.num_patches
|
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
|
|
|
def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
|
|
"""
|
|
Forward pass with integrated position encoding adaptation using 2D interpolation.
|
|
|
|
Args:
|
|
embeddings: Input embeddings tensor
|
|
lengths (torch.Tensor): Sequence lengths for each image in the batch.
|
|
image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
|
|
h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
|
|
w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
|
|
|
|
Returns:
|
|
torch.Tensor: Embeddings with adapted position encoding added.
|
|
"""
|
|
# Get position embedding parameters
|
|
pos_embed_weight = self.position_embedding.weight
|
|
hidden_size = pos_embed_weight.shape[1]
|
|
total_seq = h_coords.shape[0]
|
|
device = pos_embed_weight.device
|
|
|
|
# Move coordinates to correct device
|
|
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
|
|
|
|
# Handle empty sequence case
|
|
if total_seq == 0:
|
|
adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)
|
|
else:
|
|
# Convert inputs to tensors if needed
|
|
if isinstance(lengths, list):
|
|
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
|
|
if not isinstance(image_shapes, torch.Tensor):
|
|
image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)
|
|
|
|
# Prepare 2D position embedding
|
|
orig_size_sq = pos_embed_weight.shape[0]
|
|
orig_size = int(orig_size_sq**0.5)
|
|
pos_embed_2d = (
|
|
pos_embed_weight.view(orig_size, orig_size, hidden_size)
|
|
.permute(2, 0, 1)
|
|
.unsqueeze(0)
|
|
.to(device=device, dtype=torch.float32)
|
|
)
|
|
|
|
# Calculate target dimensions for each patch
|
|
target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
|
|
device=device, dtype=torch.float32
|
|
)
|
|
target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
|
|
device=device, dtype=torch.float32
|
|
)
|
|
|
|
# Normalize coordinates to [-1, 1] range for grid_sample
|
|
h_coords = h_coords.to(device=device, dtype=torch.float32)
|
|
w_coords = w_coords.to(device=device, dtype=torch.float32)
|
|
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
|
|
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
|
|
|
|
# Create sampling grid
|
|
grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
|
|
|
|
# Perform bicubic interpolation
|
|
interpolated_embed_fp32 = F.grid_sample(
|
|
pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border"
|
|
)
|
|
|
|
# Reshape and convert back to original dtype
|
|
adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
|
|
adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
|
|
|
|
# Add adapted position encoding to embeddings
|
|
embeddings = embeddings + adapted_pos_embed
|
|
return embeddings
|
|
|
|
|
|
class Glm4vVisionAttention(Qwen2_5_VLVisionAttention):
|
|
def __init__(self, config: Glm4vVisionConfig) -> None:
|
|
super().__init__()
|
|
self.attention_dropout = config.attention_dropout
|
|
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
|
|
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
|
|
|
|
|
class Glm4vVisionBlock(Qwen2_5_VLVisionBlock):
|
|
def __init__(self, config) -> None:
|
|
super().__init__()
|
|
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.attn = Glm4vVisionAttention(config)
|
|
self.mlp = Glm4VisionMlp(config, bias=False)
|
|
|
|
|
|
class Glm4vPreTrainedModel(Qwen2_5_VLPreTrainedModel):
|
|
_no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"]
|
|
|
|
|
|
class Glm4vVisionModel(Glm4vPreTrainedModel):
|
|
config: Glm4vVisionConfig
|
|
_no_split_modules = ["Glm4vVisionBlock"]
|
|
|
|
def __init__(self, config) -> None:
|
|
super().__init__(config)
|
|
self.spatial_merge_size = config.spatial_merge_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.embeddings = Glm4vVisionEmbeddings(config)
|
|
self.patch_embed = Glm4vVisionPatchEmbed(config)
|
|
|
|
head_dim = config.hidden_size // config.num_heads
|
|
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
|
|
|
|
self.blocks = nn.ModuleList([Glm4vVisionBlock(config) for _ in range(config.depth)])
|
|
self.merger = Glm4vVisionPatchMerger(
|
|
dim=config.out_hidden_size, context_dim=config.intermediate_size, hidden_act=config.hidden_act
|
|
)
|
|
|
|
self.post_conv_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.downsample = nn.Conv2d(
|
|
in_channels=config.hidden_size,
|
|
out_channels=config.out_hidden_size,
|
|
kernel_size=config.spatial_merge_size,
|
|
stride=config.spatial_merge_size,
|
|
)
|
|
self.post_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
self.gradient_checkpointing = False
|
|
self.post_init()
|
|
|
|
def rot_pos_emb(self, grid_thw):
|
|
pos_ids = []
|
|
for t, h, w in grid_thw:
|
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
|
hpos_ids = hpos_ids.reshape(
|
|
h // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
w // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
)
|
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
|
hpos_ids = hpos_ids.flatten()
|
|
|
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
|
wpos_ids = wpos_ids.reshape(
|
|
h // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
w // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
)
|
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
|
wpos_ids = wpos_ids.flatten()
|
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
pos_ids = torch.cat(pos_ids, dim=0)
|
|
max_grid_size = grid_thw[:, 1:].max()
|
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
return rotary_pos_emb, pos_ids
|
|
|
|
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
|
The final hidden states of the model.
|
|
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
|
|
Returns:
|
|
`torch.Tensor`: hidden_states.
|
|
"""
|
|
hidden_states = self.patch_embed(hidden_states)
|
|
hidden_states = self.post_conv_layernorm(hidden_states)
|
|
|
|
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
|
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
position_embeddings = (emb.cos(), emb.sin())
|
|
|
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
|
dim=0,
|
|
# Select dtype based on the following factors:
|
|
# - FA2 requires that cu_seqlens_q must have dtype int32
|
|
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
|
|
# See https://github.com/huggingface/transformers/pull/34852 for more information
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
)
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
|
|
|
for blk in self.blocks:
|
|
hidden_states = blk(
|
|
hidden_states,
|
|
cu_seqlens=cu_seqlens,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
|
|
hidden_states = self.post_layernorm(hidden_states)
|
|
|
|
hidden_states = hidden_states.view(
|
|
-1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1]
|
|
)
|
|
hidden_states = hidden_states.permute(0, 3, 1, 2)
|
|
hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size)
|
|
|
|
hidden_states = self.merger(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Glm4vTextRotaryEmbedding(Qwen2_5_VLRotaryEmbedding):
|
|
pass
|
|
|
|
|
|
def rotate_half_llm(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., 0::2]
|
|
x2 = x[..., 1::2]
|
|
return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
|
|
|
|
|
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
|
|
|
|
Explanation:
|
|
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
|
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
|
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
|
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
|
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
|
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
|
difference with modern LLMs.
|
|
|
|
Args:
|
|
q (`torch.Tensor`): The query tensor.
|
|
k (`torch.Tensor`): The key tensor.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
mrope_section(`List(int)`):
|
|
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
Returns:
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
"""
|
|
mrope_section = mrope_section * 2
|
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
|
unsqueeze_dim
|
|
)
|
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
|
unsqueeze_dim
|
|
)
|
|
|
|
# Interleave them instead of usual shape
|
|
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
|
|
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
|
|
|
|
# Keep half or full tensor for later concatenation
|
|
rotary_dim = cos.shape[-1]
|
|
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
|
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
|
|
|
# Apply rotary embeddings on the first half or full tensor
|
|
q_embed = (q_rot * cos) + (rotate_half_llm(q_rot) * sin)
|
|
k_embed = (k_rot * cos) + (rotate_half_llm(k_rot) * sin)
|
|
|
|
# Concatenate back to full shape
|
|
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
|
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
|
|
|
return q_embed, k_embed
|
|
|
|
|
|
class Glm4vTextAttention(nn.Module):
|
|
"""
|
|
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
and "Generating Long Sequences with Sparse Transformers".
|
|
"""
|
|
|
|
def __init__(self, config: Glm4vTextConfig, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.is_causal = True
|
|
self.attention_dropout = config.attention_dropout
|
|
self.rope_scaling = config.rope_scaling
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
|
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
|
)
|
|
|
|
if past_key_value is not None:
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
class Glm4vTextMLP(Glm4MLP):
|
|
pass
|
|
|
|
|
|
class Glm4vTextDecoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: Glm4vTextConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.self_attn = Glm4vTextAttention(config, layer_idx)
|
|
self.mlp = Glm4vTextMLP(config)
|
|
self.input_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_self_attn_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_mlp_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = self.post_self_attn_layernorm(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.post_mlp_layernorm(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
|
|
class Glm4vModelOutputWithPast(Qwen2_5_VLModelOutputWithPast):
|
|
pass
|
|
|
|
|
|
class Glm4vTextModel(Qwen2_5_VLTextModel):
|
|
def __init__(self, config: Glm4vTextConfig):
|
|
super().__init__(config)
|
|
self.layers = nn.ModuleList(
|
|
[Glm4vTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
self.norm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.rotary_emb = Glm4vTextRotaryEmbedding(config=config)
|
|
del self._attn_implementation
|
|
del self.has_sliding_layers
|
|
|
|
@auto_docstring
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
)
|
|
use_cache = False
|
|
|
|
# torch.jit.trace() doesn't support cache objects in the output
|
|
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
|
past_key_values = DynamicCache()
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
|
|
# the hard coded `3` is for temporal, height and width.
|
|
if position_ids is None:
|
|
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
|
elif position_ids.dim() == 2:
|
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
|
|
|
causal_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
|
|
for decoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values if use_cache else None,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
class Glm4vModel(Qwen2_5_VLModel):
|
|
_checkpoint_conversion_mapping = {}
|
|
_no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.visual = Glm4vVisionModel._from_config(config.vision_config)
|
|
|
|
def get_rope_index(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
|
|
|
Explanation:
|
|
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
|
|
|
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
|
|
Examples:
|
|
input_ids: [T T T T T], here T is for text.
|
|
temporal position_ids: [0, 1, 2, 3, 4]
|
|
height position_ids: [0, 1, 2, 3, 4]
|
|
width position_ids: [0, 1, 2, 3, 4]
|
|
|
|
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
|
and 1D rotary position embedding for text part.
|
|
Examples:
|
|
Temporal (Time): 3 patches, representing different segments of the video in time.
|
|
Height: 2 patches, dividing each frame vertically.
|
|
Width: 2 patches, dividing each frame horizontally.
|
|
We also have some important parameters:
|
|
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
|
|
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
|
|
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
|
|
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
|
|
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
|
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
|
|
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
|
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
|
text temporal position_ids: [101, 102, 103, 104, 105]
|
|
text height position_ids: [101, 102, 103, 104, 105]
|
|
text width position_ids: [101, 102, 103, 104, 105]
|
|
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
|
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
Returns:
|
|
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
|
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
|
"""
|
|
|
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
|
image_token_id = self.config.image_token_id
|
|
video_start_token_id = self.config.video_start_token_id
|
|
video_end_token_id = self.config.video_end_token_id
|
|
|
|
mrope_position_deltas = []
|
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
|
total_input_ids = input_ids
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(total_input_ids)
|
|
position_ids = torch.ones(
|
|
3,
|
|
input_ids.shape[0],
|
|
input_ids.shape[1],
|
|
dtype=input_ids.dtype,
|
|
device=input_ids.device,
|
|
)
|
|
image_index, video_index = 0, 0
|
|
video_group_index = 0
|
|
attention_mask = attention_mask.to(total_input_ids.device)
|
|
for i, input_ids in enumerate(total_input_ids):
|
|
input_ids = input_ids[attention_mask[i] == 1]
|
|
input_tokens = input_ids.tolist()
|
|
|
|
input_token_type = []
|
|
video_check_flg = False
|
|
for token in input_tokens:
|
|
if token == video_start_token_id:
|
|
video_check_flg = True
|
|
elif token == video_end_token_id:
|
|
video_check_flg = False
|
|
|
|
if token == image_token_id and not video_check_flg:
|
|
input_token_type.append("image")
|
|
elif token == image_token_id and video_check_flg:
|
|
input_token_type.append("video")
|
|
else:
|
|
input_token_type.append("text")
|
|
|
|
input_type_group = []
|
|
for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):
|
|
group = list(group)
|
|
start_index = group[0][0]
|
|
end_index = group[-1][0] + 1
|
|
input_type_group.append((key, start_index, end_index))
|
|
|
|
llm_pos_ids_list = []
|
|
video_frame_num = 1
|
|
for modality_type, start_idx, end_idx in input_type_group:
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
|
|
if modality_type == "image":
|
|
t, h, w = (
|
|
image_grid_thw[image_index][0],
|
|
image_grid_thw[image_index][1],
|
|
image_grid_thw[image_index][2],
|
|
)
|
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
t.item(),
|
|
h.item() // spatial_merge_size,
|
|
w.item() // spatial_merge_size,
|
|
)
|
|
|
|
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
|
|
|
image_index += 1
|
|
video_frame_num = 1
|
|
|
|
elif modality_type == "video":
|
|
t, h, w = (
|
|
video_frame_num,
|
|
video_grid_thw[video_index][1],
|
|
video_grid_thw[video_index][2],
|
|
)
|
|
|
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
t,
|
|
h.item() // spatial_merge_size,
|
|
w.item() // spatial_merge_size,
|
|
)
|
|
|
|
for t_idx in range(llm_grid_t):
|
|
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
|
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
|
|
|
video_group_index += 1
|
|
|
|
if video_group_index >= video_grid_thw[video_index][0]:
|
|
video_index += 1
|
|
video_group_index = 0
|
|
|
|
video_frame_num += 1
|
|
|
|
else:
|
|
text_len = end_idx - start_idx
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
|
|
video_frame_num = 1
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
|
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
|
return position_ids, mrope_position_deltas
|
|
else:
|
|
if attention_mask is not None:
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
else:
|
|
position_ids = (
|
|
torch.arange(input_ids.shape[1], device=input_ids.device)
|
|
.view(1, 1, -1)
|
|
.expand(3, input_ids.shape[0], -1)
|
|
)
|
|
mrope_position_deltas = torch.zeros(
|
|
[input_ids.shape[0], 1],
|
|
device=input_ids.device,
|
|
dtype=input_ids.dtype,
|
|
)
|
|
|
|
return position_ids, mrope_position_deltas
|
|
|
|
def get_video_features(
|
|
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
|
):
|
|
"""
|
|
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
|
|
|
Args:
|
|
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
The tensors corresponding to the input videos.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
"""
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
|
temp_frames_hw = []
|
|
for t, h, w in video_grid_thw:
|
|
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
|
temp_frames_hw.append(repeated_row)
|
|
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
|
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
|
|
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
|
video_embeds = torch.split(video_embeds, split_sizes)
|
|
return video_embeds
|
|
|
|
@auto_docstring
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Union[tuple, Glm4vModelOutputWithPast]:
|
|
r"""
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
The rope index difference between sequence length and multimodal rope.
|
|
"""
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
if pixel_values is not None:
|
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
|
image_embeds = torch.cat(image_embeds, dim=0)
|
|
|
|
if input_ids is None:
|
|
image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
image_mask = image_mask.all(-1)
|
|
else:
|
|
image_mask = input_ids == self.config.image_token_id
|
|
|
|
n_image_tokens = image_mask.sum()
|
|
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
n_image_features = image_embeds.shape[0]
|
|
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
|
raise ValueError(
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
)
|
|
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
|
|
if pixel_values_videos is not None:
|
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
|
video_embeds = torch.cat(video_embeds, dim=0)
|
|
|
|
if input_ids is None:
|
|
video_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
video_mask = video_mask.all(-1)
|
|
else:
|
|
video_mask = input_ids == self.config.image_token_id
|
|
|
|
n_video_tokens = video_mask.sum()
|
|
n_video_features = video_embeds.shape[0]
|
|
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
|
raise ValueError(
|
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
)
|
|
|
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
|
|
if position_ids is None:
|
|
attention_mask_tensor = (
|
|
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
|
)
|
|
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
|
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
|
# Only apply conversion for floating point tensors (inverted masks)
|
|
if attention_mask_tensor.dtype.is_floating_point:
|
|
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
|
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
|
|
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
# When compiling, we can't check tensor values thus we check only input length
|
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
|
# models currently cannot do asssisted decoding
|
|
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
|
(input_ids is not None and input_ids.shape[1] != 1)
|
|
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
|
)
|
|
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
|
(cache_position is not None and cache_position[0] == 0)
|
|
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
|
)
|
|
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
|
position_ids, rope_deltas = self.get_rope_index(
|
|
input_ids,
|
|
image_grid_thw,
|
|
video_grid_thw,
|
|
attention_mask=attention_mask_tensor,
|
|
)
|
|
self.rope_deltas = rope_deltas
|
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
else:
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
delta = (
|
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
if cache_position is not None
|
|
else 0
|
|
)
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
position_ids = position_ids.add(delta)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
|
|
outputs = self.language_model(
|
|
input_ids=None,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
return Glm4vModelOutputWithPast(
|
|
last_hidden_state=outputs.last_hidden_state,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
rope_deltas=self.rope_deltas,
|
|
)
|
|
|
|
|
|
class Glm4vCausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
|
|
pass
|
|
|
|
|
|
class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
_checkpoint_conversion_mapping = {}
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Union[tuple, Glm4vCausalLMOutputWithPast]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
The rope index difference between sequence length and multimodal rope.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, Glm4vForConditionalGeneration
|
|
|
|
>>> model = Glm4vForConditionalGeneration.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
|
|
>>> processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
|
|
|
|
>>> messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image"},
|
|
{"type": "text", "text": "What is shown in this image?"},
|
|
],
|
|
},
|
|
]
|
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
```"""
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
pixel_values_videos=pixel_values_videos,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
|
|
|
return Glm4vCausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
rope_deltas=outputs.rope_deltas,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
attention_mask=None,
|
|
inputs_embeds=None,
|
|
cache_position=None,
|
|
position_ids=None,
|
|
use_cache=True,
|
|
pixel_values=None,
|
|
pixel_values_videos=None,
|
|
image_grid_thw=None,
|
|
video_grid_thw=None,
|
|
**kwargs,
|
|
):
|
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
|
|
model_inputs = super().prepare_inputs_for_generation(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
cache_position=cache_position,
|
|
position_ids=position_ids,
|
|
pixel_values=pixel_values,
|
|
pixel_values_videos=pixel_values_videos,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
use_cache=use_cache,
|
|
**kwargs,
|
|
)
|
|
|
|
# GLM-4.1V position_ids are prepareed with rope_deltas in forward
|
|
model_inputs["position_ids"] = None
|
|
|
|
if cache_position[0] != 0:
|
|
model_inputs["pixel_values"] = None
|
|
model_inputs["pixel_values_videos"] = None
|
|
|
|
return model_inputs
|
|
|
|
def _get_image_nums_and_video_nums(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor],
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
|
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
|
|
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Returns:
|
|
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
|
|
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
|
"""
|
|
|
|
if inputs_embeds is not None:
|
|
is_image = (
|
|
inputs_embeds
|
|
== self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
)[..., 0]
|
|
is_video_start = (
|
|
inputs_embeds
|
|
== self.get_input_embeddings()(
|
|
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
)[..., 0]
|
|
is_video_end = (
|
|
inputs_embeds
|
|
== self.get_input_embeddings()(
|
|
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
)[..., 0]
|
|
else:
|
|
is_image = input_ids == self.config.image_start_token_id
|
|
is_video_start = input_ids == self.config.video_start_token_id
|
|
is_video_end = input_ids == self.config.video_end_token_id
|
|
|
|
# Cumulative sum to track if we're inside a video span
|
|
# We'll assume well-formed video tags (i.e. matching starts and ends)
|
|
video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1)
|
|
inside_video = video_level > 0 # shape (batch_size, seq_length)
|
|
|
|
# Mask out image tokens that are inside video spans
|
|
standalone_images = is_image & (~inside_video)
|
|
|
|
# Count per batch
|
|
image_counts = standalone_images.sum(dim=1)
|
|
video_counts = is_video_start.sum(dim=1)
|
|
|
|
return image_counts, video_counts
|
|
|
|
|
|
class Glm4vVideosProcessorKwargs(Qwen2_5_VLVideosProcessorKwargs):
|
|
pass
|
|
|
|
|
|
class Glm4vImagesKwargs(ImagesKwargs):
|
|
patch_size: Optional[int]
|
|
temporal_patch_size: Optional[int]
|
|
merge_size: Optional[int]
|
|
|
|
|
|
class Glm4vProcessorKwargs(Qwen2_5_VLProcessorKwargs):
|
|
images_kwargs: Glm4vImagesKwargs
|
|
videos_kwargs: Glm4vVideosProcessorKwargs
|
|
_defaults = {
|
|
"text_kwargs": {
|
|
"padding": False,
|
|
"return_mm_token_type_ids": False,
|
|
},
|
|
}
|
|
|
|
|
|
class Glm4vProcessor(Qwen2_5_VLProcessor):
|
|
r"""
|
|
Constructs a GLM-4V processor which wraps a GLM-4V image processor and a GLM-4 tokenizer into a single processor.
|
|
[`~Glm4vProcessor.__call__`] and [`~Glm4vProcessor.decode`] for more information.
|
|
Args:
|
|
image_processor ([`Glm4vProcessor`], *optional*):
|
|
The image processor is a required input.
|
|
tokenizer ([`PreTrainedTokenizerFast`], *optional*):
|
|
The tokenizer is a required input.
|
|
video_processor ([`Glm4vVideoProcessor`], *optional*):
|
|
The video processor is a required input.
|
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
|
in a chat into a tokenizable string.
|
|
"""
|
|
|
|
tokenizer_class = ("PreTrainedTokenizer", "PreTrainedTokenizerFast")
|
|
|
|
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
|
|
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
|
self.image_token = "<|image|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
|
self.video_token = "<|video|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
|
|
|
def __call__(
|
|
self,
|
|
images: ImageInput = None,
|
|
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
|
videos: VideoInput = None,
|
|
**kwargs: Unpack[Glm4vProcessorKwargs],
|
|
) -> BatchFeature:
|
|
"""
|
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
|
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
|
|
the text.
|
|
|
|
Args:
|
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
tensor. Both channels-first and channels-last formats are supported.
|
|
text (`str`, `List[str]`, `List[List[str]]`):
|
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
|
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
|
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
|
If set, will return tensors of a particular framework. Acceptable values are:
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
|
|
Returns:
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
|
`None`).
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
|
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
|
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
|
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
|
"""
|
|
output_kwargs = self._merge_kwargs(
|
|
Glm4vProcessorKwargs,
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
**kwargs,
|
|
)
|
|
if images is not None:
|
|
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
|
image_grid_thw = image_inputs["image_grid_thw"]
|
|
else:
|
|
image_inputs = {}
|
|
image_grid_thw = None
|
|
|
|
if videos is not None:
|
|
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
|
timestamps = videos_inputs.pop("timestamps")
|
|
video_grid_thw = videos_inputs["video_grid_thw"]
|
|
else:
|
|
videos_inputs = {}
|
|
timestamps = []
|
|
video_grid_thw = None
|
|
|
|
if not isinstance(text, list):
|
|
text = [text]
|
|
|
|
text = text.copy() # below lines change text in-place
|
|
if image_grid_thw is not None:
|
|
merge_length = self.image_processor.merge_size**2
|
|
index = 0
|
|
for i in range(len(text)):
|
|
while self.image_token in text[i]:
|
|
num_image_tokens = image_grid_thw[index].prod() // merge_length
|
|
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
|
index += 1
|
|
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
|
|
if video_grid_thw is not None:
|
|
merge_length = self.video_processor.merge_size**2
|
|
video_index = 0
|
|
for i in range(len(text)):
|
|
while self.video_token in text[i]:
|
|
num_frames = video_grid_thw[video_index][0]
|
|
video_structure = ""
|
|
|
|
if hasattr(timestamps, "tolist"):
|
|
timestamps_list = timestamps.tolist()[0]
|
|
else:
|
|
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
|
|
|
|
unique_timestamps = []
|
|
for idx in range(0, len(timestamps_list)):
|
|
unique_timestamps.append(timestamps_list[idx])
|
|
|
|
selected_timestamps = unique_timestamps[:num_frames]
|
|
while len(selected_timestamps) < num_frames:
|
|
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
|
|
|
for frame_idx in range(num_frames):
|
|
timestamp_sec = selected_timestamps[frame_idx]
|
|
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
|
|
video_structure += frame_structure
|
|
|
|
text[i] = text[i].replace(self.video_token, video_structure, 1)
|
|
num_image_tokens = (
|
|
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
|
|
)
|
|
for frame_idx in range(num_frames):
|
|
if self.image_token in text[i]:
|
|
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
|
|
|
video_index += 1
|
|
|
|
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
|
|
|
if return_mm_token_type_ids:
|
|
array_ids = np.array(text_inputs["input_ids"])
|
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
|
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
|
|
|
|
|
|
__all__ = [
|
|
"Glm4vConfig",
|
|
"Glm4vTextConfig",
|
|
"Glm4vForConditionalGeneration",
|
|
"Glm4vModel",
|
|
"Glm4vPreTrainedModel",
|
|
"Glm4vProcessor",
|
|
"Glm4vTextModel",
|
|
]
|