380 lines
17 KiB
Python
380 lines
17 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
|
|
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliPreTrainedModel
|
|
from transformers.models.colpali.processing_colpali import ColPaliProcessor
|
|
|
|
from ...cache_utils import Cache
|
|
from ...feature_extraction_utils import BatchFeature
|
|
from ...image_utils import ImageInput, is_valid_image
|
|
from ...processing_utils import ProcessingKwargs, Unpack
|
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
|
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging
|
|
from .configuration_colqwen2 import ColQwen2Config
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class ColQwen2ProcessorKwargs(ProcessingKwargs, total=False):
|
|
_defaults = {
|
|
"text_kwargs": {
|
|
"padding": "longest",
|
|
},
|
|
"images_kwargs": {
|
|
"data_format": "channels_first",
|
|
"do_convert_rgb": True,
|
|
},
|
|
"common_kwargs": {"return_tensors": "pt"},
|
|
}
|
|
|
|
|
|
class ColQwen2Processor(ColPaliProcessor):
|
|
r"""
|
|
Constructs a ColQwen2 processor which wraps a Qwen2VLProcessor and special methods to process images and queries, as
|
|
well as to compute the late-interaction retrieval score.
|
|
|
|
[`ColQwen2Processor`] offers all the functionalities of [`Qwen2VLProcessor`]. See the [`~Qwen2VLProcessor.__call__`]
|
|
for more information.
|
|
|
|
Args:
|
|
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
|
The image processor is a required input.
|
|
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
|
The tokenizer is a required input.
|
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
|
in a chat into a tokenizable string.
|
|
visual_prompt_prefix (`str`, *optional*): A string that gets tokenized and prepended to the image tokens.
|
|
query_prefix (`str`, *optional*): A prefix to be used for the query.
|
|
"""
|
|
|
|
image_processor_class = "AutoImageProcessor"
|
|
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
|
|
|
def __init__(
|
|
self,
|
|
image_processor=None,
|
|
tokenizer=None,
|
|
chat_template=None,
|
|
visual_prompt_prefix: Optional[str] = None,
|
|
query_prefix: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
ColPaliProcessor().__init__(image_processor, tokenizer, chat_template=chat_template)
|
|
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
|
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
|
|
|
if visual_prompt_prefix is None:
|
|
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
|
|
self.visual_prompt_prefix = visual_prompt_prefix
|
|
|
|
if query_prefix is None:
|
|
query_prefix = "Query: "
|
|
self.query_prefix = query_prefix
|
|
|
|
def __call__(
|
|
self,
|
|
images: ImageInput = None,
|
|
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
|
audio=None,
|
|
videos=None,
|
|
**kwargs: Unpack[ColQwen2ProcessorKwargs],
|
|
) -> BatchFeature:
|
|
"""
|
|
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
|
|
wrapper around the Qwen2VLProcessor's [`~Qwen2VLProcessor.__call__`] method adapted for the ColQwen2 model. It cannot process
|
|
both text and images at the same time.
|
|
|
|
When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's
|
|
[`~Qwen2TokenizerFast.__call__`].
|
|
When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's
|
|
[`~Qwen2VLImageProcessor.__call__`].
|
|
Please refer to the doctsring of the above two methods for more information.
|
|
|
|
Args:
|
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
number of channels, H and W are image height and width.
|
|
text (`str`, `list[str]`, `list[list[str]]`):
|
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
|
If set, will return tensors of a particular framework. Acceptable values are:
|
|
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
|
|
Returns:
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model.
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
|
`None`).
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
|
"""
|
|
output_kwargs = self._merge_kwargs(
|
|
ColQwen2ProcessorKwargs,
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
**kwargs,
|
|
)
|
|
suffix = output_kwargs["text_kwargs"].pop("suffix", None)
|
|
|
|
return_token_type_ids = True if suffix is not None else False
|
|
|
|
if text is None and images is None:
|
|
raise ValueError("Either text or images must be provided")
|
|
if text is not None and images is not None:
|
|
raise ValueError("Only one of text or images can be processed at a time")
|
|
|
|
if images is not None:
|
|
if is_valid_image(images):
|
|
images = [images]
|
|
elif isinstance(images, list) and is_valid_image(images[0]):
|
|
pass
|
|
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
|
|
raise ValueError("images must be an image, list of images or list of list of images")
|
|
|
|
texts_doc = [self.visual_prompt_prefix] * len(images)
|
|
|
|
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
|
image_grid_thw = image_inputs["image_grid_thw"]
|
|
|
|
if image_grid_thw is not None:
|
|
merge_length = self.image_processor.merge_size**2
|
|
index = 0
|
|
for i in range(len(texts_doc)):
|
|
while self.image_token in texts_doc[i]:
|
|
texts_doc[i] = texts_doc[i].replace(
|
|
self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
|
|
)
|
|
index += 1
|
|
texts_doc[i] = texts_doc[i].replace("<|placeholder|>", self.image_token)
|
|
|
|
text_inputs = self.tokenizer(
|
|
texts_doc,
|
|
return_token_type_ids=False,
|
|
**output_kwargs["text_kwargs"],
|
|
)
|
|
|
|
return_data = BatchFeature(data={**text_inputs, **image_inputs})
|
|
|
|
# NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
|
|
offsets = return_data["image_grid_thw"][:, 1] * return_data["image_grid_thw"][:, 2] # (batch_size,)
|
|
|
|
# Split the pixel_values tensor into a list of tensors, one per image
|
|
pixel_values = list(
|
|
torch.split(return_data["pixel_values"], offsets.tolist())
|
|
) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
|
|
|
|
# Pad the list of pixel_value tensors to the same length along the sequence dimension
|
|
return_data["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
|
|
pixel_values, batch_first=True
|
|
) # (batch_size, max_num_patches, pixel_values)
|
|
|
|
if return_token_type_ids:
|
|
labels = return_data["input_ids"].masked_fill(return_data["token_type_ids"] == 0, -100)
|
|
return_data.update({"labels": labels})
|
|
|
|
return return_data
|
|
|
|
elif text is not None:
|
|
if isinstance(text, str):
|
|
text = [text]
|
|
elif not (isinstance(text, list) and isinstance(text[0], str)):
|
|
raise ValueError("Text must be a string or a list of strings")
|
|
|
|
if suffix is None:
|
|
suffix = self.query_augmentation_token * 10
|
|
|
|
texts_query: list[str] = []
|
|
|
|
for query in text:
|
|
augmented_query = self.query_prefix + query + suffix
|
|
texts_query.append(augmented_query)
|
|
|
|
batch_query = self.tokenizer(
|
|
texts_query,
|
|
return_token_type_ids=False,
|
|
**output_kwargs["text_kwargs"],
|
|
)
|
|
|
|
return batch_query
|
|
|
|
|
|
class ColQwen2PreTrainedModel(ColPaliPreTrainedModel):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for ColQwen2 embeddings output.
|
|
"""
|
|
)
|
|
class ColQwen2ForRetrievalOutput(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
Language modeling loss (for next-token prediction).
|
|
embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
The embeddings of the model.
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
embeddings: Optional[torch.Tensor] = None
|
|
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Following the ColPali approach, ColQwen2 leverages VLMs to construct efficient multi-vector embeddings directly
|
|
from document images (“screenshots”) for document retrieval. The model is trained to maximize the similarity
|
|
between these document embeddings and the corresponding query embeddings, using the late interaction method
|
|
introduced in ColBERT.
|
|
|
|
Using ColQwen2 removes the need for potentially complex and brittle layout recognition and OCR pipelines with
|
|
a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
|
|
|
|
ColQwen2 is part of the ColVision model family, which was introduced with ColPali in the following paper:
|
|
[*ColPali: Efficient Document Retrieval with Vision Language Models*](https://huggingface.co/papers/2407.01449).
|
|
"""
|
|
)
|
|
class ColQwen2ForRetrieval(ColPaliForRetrieval):
|
|
_checkpoint_conversion_mapping = {}
|
|
|
|
def __init__(self, config: ColQwen2Config):
|
|
super().__init__(config)
|
|
del self._tied_weights_keys
|
|
self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
) -> ColQwen2ForRetrievalOutput:
|
|
r"""
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
"""
|
|
if pixel_values is not None:
|
|
pixel_values = pixel_values.to(dtype=self.dtype) # (batch_size, max_num_patches, pixel_values)
|
|
|
|
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
|
|
if pixel_values is not None and image_grid_thw is not None:
|
|
# NOTE: image_grid_thw: (batch_size, 3) where image_grid_thw[i] = (num_patches_h, num_patches_w, temporal_patch_size)
|
|
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w)
|
|
pixel_values = torch.cat(
|
|
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
|
|
dim=0,
|
|
) # (num_patches_h * num_patches_w, pixel_values)
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
position_ids, rope_deltas = self.vlm.model.get_rope_index(
|
|
input_ids=input_ids,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=None,
|
|
attention_mask=attention_mask,
|
|
)
|
|
|
|
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
|
|
|
|
if pixel_values is not None:
|
|
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
|
|
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
|
|
image_mask = (
|
|
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
|
)
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
|
|
vlm_output = self.vlm.model(
|
|
input_ids=None,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
cache_position=cache_position,
|
|
)
|
|
|
|
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
|
|
|
|
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
|
|
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
|
|
|
|
# L2 normalization
|
|
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
|
if attention_mask is not None:
|
|
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
|
|
|
return ColQwen2ForRetrievalOutput(
|
|
embeddings=embeddings,
|
|
past_key_values=vlm_output.past_key_values,
|
|
hidden_states=vlm_hidden_states,
|
|
attentions=vlm_output.attentions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"ColQwen2ForRetrieval",
|
|
"ColQwen2PreTrainedModel",
|
|
"ColQwen2Processor",
|
|
]
|