435 lines
20 KiB
Python
435 lines
20 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2024 The HuggingFace Inc. team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""
|
||
|
Processor class for Idefics3.
|
||
|
"""
|
||
|
|
||
|
import re
|
||
|
from itertools import accumulate
|
||
|
from typing import TYPE_CHECKING, Optional, Union
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from ...feature_extraction_utils import BatchFeature
|
||
|
from ...image_utils import ImageInput, is_valid_image, load_image
|
||
|
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||
|
from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput
|
||
|
from ...utils import logging
|
||
|
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from ...tokenization_utils_base import PreTokenizedInput
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
|
||
|
def is_url(val) -> bool:
|
||
|
return isinstance(val, str) and val.startswith("http")
|
||
|
|
||
|
|
||
|
def is_image_or_image_url(elem):
|
||
|
return is_url(elem) or is_valid_image(elem)
|
||
|
|
||
|
|
||
|
def _prompt_split_image(image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_img_token):
|
||
|
"""Prompt with expanded image tokens for when the image is split into patches."""
|
||
|
text_split_images = ""
|
||
|
for n_h in range(image_rows):
|
||
|
for n_w in range(image_cols):
|
||
|
text_split_images += (
|
||
|
f"{fake_token_around_image}" + f"<row_{n_h + 1}_col_{n_w + 1}>" + f"{image_token}" * image_seq_len
|
||
|
)
|
||
|
text_split_images += "\n"
|
||
|
|
||
|
text_split_images += (
|
||
|
f"\n{fake_token_around_image}"
|
||
|
+ f"{global_img_token}"
|
||
|
+ f"{image_token}" * image_seq_len
|
||
|
+ f"{fake_token_around_image}"
|
||
|
)
|
||
|
return text_split_images
|
||
|
|
||
|
|
||
|
def _prompt_single_image(image_seq_len, fake_token_around_image, image_token, global_img_token):
|
||
|
"""Prompt with expanded image tokens for a single image."""
|
||
|
return (
|
||
|
f"{fake_token_around_image}"
|
||
|
+ f"{global_img_token}"
|
||
|
+ f"{image_token}" * image_seq_len
|
||
|
+ f"{fake_token_around_image}"
|
||
|
)
|
||
|
|
||
|
|
||
|
def get_image_prompt_string(
|
||
|
image_rows, image_cols, image_seq_len, fake_token_around_image, image_token, global_img_token
|
||
|
):
|
||
|
if image_rows == 0 and image_cols == 0:
|
||
|
return _prompt_single_image(
|
||
|
image_seq_len,
|
||
|
fake_token_around_image=fake_token_around_image,
|
||
|
image_token=image_token,
|
||
|
global_img_token=global_img_token,
|
||
|
)
|
||
|
return _prompt_split_image(
|
||
|
image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_img_token
|
||
|
)
|
||
|
|
||
|
|
||
|
class Idefics3ImagesKwargs(ImagesKwargs, total=False):
|
||
|
return_row_col_info: Optional[bool]
|
||
|
max_image_size: Optional[dict[str, int]]
|
||
|
|
||
|
|
||
|
class Idefics3ProcessorKwargs(ProcessingKwargs, total=False):
|
||
|
images_kwargs: Idefics3ImagesKwargs
|
||
|
|
||
|
_defaults = {
|
||
|
"text_kwargs": {
|
||
|
"add_special_tokens": True,
|
||
|
"padding": False,
|
||
|
"is_split_into_words": False,
|
||
|
"return_mm_token_type_ids": False,
|
||
|
},
|
||
|
"images_kwargs": {
|
||
|
"return_row_col_info": True,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
|
||
|
Idefics3ProcessorKwargs.__annotations__["images_kwargs"] = Idefics3ImagesKwargs # python 3.8 compatibility
|
||
|
|
||
|
|
||
|
class Idefics3Processor(ProcessorMixin):
|
||
|
r"""
|
||
|
Constructs a Idefics3 processor which wraps a LLama tokenizer and Idefics3 image processor into a single processor.
|
||
|
|
||
|
[`Idefics3Processor`] offers all the functionalities of [`Idefics3ImageProcessor`] and [`Idefics3TokenizerFast`]. See
|
||
|
the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
|
||
|
|
||
|
Args:
|
||
|
image_processor (`Idefics3ImageProcessor`):
|
||
|
An instance of [`Idefics3ImageProcessor`]. The image processor is a required input.
|
||
|
tokenizer (`PreTrainedTokenizerBase`, *optional*):
|
||
|
An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
|
||
|
image_seq_len (`int`, *optional*, defaults to 169):
|
||
|
The length of the image sequence i.e. the number of <image> tokens per image in the input.
|
||
|
This parameter is used to build the string from the input prompt and image tokens and should match the
|
||
|
value the model used. It is computed as: image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
|
||
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||
|
in a chat into a tokenizable string.
|
||
|
"""
|
||
|
|
||
|
attributes = ["image_processor", "tokenizer"]
|
||
|
image_processor_class = "Idefics3ImageProcessor"
|
||
|
tokenizer_class = "AutoTokenizer"
|
||
|
|
||
|
def __init__(
|
||
|
self, image_processor, tokenizer=None, image_seq_len: int = 169, chat_template: Optional[str] = None, **kwargs
|
||
|
):
|
||
|
if image_processor is None:
|
||
|
raise ValueError("You need to specify an `image_processor`.")
|
||
|
if tokenizer is None:
|
||
|
raise ValueError("You need to specify a `tokenizer`.")
|
||
|
|
||
|
self.fake_image_token = AddedToken("<fake_token_around_image>", normalized=False, special=True).content
|
||
|
self.image_token = AddedToken("<image>", normalized=False, special=True).content
|
||
|
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True).content
|
||
|
self.global_image_tag = "<global-img>" # https://github.com/huggingface/transformers/pull/32473/files/8063e5e17362571b693f1db95167f5443a3be1b2#r1734825341
|
||
|
self.image_seq_len = image_seq_len
|
||
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||
|
self.fake_image_token_id = tokenizer.convert_tokens_to_ids(self.fake_image_token)
|
||
|
self.global_image_token_id = tokenizer.convert_tokens_to_ids(self.global_image_tag)
|
||
|
self.row_col_ids = [
|
||
|
tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>") for i in range(6) for j in range(6)
|
||
|
]
|
||
|
|
||
|
# This regex matches one or more occurrences of <global-img> tags (optionally surrounded by newline characters)
|
||
|
# or <row_x_col_y> tags (where x and y are digits, also optionally surrounded by newline characters).
|
||
|
self._regex_to_remove_extra_special_tokens = re.compile(r"(\n?<global-img>\n?|<row_\d+_col_\d+>\n?)+")
|
||
|
|
||
|
tokens_to_add = {
|
||
|
"additional_special_tokens": [
|
||
|
self.fake_image_token,
|
||
|
self.image_token,
|
||
|
self.end_of_utterance_token,
|
||
|
]
|
||
|
}
|
||
|
tokenizer.add_special_tokens(tokens_to_add)
|
||
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||
|
|
||
|
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
|
||
|
|
||
|
def _extract_images_from_prompts(self, prompts):
|
||
|
prompt_images = []
|
||
|
for prompt in prompts:
|
||
|
images = []
|
||
|
for elem in prompt:
|
||
|
if is_valid_image(elem):
|
||
|
images.append(elem)
|
||
|
elif is_url(elem):
|
||
|
images.append(load_image(elem))
|
||
|
prompt_images.append(images)
|
||
|
return prompt_images
|
||
|
|
||
|
def __call__(
|
||
|
self,
|
||
|
images: Union[ImageInput, list[ImageInput], list[list[ImageInput]]] = None,
|
||
|
text: Union[TextInput, "PreTokenizedInput", list[TextInput], list["PreTokenizedInput"]] = None,
|
||
|
audio=None,
|
||
|
videos=None,
|
||
|
image_seq_len: Optional[int] = None,
|
||
|
**kwargs: Unpack[Idefics3ProcessorKwargs],
|
||
|
) -> BatchEncoding:
|
||
|
"""
|
||
|
Processes the input prompts and returns a BatchEncoding.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
>>> import requests
|
||
|
>>> from transformers import Idefics3Processor
|
||
|
>>> from transformers.image_utils import load_image
|
||
|
|
||
|
>>> processor = Idefics3Processor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
|
||
|
>>> processor.image_processor.do_image_splitting = False # Force as False to simplify the example
|
||
|
|
||
|
>>> url1 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||
|
>>> url2 = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg"
|
||
|
|
||
|
>>> image1, image2 = load_image(url1), load_image(url2)
|
||
|
>>> images = [[image1], [image2]]
|
||
|
|
||
|
>>> text = [
|
||
|
... "<image>In this image, we see",
|
||
|
... "bla bla bla<image>",
|
||
|
... ]
|
||
|
>>> outputs = processor(images=images, text=text, return_tensors="pt", padding=True)
|
||
|
>>> input_ids = outputs.input_ids
|
||
|
>>> input_tokens = processor.tokenizer.batch_decode(input_ids)
|
||
|
>>> print(input_tokens)
|
||
|
['<|begin_of_text|><fake_token_around_image><global-img>((<image>)*169)<fake_token_around_image> In this image, we see', '<|reserved_special_token_0|><|reserved_special_token_0|><|reserved_special_token_0|><|begin_of_text|>bla bla bla<fake_token_around_image><global-img>((<image>)*169)<fake_token_around_image>']
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`, *optional*):
|
||
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||
|
tensor. If is of type `list[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
|
||
|
text (`Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]`, *optional*):
|
||
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||
|
Wherever an image token, `<image>` is encountered it is expanded to
|
||
|
`<fake_token_around_image>` + `<row_x_col_y>` + `<image>` * `image_seq_len` * <fake_token_around_image>`.
|
||
|
image_seq_len (`int`, *optional*):
|
||
|
The length of the image sequence. If not provided, the default value of self.image_seq_len is used.
|
||
|
image_seq_len should be equal to int(((image_size // patch_size) ** 2) / (scale_factor**2))
|
||
|
return_tensors (`Union[str, TensorType]`, *optional*):
|
||
|
If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more
|
||
|
information.
|
||
|
"""
|
||
|
if text is None and images is None:
|
||
|
raise ValueError("You must provide either `text` or `images`.")
|
||
|
|
||
|
output_kwargs = self._merge_kwargs(
|
||
|
Idefics3ProcessorKwargs,
|
||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
|
||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||
|
|
||
|
n_images_in_text = []
|
||
|
n_images_in_images = []
|
||
|
inputs = {}
|
||
|
|
||
|
if text is not None:
|
||
|
if isinstance(text, str):
|
||
|
text = [text]
|
||
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||
|
n_images_in_text = [sample.count(self.image_token) for sample in text]
|
||
|
|
||
|
if images is not None:
|
||
|
if is_image_or_image_url(images):
|
||
|
images = [[images]]
|
||
|
elif isinstance(images, (list, tuple)) and is_image_or_image_url(images[0]):
|
||
|
if text is not None:
|
||
|
if sum(n_images_in_text) != len(images):
|
||
|
raise ValueError(
|
||
|
f"The total number of {self.image_token} tokens in the prompts should be the same as the number of images passed."
|
||
|
f" Found {sum(n_images_in_text)} {self.image_token} tokens and {len(images)} images."
|
||
|
)
|
||
|
# Reorganize the images to match the prompts
|
||
|
cumsum_images_in_text = [0] + list(accumulate(n_images_in_text))
|
||
|
images = [
|
||
|
images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]]
|
||
|
for i in range(len(n_images_in_text))
|
||
|
]
|
||
|
else:
|
||
|
images = [images]
|
||
|
elif (
|
||
|
not isinstance(images, (list, tuple))
|
||
|
and not isinstance(images[0], (list, tuple))
|
||
|
and not is_image_or_image_url(images[0][0])
|
||
|
):
|
||
|
raise ValueError(
|
||
|
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
|
||
|
)
|
||
|
n_images_in_images = [len(sample) for sample in images]
|
||
|
|
||
|
# Load images if they are URLs
|
||
|
images = [[load_image(im) if is_url(im) else im for im in sample] for sample in images]
|
||
|
|
||
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||
|
inputs.update(image_inputs)
|
||
|
|
||
|
if text is not None:
|
||
|
if n_images_in_images != n_images_in_text:
|
||
|
raise ValueError(
|
||
|
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
|
||
|
)
|
||
|
|
||
|
image_rows = inputs.pop("rows", [[0] * len(text)])
|
||
|
image_cols = inputs.pop("cols", [[0] * len(text)])
|
||
|
|
||
|
fake_image_token = self.fake_image_token
|
||
|
image_token = self.image_token
|
||
|
global_img_token = self.global_image_tag
|
||
|
|
||
|
prompt_strings = []
|
||
|
batch_image_seq_lengths = []
|
||
|
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
|
||
|
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
|
||
|
image_prompt_strings = []
|
||
|
image_seq_lengths = []
|
||
|
for n_rows, n_cols in zip(sample_rows, sample_cols):
|
||
|
image_prompt_string = get_image_prompt_string(
|
||
|
n_rows,
|
||
|
n_cols,
|
||
|
image_seq_len,
|
||
|
image_token=image_token,
|
||
|
fake_token_around_image=fake_image_token,
|
||
|
global_img_token=global_img_token,
|
||
|
)
|
||
|
# Add +2 and +3 for special BOI/EOI/fake_image_wrapper tokens
|
||
|
row_length = (self.image_seq_len + 2) * n_cols + 1
|
||
|
image_seq_lengths.append((self.image_seq_len + 3) + row_length * n_rows)
|
||
|
image_prompt_strings.append(image_prompt_string)
|
||
|
|
||
|
batch_image_seq_lengths.append(image_seq_lengths)
|
||
|
split_sample = sample.split(image_token)
|
||
|
if len(split_sample) == 0:
|
||
|
raise ValueError("The image token should be present in the text.")
|
||
|
|
||
|
# Place in the image prompt strings where the image tokens are
|
||
|
sample = split_sample[0]
|
||
|
for i, image_prompt_string in enumerate(image_prompt_strings):
|
||
|
sample += image_prompt_string + split_sample[i + 1]
|
||
|
prompt_strings.append(sample)
|
||
|
|
||
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
||
|
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
||
|
inputs.update(text_inputs)
|
||
|
|
||
|
elif text is not None:
|
||
|
if any(n_images_in_text):
|
||
|
raise ValueError(
|
||
|
f"Found {sum(n_images_in_text)} {self.image_token} tokens in the text but no images were passed."
|
||
|
)
|
||
|
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||
|
inputs.update(text_inputs)
|
||
|
|
||
|
if return_mm_token_type_ids:
|
||
|
array_ids = np.array(inputs["input_ids"])
|
||
|
mm_token_type_ids = np.zeros_like(array_ids)
|
||
|
for i, seq_lengths in enumerate(batch_image_seq_lengths):
|
||
|
image_start_positions = np.where(array_ids[i] == self.fake_image_token_id)[0]
|
||
|
j = 0
|
||
|
for seq_len in seq_lengths:
|
||
|
if j >= len(image_start_positions):
|
||
|
break
|
||
|
start = image_start_positions[j]
|
||
|
end = start + seq_len
|
||
|
mm_token_type_ids[i, start:end] = 1
|
||
|
j = np.searchsorted(image_start_positions, end)
|
||
|
|
||
|
inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||
|
|
||
|
return BatchFeature(data=inputs, tensor_type=return_tensors)
|
||
|
|
||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||
|
"""
|
||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||
|
|
||
|
Args:
|
||
|
image_sizes (`list[list[int]]`, *optional*):
|
||
|
The input sizes formatted as (height, width) per each image.
|
||
|
|
||
|
Returns:
|
||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||
|
input modalities, along with other useful data.
|
||
|
"""
|
||
|
|
||
|
vision_data = {}
|
||
|
if image_sizes is not None:
|
||
|
images_kwargs = Idefics3ProcessorKwargs._defaults.get("images_kwargs", {})
|
||
|
images_kwargs.update(kwargs)
|
||
|
|
||
|
num_image_row_cols = [
|
||
|
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
||
|
for image_size in image_sizes
|
||
|
]
|
||
|
|
||
|
base_image_length = self.image_seq_len + 3
|
||
|
col_length = self.image_seq_len + 2
|
||
|
num_image_tokens = []
|
||
|
num_image_patches = []
|
||
|
|
||
|
for num_patches, num_rows, num_cols in num_image_row_cols:
|
||
|
row_length = col_length * num_cols + 1
|
||
|
num_image_tokens.append(base_image_length + (row_length * num_rows))
|
||
|
num_image_patches.append(num_patches)
|
||
|
|
||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||
|
|
||
|
return MultiModalData(**vision_data)
|
||
|
|
||
|
def batch_decode(self, *args, **kwargs):
|
||
|
"""
|
||
|
This method forwards all its arguments to Idefics3TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||
|
refer to the docstring of this method for more information.
|
||
|
"""
|
||
|
batched_decode_output = self.tokenizer.batch_decode(*args, **kwargs)
|
||
|
return [self._regex_to_remove_extra_special_tokens.sub("<image>", s) for s in batched_decode_output]
|
||
|
|
||
|
def decode(self, *args, **kwargs):
|
||
|
"""
|
||
|
This method forwards all its arguments to Idefics3TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||
|
the docstring of this method for more information.
|
||
|
"""
|
||
|
decode_output = self.tokenizer.decode(*args, **kwargs)
|
||
|
return self._regex_to_remove_extra_special_tokens.sub("<image>", decode_output)
|
||
|
|
||
|
@property
|
||
|
def model_input_names(self):
|
||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||
|
image_processor_input_names = self.image_processor.model_input_names
|
||
|
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))
|
||
|
|
||
|
|
||
|
__all__ = ["Idefics3Processor"]
|