team-10/venv/Lib/site-packages/transformers/models/colqwen2/modular_colqwen2.py
2025-08-02 02:00:33 +02:00

380 lines
17 KiB
Python

# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Union
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliPreTrainedModel
from transformers.models.colpali.processing_colpali import ColPaliProcessor
from ...cache_utils import Cache
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessingKwargs, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging
from .configuration_colqwen2 import ColQwen2Config
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": "longest",
},
"images_kwargs": {
"data_format": "channels_first",
"do_convert_rgb": True,
},
"common_kwargs": {"return_tensors": "pt"},
}
class ColQwen2Processor(ColPaliProcessor):
r"""
Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as
well as to compute the late-interaction retrieval score.
[`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`]
for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens.
query_prefix (`str`, *optional*): A prefix to be used for the query.
"""
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
visual_prompt_prefix: Optional[str] = None,
query_prefix: Optional[str] = None,
**kwargs,
):
ColPaliProcessor().__init__(image_processor, tokenizer, chat_template=chat_template)
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
if visual_prompt_prefix is None:
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
self.visual_prompt_prefix = visual_prompt_prefix
if query_prefix is None:
query_prefix = "Query: "
self.query_prefix = query_prefix
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[ColQwen2ProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process
both text and images at the same time.
When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's
[`~Qwen2TokenizerFast.__call__`].
When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's
[`~Qwen2VLImageProcessor.__call__`].
Please refer to the doctsring of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
ColQwen2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
suffix = output_kwargs["text_kwargs"].pop("suffix", None)
return_token_type_ids = True if suffix is not None else False
if text is None and images is None:
raise ValueError("Either text or images must be provided")
if text is not None and images is not None:
raise ValueError("Only one of text or images can be processed at a time")
if images is not None:
if is_valid_image(images):
images = [images]
elif isinstance(images, list) and is_valid_image(images[0]):
pass
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
raise ValueError("images must be an image, list of images or list of list of images")
texts_doc = [self.visual_prompt_prefix] * len(images)
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(texts_doc)):
while self.image_token in texts_doc[i]:
texts_doc[i] = texts_doc[i].replace(
self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
)
index += 1
texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token)
text_inputs = self.tokenizer(
texts_doc,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return_data = BatchFeature(data={**text_inputs, **image_inputs})
# NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,)
# Split the pixel_values tensor into a list of tensors, one per image
pixel_values = list(
torch.split(return_data["pixel_values"], offsets.tolist())
) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
# Pad the list of pixel_value tensors to the same length along the sequence dimension
return_data["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
pixel_values, batch_first=True
) # (batch_size, max_num_patches, pixel_values)
if return_token_type_ids:
labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100)
return_data.update({"labels": labels})
return return_data
elif text is not None:
if isinstance(text, str):
text = [text]
elif not (isinstance(text, list) and isinstance(text[0], str)):
raise ValueError("Text must be a string or a list of strings")
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: list[str] = []
for query in text:
augmented_query = self.query_prefix + query + suffix
texts_query.append(augmented_query)
batch_query = self.tokenizer(
texts_query,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return batch_query
class ColQwen2PreTrainedModel(ColPaliPreTrainedModel):
pass
@dataclass
@auto_docstring(
custom_intro="""
Base class for ColQwen2 embeddings output.
"""
)
class ColQwen2ForRetrievalOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
The embeddings of the model.
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
"""
loss: Optional[torch.FloatTensor] = None
embeddings: Optional[torch.Tensor] = None
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
@auto_docstring(
custom_intro="""
Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly
from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
between these document embeddings and the corresponding query embeddings, using the late interaction method
introduced in ColBERT.
Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with
a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper:
[*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
"""
)
class ColQwen2ForRetrieval(ColPaliForRetrieval):
_checkpoint_conversion_mapping = {}
def __init__(self, config: ColQwen2Config):
super().__init__(config)
del self._tied_weights_keys
self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
labels: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> ColQwen2ForRetrievalOutput:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values)
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
if pixel_values is not None and image_grid_thw is not None:
# NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w)
pixel_values = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
dim=0,
) # (num_patches_h * num_patches_w, pixel_values)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
position_ids, rope_deltas = self.vlm.model.get_rope_index(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=None,
attention_mask=attention_mask,
)
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
if inputs_embeds is None:
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
vlm_output = self.vlm.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
if attention_mask is not None:
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return ColQwen2ForRetrievalOutput(
embeddings=embeddings,
past_key_values=vlm_output.past_key_values,
hidden_states=vlm_hidden_states,
attentions=vlm_output.attentions,
)
__all__ = [
"ColQwen2ForRetrieval",
"ColQwen2PreTrainedModel",
"ColQwen2Processor",
]