1707 lines
77 KiB
Python
1707 lines
77 KiB
Python
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/glm4v/modular_glm4v.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_glm4v.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# coding=utf-8
|
|
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import itertools
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import LayerNorm
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...generation import GenerationMixin
|
|
from ...integrations import use_kernel_forward_from_hub
|
|
from ...masking_utils import create_causal_mask
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
|
from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@use_kernel_forward_from_hub("RMSNorm")
|
|
class Glm4vRMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
"""
|
|
Glm4vRMSNorm is equivalent to T5LayerNorm
|
|
"""
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
def extra_repr(self):
|
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
|
|
|
|
class Glm4VisionMlp(nn.Module):
|
|
def __init__(self, config, bias: bool = False):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.out_hidden_size
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
def forward(self, hidden_state):
|
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
|
class Glm4vVisionPatchEmbed(nn.Module):
|
|
def __init__(self, config: Glm4vVisionConfig) -> None:
|
|
super().__init__()
|
|
self.patch_size = config.patch_size
|
|
self.temporal_patch_size = config.temporal_patch_size
|
|
self.in_channels = config.in_channels
|
|
self.embed_dim = config.hidden_size
|
|
|
|
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
|
self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
target_dtype = self.proj.weight.dtype
|
|
hidden_states = hidden_states.view(
|
|
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
|
)
|
|
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
|
return hidden_states
|
|
|
|
|
|
class Glm4vVisionRotaryEmbedding(nn.Module):
|
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
super().__init__()
|
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
def forward(self, seqlen: int) -> torch.Tensor:
|
|
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
|
freqs = torch.outer(seq, self.inv_freq)
|
|
return freqs
|
|
|
|
|
|
class Glm4vVisionPatchMerger(nn.Module):
|
|
def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None:
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim, dim, bias=bias)
|
|
self.post_projection_norm = LayerNorm(dim)
|
|
self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
|
|
self.up_proj = nn.Linear(dim, context_dim, bias=bias)
|
|
self.down_proj = nn.Linear(context_dim, dim, bias=bias)
|
|
self.act1 = nn.GELU()
|
|
self.act_fn = ACT2FN[hidden_act]
|
|
|
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
hidden_state = self.proj(hidden_state)
|
|
hidden_state = self.act1(self.post_projection_norm(hidden_state))
|
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
|
class Glm4vVisionEmbeddings(nn.Module):
|
|
def __init__(self, config: Glm4vVisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
self.num_positions = self.num_patches
|
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
|
|
|
def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
|
|
"""
|
|
Forward pass with integrated position encoding adaptation using 2D interpolation.
|
|
|
|
Args:
|
|
embeddings: Input embeddings tensor
|
|
lengths (torch.Tensor): Sequence lengths for each image in the batch.
|
|
image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
|
|
h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
|
|
w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
|
|
|
|
Returns:
|
|
torch.Tensor: Embeddings with adapted position encoding added.
|
|
"""
|
|
# Get position embedding parameters
|
|
pos_embed_weight = self.position_embedding.weight
|
|
hidden_size = pos_embed_weight.shape[1]
|
|
total_seq = h_coords.shape[0]
|
|
device = pos_embed_weight.device
|
|
|
|
# Move coordinates to correct device
|
|
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
|
|
|
|
# Handle empty sequence case
|
|
if total_seq == 0:
|
|
adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)
|
|
else:
|
|
# Convert inputs to tensors if needed
|
|
if isinstance(lengths, list):
|
|
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
|
|
if not isinstance(image_shapes, torch.Tensor):
|
|
image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)
|
|
|
|
# Prepare 2D position embedding
|
|
orig_size_sq = pos_embed_weight.shape[0]
|
|
orig_size = int(orig_size_sq**0.5)
|
|
pos_embed_2d = (
|
|
pos_embed_weight.view(orig_size, orig_size, hidden_size)
|
|
.permute(2, 0, 1)
|
|
.unsqueeze(0)
|
|
.to(device=device, dtype=torch.float32)
|
|
)
|
|
|
|
# Calculate target dimensions for each patch
|
|
target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
|
|
device=device, dtype=torch.float32
|
|
)
|
|
target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
|
|
device=device, dtype=torch.float32
|
|
)
|
|
|
|
# Normalize coordinates to [-1, 1] range for grid_sample
|
|
h_coords = h_coords.to(device=device, dtype=torch.float32)
|
|
w_coords = w_coords.to(device=device, dtype=torch.float32)
|
|
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
|
|
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
|
|
|
|
# Create sampling grid
|
|
grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
|
|
|
|
# Perform bicubic interpolation
|
|
interpolated_embed_fp32 = F.grid_sample(
|
|
pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border"
|
|
)
|
|
|
|
# Reshape and convert back to original dtype
|
|
adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
|
|
adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
|
|
|
|
# Add adapted position encoding to embeddings
|
|
embeddings = embeddings + adapted_pos_embed
|
|
return embeddings
|
|
|
|
|
|
def rotate_half(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def apply_rotary_pos_emb_vision(
|
|
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
orig_q_dtype = q.dtype
|
|
orig_k_dtype = k.dtype
|
|
q, k = q.float(), k.float()
|
|
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
q_embed = q_embed.to(orig_q_dtype)
|
|
k_embed = k_embed.to(orig_k_dtype)
|
|
return q_embed, k_embed
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
"""
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor],
|
|
scaling: float,
|
|
dropout: float = 0.0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
):
|
|
key_states = repeat_kv(key, module.num_key_value_groups)
|
|
value_states = repeat_kv(value, module.num_key_value_groups)
|
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
|
if attention_mask is not None:
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class Glm4vVisionAttention(nn.Module):
|
|
def __init__(self, config: Glm4vVisionConfig) -> None:
|
|
super().__init__()
|
|
self.dim = config.hidden_size
|
|
self.num_heads = config.num_heads
|
|
self.head_dim = self.dim // self.num_heads
|
|
self.num_key_value_groups = 1 # needed for eager attention
|
|
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
|
|
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.config = config
|
|
self.attention_dropout = config.attention_dropout
|
|
self.is_causal = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
seq_length = hidden_states.shape[0]
|
|
query_states, key_states, value_states = (
|
|
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
|
)
|
|
if position_embeddings is None:
|
|
logger.warning_once(
|
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
|
"removed and `position_embeddings` will be mandatory."
|
|
)
|
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
cos = emb.cos()
|
|
sin = emb.sin()
|
|
else:
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
|
|
|
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
|
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
|
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
if self.config._attn_implementation == "flash_attention_2":
|
|
# Flash Attention 2: Use cu_seqlens for variable length attention
|
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
|
attn_output, _ = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask=None,
|
|
scaling=self.scaling,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
cu_seq_lens_q=cu_seqlens,
|
|
cu_seq_lens_k=cu_seqlens,
|
|
max_length_q=max_seqlen,
|
|
max_length_k=max_seqlen,
|
|
is_causal=False,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
# Other implementations: Process each chunk separately
|
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
splits = [
|
|
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
|
]
|
|
|
|
attn_outputs = [
|
|
attention_interface(
|
|
self,
|
|
q,
|
|
k,
|
|
v,
|
|
attention_mask=None,
|
|
scaling=self.scaling,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
is_causal=False,
|
|
**kwargs,
|
|
)[0]
|
|
for q, k, v in zip(*splits)
|
|
]
|
|
attn_output = torch.cat(attn_outputs, dim=1)
|
|
|
|
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
|
attn_output = self.proj(attn_output)
|
|
return attn_output
|
|
|
|
|
|
class Glm4vVisionBlock(GradientCheckpointingLayer):
|
|
def __init__(self, config) -> None:
|
|
super().__init__()
|
|
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.attn = Glm4vVisionAttention(config)
|
|
self.mlp = Glm4VisionMlp(config, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
hidden_states = hidden_states + self.attn(
|
|
self.norm1(hidden_states),
|
|
cu_seqlens=cu_seqlens,
|
|
rotary_pos_emb=rotary_pos_emb,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
|
return hidden_states
|
|
|
|
|
|
@auto_docstring
|
|
class Glm4vPreTrainedModel(PreTrainedModel):
|
|
config: Glm4vConfig
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"]
|
|
_skip_keys_device_placement = "past_key_values"
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
|
|
_can_compile_fullgraph = True
|
|
_supports_attention_backend = True
|
|
|
|
|
|
class Glm4vVisionModel(Glm4vPreTrainedModel):
|
|
config: Glm4vVisionConfig
|
|
_no_split_modules = ["Glm4vVisionBlock"]
|
|
|
|
def __init__(self, config) -> None:
|
|
super().__init__(config)
|
|
self.spatial_merge_size = config.spatial_merge_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.embeddings = Glm4vVisionEmbeddings(config)
|
|
self.patch_embed = Glm4vVisionPatchEmbed(config)
|
|
|
|
head_dim = config.hidden_size // config.num_heads
|
|
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
|
|
|
|
self.blocks = nn.ModuleList([Glm4vVisionBlock(config) for _ in range(config.depth)])
|
|
self.merger = Glm4vVisionPatchMerger(
|
|
dim=config.out_hidden_size, context_dim=config.intermediate_size, hidden_act=config.hidden_act
|
|
)
|
|
|
|
self.post_conv_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.downsample = nn.Conv2d(
|
|
in_channels=config.hidden_size,
|
|
out_channels=config.out_hidden_size,
|
|
kernel_size=config.spatial_merge_size,
|
|
stride=config.spatial_merge_size,
|
|
)
|
|
self.post_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
self.gradient_checkpointing = False
|
|
self.post_init()
|
|
|
|
def rot_pos_emb(self, grid_thw):
|
|
pos_ids = []
|
|
for t, h, w in grid_thw:
|
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
|
hpos_ids = hpos_ids.reshape(
|
|
h // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
w // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
)
|
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
|
hpos_ids = hpos_ids.flatten()
|
|
|
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
|
wpos_ids = wpos_ids.reshape(
|
|
h // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
w // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
)
|
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
|
wpos_ids = wpos_ids.flatten()
|
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
pos_ids = torch.cat(pos_ids, dim=0)
|
|
max_grid_size = grid_thw[:, 1:].max()
|
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
return rotary_pos_emb, pos_ids
|
|
|
|
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
|
The final hidden states of the model.
|
|
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
|
|
Returns:
|
|
`torch.Tensor`: hidden_states.
|
|
"""
|
|
hidden_states = self.patch_embed(hidden_states)
|
|
hidden_states = self.post_conv_layernorm(hidden_states)
|
|
|
|
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
|
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
position_embeddings = (emb.cos(), emb.sin())
|
|
|
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
|
dim=0,
|
|
# Select dtype based on the following factors:
|
|
# - FA2 requires that cu_seqlens_q must have dtype int32
|
|
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
|
|
# See https://github.com/huggingface/transformers/pull/34852 for more information
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
)
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
|
|
|
for blk in self.blocks:
|
|
hidden_states = blk(
|
|
hidden_states,
|
|
cu_seqlens=cu_seqlens,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
|
|
hidden_states = self.post_layernorm(hidden_states)
|
|
|
|
hidden_states = hidden_states.view(
|
|
-1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1]
|
|
)
|
|
hidden_states = hidden_states.permute(0, 3, 1, 2)
|
|
hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size)
|
|
|
|
hidden_states = self.merger(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Glm4vTextRotaryEmbedding(nn.Module):
|
|
def __init__(self, config: Glm4vTextConfig, device=None):
|
|
super().__init__()
|
|
# BC: "rope_type" was originally "type"
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
else:
|
|
self.rope_type = "default"
|
|
self.max_seq_len_cached = config.max_position_embeddings
|
|
self.original_max_seq_len = config.max_position_embeddings
|
|
|
|
self.config = config
|
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
self.original_inv_freq = self.inv_freq
|
|
|
|
@torch.no_grad()
|
|
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
|
def forward(self, x, position_ids):
|
|
# In contrast to other models, Glm4vText has different position ids for the grids
|
|
# So we expand the inv_freq to shape (3, ...)
|
|
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
cos = emb.cos() * self.attention_scaling
|
|
sin = emb.sin() * self.attention_scaling
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
|
|
def rotate_half_llm(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., 0::2]
|
|
x2 = x[..., 1::2]
|
|
return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
|
|
|
|
|
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
|
|
|
|
Explanation:
|
|
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
|
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
|
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
|
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
|
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
|
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
|
difference with modern LLMs.
|
|
|
|
Args:
|
|
q (`torch.Tensor`): The query tensor.
|
|
k (`torch.Tensor`): The key tensor.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
mrope_section(`List(int)`):
|
|
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
Returns:
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
"""
|
|
mrope_section = mrope_section * 2
|
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
|
unsqueeze_dim
|
|
)
|
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
|
unsqueeze_dim
|
|
)
|
|
|
|
# Interleave them instead of usual shape
|
|
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
|
|
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
|
|
|
|
# Keep half or full tensor for later concatenation
|
|
rotary_dim = cos.shape[-1]
|
|
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
|
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
|
|
|
# Apply rotary embeddings on the first half or full tensor
|
|
q_embed = (q_rot * cos) + (rotate_half_llm(q_rot) * sin)
|
|
k_embed = (k_rot * cos) + (rotate_half_llm(k_rot) * sin)
|
|
|
|
# Concatenate back to full shape
|
|
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
|
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
|
|
|
return q_embed, k_embed
|
|
|
|
|
|
class Glm4vTextAttention(nn.Module):
|
|
"""
|
|
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
and "Generating Long Sequences with Sparse Transformers".
|
|
"""
|
|
|
|
def __init__(self, config: Glm4vTextConfig, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.is_causal = True
|
|
self.attention_dropout = config.attention_dropout
|
|
self.rope_scaling = config.rope_scaling
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
|
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
|
)
|
|
|
|
if past_key_value is not None:
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
scaling=self.scaling,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
class Glm4vTextMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
|
|
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
|
|
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
|
up_states = self.gate_up_proj(hidden_states)
|
|
|
|
gate, up_states = up_states.chunk(2, dim=-1)
|
|
up_states = up_states * self.activation_fn(gate)
|
|
|
|
return self.down_proj(up_states)
|
|
|
|
|
|
class Glm4vTextDecoderLayer(GradientCheckpointingLayer):
|
|
def __init__(self, config: Glm4vTextConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.self_attn = Glm4vTextAttention(config, layer_idx)
|
|
self.mlp = Glm4vTextMLP(config)
|
|
self.input_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_self_attn_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_mlp_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = self.post_self_attn_layernorm(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.post_mlp_layernorm(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for Llava outputs, with hidden states and attentions.
|
|
"""
|
|
)
|
|
class Glm4vModelOutputWithPast(ModelOutput):
|
|
r"""
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
The rope index difference between sequence length and multimodal rope.
|
|
"""
|
|
|
|
last_hidden_state: torch.FloatTensor = None
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
rope_deltas: Optional[torch.LongTensor] = None
|
|
|
|
|
|
@auto_docstring
|
|
class Glm4vTextModel(Glm4vPreTrainedModel):
|
|
config: Glm4vTextConfig
|
|
|
|
def __init__(self, config: Glm4vTextConfig):
|
|
super().__init__(config)
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
self.layers = nn.ModuleList(
|
|
[Glm4vTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
self.norm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.rotary_emb = Glm4vTextRotaryEmbedding(config=config)
|
|
|
|
self.gradient_checkpointing = False
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
)
|
|
use_cache = False
|
|
|
|
# torch.jit.trace() doesn't support cache objects in the output
|
|
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
|
past_key_values = DynamicCache()
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
|
|
# the hard coded `3` is for temporal, height and width.
|
|
if position_ids is None:
|
|
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
|
elif position_ids.dim() == 2:
|
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
|
|
|
causal_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
|
|
for decoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values if use_cache else None,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class Glm4vModel(Glm4vPreTrainedModel):
|
|
base_model_prefix = ""
|
|
_checkpoint_conversion_mapping = {}
|
|
config: Glm4vConfig
|
|
_no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.visual = Glm4vVisionModel._from_config(config.vision_config)
|
|
self.language_model = Glm4vTextModel._from_config(config.text_config)
|
|
self.rope_deltas = None # cache rope_deltas here
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.language_model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.language_model.set_input_embeddings(value)
|
|
|
|
def set_decoder(self, decoder):
|
|
self.language_model = decoder
|
|
|
|
def get_decoder(self):
|
|
return self.language_model
|
|
|
|
def get_rope_index(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
|
|
|
Explanation:
|
|
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
|
|
|
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
|
|
Examples:
|
|
input_ids: [T T T T T], here T is for text.
|
|
temporal position_ids: [0, 1, 2, 3, 4]
|
|
height position_ids: [0, 1, 2, 3, 4]
|
|
width position_ids: [0, 1, 2, 3, 4]
|
|
|
|
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
|
and 1D rotary position embedding for text part.
|
|
Examples:
|
|
Temporal (Time): 3 patches, representing different segments of the video in time.
|
|
Height: 2 patches, dividing each frame vertically.
|
|
Width: 2 patches, dividing each frame horizontally.
|
|
We also have some important parameters:
|
|
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
|
|
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
|
|
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
|
|
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
|
|
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
|
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
|
|
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
|
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
|
text temporal position_ids: [101, 102, 103, 104, 105]
|
|
text height position_ids: [101, 102, 103, 104, 105]
|
|
text width position_ids: [101, 102, 103, 104, 105]
|
|
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
|
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
Returns:
|
|
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
|
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
|
"""
|
|
|
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
|
image_token_id = self.config.image_token_id
|
|
video_start_token_id = self.config.video_start_token_id
|
|
video_end_token_id = self.config.video_end_token_id
|
|
|
|
mrope_position_deltas = []
|
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
|
total_input_ids = input_ids
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(total_input_ids)
|
|
position_ids = torch.ones(
|
|
3,
|
|
input_ids.shape[0],
|
|
input_ids.shape[1],
|
|
dtype=input_ids.dtype,
|
|
device=input_ids.device,
|
|
)
|
|
image_index, video_index = 0, 0
|
|
video_group_index = 0
|
|
attention_mask = attention_mask.to(total_input_ids.device)
|
|
for i, input_ids in enumerate(total_input_ids):
|
|
input_ids = input_ids[attention_mask[i] == 1]
|
|
input_tokens = input_ids.tolist()
|
|
|
|
input_token_type = []
|
|
video_check_flg = False
|
|
for token in input_tokens:
|
|
if token == video_start_token_id:
|
|
video_check_flg = True
|
|
elif token == video_end_token_id:
|
|
video_check_flg = False
|
|
|
|
if token == image_token_id and not video_check_flg:
|
|
input_token_type.append("image")
|
|
elif token == image_token_id and video_check_flg:
|
|
input_token_type.append("video")
|
|
else:
|
|
input_token_type.append("text")
|
|
|
|
input_type_group = []
|
|
for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):
|
|
group = list(group)
|
|
start_index = group[0][0]
|
|
end_index = group[-1][0] + 1
|
|
input_type_group.append((key, start_index, end_index))
|
|
|
|
llm_pos_ids_list = []
|
|
video_frame_num = 1
|
|
for modality_type, start_idx, end_idx in input_type_group:
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
|
|
if modality_type == "image":
|
|
t, h, w = (
|
|
image_grid_thw[image_index][0],
|
|
image_grid_thw[image_index][1],
|
|
image_grid_thw[image_index][2],
|
|
)
|
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
t.item(),
|
|
h.item() // spatial_merge_size,
|
|
w.item() // spatial_merge_size,
|
|
)
|
|
|
|
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
|
|
|
image_index += 1
|
|
video_frame_num = 1
|
|
|
|
elif modality_type == "video":
|
|
t, h, w = (
|
|
video_frame_num,
|
|
video_grid_thw[video_index][1],
|
|
video_grid_thw[video_index][2],
|
|
)
|
|
|
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
t,
|
|
h.item() // spatial_merge_size,
|
|
w.item() // spatial_merge_size,
|
|
)
|
|
|
|
for t_idx in range(llm_grid_t):
|
|
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
|
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
|
|
|
video_group_index += 1
|
|
|
|
if video_group_index >= video_grid_thw[video_index][0]:
|
|
video_index += 1
|
|
video_group_index = 0
|
|
|
|
video_frame_num += 1
|
|
|
|
else:
|
|
text_len = end_idx - start_idx
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
|
|
video_frame_num = 1
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
|
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
|
return position_ids, mrope_position_deltas
|
|
else:
|
|
if attention_mask is not None:
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
else:
|
|
position_ids = (
|
|
torch.arange(input_ids.shape[1], device=input_ids.device)
|
|
.view(1, 1, -1)
|
|
.expand(3, input_ids.shape[0], -1)
|
|
)
|
|
mrope_position_deltas = torch.zeros(
|
|
[input_ids.shape[0], 1],
|
|
device=input_ids.device,
|
|
dtype=input_ids.dtype,
|
|
)
|
|
|
|
return position_ids, mrope_position_deltas
|
|
|
|
def get_video_features(
|
|
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
|
):
|
|
"""
|
|
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
|
|
|
Args:
|
|
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
The tensors corresponding to the input videos.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
"""
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
|
temp_frames_hw = []
|
|
for t, h, w in video_grid_thw:
|
|
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
|
temp_frames_hw.append(repeated_row)
|
|
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
|
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
|
|
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
|
video_embeds = torch.split(video_embeds, split_sizes)
|
|
return video_embeds
|
|
|
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
|
"""
|
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
|
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
The tensors corresponding to the input images.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
"""
|
|
pixel_values = pixel_values.type(self.visual.dtype)
|
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
|
image_embeds = torch.split(image_embeds, split_sizes)
|
|
return image_embeds
|
|
|
|
@auto_docstring
|
|
@can_return_tuple
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Union[tuple, Glm4vModelOutputWithPast]:
|
|
r"""
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
The rope index difference between sequence length and multimodal rope.
|
|
"""
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
if pixel_values is not None:
|
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
|
image_embeds = torch.cat(image_embeds, dim=0)
|
|
|
|
if input_ids is None:
|
|
image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
image_mask = image_mask.all(-1)
|
|
else:
|
|
image_mask = input_ids == self.config.image_token_id
|
|
|
|
n_image_tokens = image_mask.sum()
|
|
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
n_image_features = image_embeds.shape[0]
|
|
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
|
raise ValueError(
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
)
|
|
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
|
|
if pixel_values_videos is not None:
|
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
|
video_embeds = torch.cat(video_embeds, dim=0)
|
|
|
|
if input_ids is None:
|
|
video_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
video_mask = video_mask.all(-1)
|
|
else:
|
|
video_mask = input_ids == self.config.image_token_id
|
|
|
|
n_video_tokens = video_mask.sum()
|
|
n_video_features = video_embeds.shape[0]
|
|
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
|
raise ValueError(
|
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
)
|
|
|
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
|
|
if position_ids is None:
|
|
attention_mask_tensor = (
|
|
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
|
)
|
|
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
|
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
|
# Only apply conversion for floating point tensors (inverted masks)
|
|
if attention_mask_tensor.dtype.is_floating_point:
|
|
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
|
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
|
|
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
# When compiling, we can't check tensor values thus we check only input length
|
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
|
# models currently cannot do asssisted decoding
|
|
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
|
(input_ids is not None and input_ids.shape[1] != 1)
|
|
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
|
)
|
|
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
|
(cache_position is not None and cache_position[0] == 0)
|
|
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
|
)
|
|
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
|
position_ids, rope_deltas = self.get_rope_index(
|
|
input_ids,
|
|
image_grid_thw,
|
|
video_grid_thw,
|
|
attention_mask=attention_mask_tensor,
|
|
)
|
|
self.rope_deltas = rope_deltas
|
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
else:
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
delta = (
|
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
if cache_position is not None
|
|
else 0
|
|
)
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
position_ids = position_ids.add(delta)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
|
|
outputs = self.language_model(
|
|
input_ids=None,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
return Glm4vModelOutputWithPast(
|
|
last_hidden_state=outputs.last_hidden_state,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
rope_deltas=self.rope_deltas,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for Glm4v causal language model (or autoregressive) outputs.
|
|
"""
|
|
)
|
|
class Glm4vCausalLMOutputWithPast(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
Language modeling loss (for next-token prediction).
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
The rope index difference between sequence length and multimodal rope.
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
logits: Optional[torch.FloatTensor] = None
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
rope_deltas: Optional[torch.LongTensor] = None
|
|
|
|
|
|
class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
|
|
_checkpoint_conversion_mapping = {}
|
|
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = Glm4vModel(config)
|
|
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.model.set_input_embeddings(value)
|
|
|
|
def set_decoder(self, decoder):
|
|
self.model.set_decoder(decoder)
|
|
|
|
def get_decoder(self):
|
|
return self.model.get_decoder()
|
|
|
|
def get_video_features(
|
|
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
|
):
|
|
return self.model.get_video_features(pixel_values_videos, video_grid_thw)
|
|
|
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
|
return self.model.get_image_features(pixel_values, image_grid_thw)
|
|
|
|
# Make modules available throught conditional class for BC
|
|
@property
|
|
def language_model(self):
|
|
return self.model.language_model
|
|
|
|
@property
|
|
def visual(self):
|
|
return self.model.visual
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Union[tuple, Glm4vCausalLMOutputWithPast]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
The rope index difference between sequence length and multimodal rope.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, Glm4vForConditionalGeneration
|
|
|
|
>>> model = Glm4vForConditionalGeneration.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
|
|
>>> processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
|
|
|
|
>>> messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image"},
|
|
{"type": "text", "text": "What is shown in this image?"},
|
|
],
|
|
},
|
|
]
|
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
```"""
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
pixel_values_videos=pixel_values_videos,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
|
|
|
return Glm4vCausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
rope_deltas=outputs.rope_deltas,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
attention_mask=None,
|
|
inputs_embeds=None,
|
|
cache_position=None,
|
|
position_ids=None,
|
|
use_cache=True,
|
|
pixel_values=None,
|
|
pixel_values_videos=None,
|
|
image_grid_thw=None,
|
|
video_grid_thw=None,
|
|
**kwargs,
|
|
):
|
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
|
|
model_inputs = super().prepare_inputs_for_generation(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
cache_position=cache_position,
|
|
position_ids=position_ids,
|
|
pixel_values=pixel_values,
|
|
pixel_values_videos=pixel_values_videos,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
use_cache=use_cache,
|
|
**kwargs,
|
|
)
|
|
|
|
# GLM-4.1V position_ids are prepareed with rope_deltas in forward
|
|
model_inputs["position_ids"] = None
|
|
|
|
if cache_position[0] != 0:
|
|
model_inputs["pixel_values"] = None
|
|
model_inputs["pixel_values_videos"] = None
|
|
|
|
return model_inputs
|
|
|
|
def _get_image_nums_and_video_nums(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor],
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
|
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
|
|
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Returns:
|
|
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
|
|
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
|
"""
|
|
|
|
if inputs_embeds is not None:
|
|
is_image = (
|
|
inputs_embeds
|
|
== self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
)[..., 0]
|
|
is_video_start = (
|
|
inputs_embeds
|
|
== self.get_input_embeddings()(
|
|
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
)[..., 0]
|
|
is_video_end = (
|
|
inputs_embeds
|
|
== self.get_input_embeddings()(
|
|
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
)[..., 0]
|
|
else:
|
|
is_image = input_ids == self.config.image_start_token_id
|
|
is_video_start = input_ids == self.config.video_start_token_id
|
|
is_video_end = input_ids == self.config.video_end_token_id
|
|
|
|
# Cumulative sum to track if we're inside a video span
|
|
# We'll assume well-formed video tags (i.e. matching starts and ends)
|
|
video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1)
|
|
inside_video = video_level > 0 # shape (batch_size, seq_length)
|
|
|
|
# Mask out image tokens that are inside video spans
|
|
standalone_images = is_image & (~inside_video)
|
|
|
|
# Count per batch
|
|
image_counts = standalone_images.sum(dim=1)
|
|
video_counts = is_video_start.sum(dim=1)
|
|
|
|
return image_counts, video_counts
|
|
|
|
def _expand_inputs_for_generation(
|
|
self,
|
|
expand_size: int = 1,
|
|
is_encoder_decoder: bool = False,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
**model_kwargs,
|
|
) -> tuple[torch.LongTensor, dict[str, Any]]:
|
|
# Overwritten -- Support for expanding tensors without a batch size dimension
|
|
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
|
|
# pixel_values.shape[0] is sum(seqlen_images for samples)
|
|
# image_grid_thw.shape[0] is sum(num_images for samples)
|
|
|
|
if expand_size == 1:
|
|
return input_ids, model_kwargs
|
|
|
|
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
|
|
|
|
def _expand_dict_for_generation_visual(dict_to_expand):
|
|
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
|
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
|
image_nums, video_nums = self._get_image_nums_and_video_nums(
|
|
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
|
|
)
|
|
|
|
def _repeat_interleave_samples(x, lengths, repeat_times):
|
|
samples = torch.split(x, lengths)
|
|
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
|
|
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
|
|
return result
|
|
|
|
for key in dict_to_expand:
|
|
if key == "pixel_values":
|
|
# split images into samples
|
|
samples = torch.split(image_grid_thw, list(image_nums))
|
|
# compute the sequence length of images for each sample
|
|
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
|
)
|
|
elif key == "image_grid_thw":
|
|
# get the num of images for each sample
|
|
lengths = list(image_nums)
|
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
|
)
|
|
elif key == "pixel_values_videos":
|
|
samples = torch.split(video_grid_thw, list(video_nums))
|
|
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
|
)
|
|
elif key == "video_grid_thw":
|
|
lengths = list(video_nums)
|
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
|
)
|
|
elif key == "second_per_grid_ts":
|
|
if not isinstance(dict_to_expand[key], list):
|
|
raise TypeError(
|
|
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
|
|
)
|
|
tensor = torch.tensor(dict_to_expand[key])
|
|
lengths = list(video_nums)
|
|
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
|
|
dict_to_expand[key] = tensor.tolist()
|
|
return dict_to_expand
|
|
|
|
def _expand_dict_for_generation(dict_to_expand):
|
|
for key in dict_to_expand:
|
|
if (
|
|
key != "cache_position"
|
|
and dict_to_expand[key] is not None
|
|
and isinstance(dict_to_expand[key], torch.Tensor)
|
|
and key not in visual_keys
|
|
):
|
|
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
|
return dict_to_expand
|
|
|
|
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
|
|
|
if input_ids is not None:
|
|
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
|
|
|
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
|
|
|
if is_encoder_decoder:
|
|
if model_kwargs.get("encoder_outputs") is None:
|
|
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
|
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
|
|
|
return input_ids, model_kwargs
|
|
|
|
|
|
__all__ = ["Glm4vForConditionalGeneration", "Glm4vModel", "Glm4vPreTrainedModel", "Glm4vTextModel"]
|