team-10/venv/Lib/site-packages/transformers/models/llava/processing_llava.py
2025-08-02 02:00:33 +02:00

235 lines
11 KiB
Python

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Llava.
"""
from typing import Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
logger = logging.get_logger(__name__)
class LlavaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {"padding": False, "return_mm_token_type_ids": False},
"images_kwargs": {},
}
class LlavaProcessor(ProcessorMixin):
r"""
Constructs a LLaVa processor which wraps a LLaVa image processor and a LLaMa tokenizer into a single processor.
[`LlavaProcessor`] offers all the functionalities of [`LlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
Args:
image_processor ([`LlavaImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
patch_size (`int`, *optional*):
Patch size from the vision tower.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Should be same as in model's config
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `"<image>"`):
Special token used to denote image location.
num_additional_image_tokens (`int`, *optional*, defaults to 0):
Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
extra tokens appended, no need to set this arg.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
patch_size=None,
vision_feature_select_strategy=None,
chat_template=None,
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
num_additional_image_tokens=0,
**kwargs,
):
self.patch_size = patch_size
self.num_additional_image_tokens = num_additional_image_tokens
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.image_token_id = tokenizer.encode(self.image_token, add_special_tokens=False)[0]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[LlavaProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is None and text is None:
raise ValueError("You have to specify at least one of `images` or `text`.")
output_kwargs = self._merge_kwargs(
LlavaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
image_inputs = {}
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
# try to expand inputs in processing if we have the necessary parts
prompt_strings = text
if image_inputs.get("pixel_values") is not None:
# Replace the image token with the expanded image token sequence
pixel_values = image_inputs["pixel_values"]
height, width = get_image_size(to_numpy_array(pixel_values[0]))
num_image_tokens = (height // self.patch_size) * (
width // self.patch_size
) + self.num_additional_image_tokens
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
prompt_strings = []
for sample in text:
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`list[list[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = LlavaProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
crop_size = images_kwargs.get("crop_size", None) or self.image_processor.crop_size
resized_height, resized_width = crop_size["height"], crop_size["width"]
num_image_tokens = (resized_height // self.patch_size) * (resized_width // self.patch_size)
num_image_tokens += self.num_additional_image_tokens
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
num_image_tokens = [num_image_tokens] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["LlavaProcessor"]