136 lines
5.9 KiB
Python
136 lines
5.9 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Image/Text processor class for GIT
|
|
"""
|
|
|
|
from typing import Optional, Union
|
|
|
|
from ...feature_extraction_utils import BatchFeature
|
|
from ...image_utils import ImageInput
|
|
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
|
from ...utils import logging
|
|
|
|
|
|
class GitProcessorKwargs(ProcessingKwargs, total=False):
|
|
_defaults = {}
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class GitProcessor(ProcessorMixin):
|
|
r"""
|
|
Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.
|
|
|
|
[`GitProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BertTokenizerFast`]. See the
|
|
[`~GitProcessor.__call__`] and [`~GitProcessor.decode`] for more information.
|
|
|
|
Args:
|
|
image_processor ([`AutoImageProcessor`]):
|
|
The image processor is a required input.
|
|
tokenizer ([`AutoTokenizer`]):
|
|
The tokenizer is a required input.
|
|
"""
|
|
|
|
attributes = ["image_processor", "tokenizer"]
|
|
image_processor_class = "AutoImageProcessor"
|
|
tokenizer_class = "AutoTokenizer"
|
|
|
|
def __init__(self, image_processor, tokenizer):
|
|
super().__init__(image_processor, tokenizer)
|
|
self.current_processor = self.image_processor
|
|
|
|
def __call__(
|
|
self,
|
|
images: Optional[ImageInput] = None,
|
|
text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
|
|
audio=None,
|
|
videos=None,
|
|
**kwargs: Unpack[GitProcessorKwargs],
|
|
) -> BatchFeature:
|
|
"""
|
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
|
and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
|
|
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
|
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
|
of the above two methods for more information.
|
|
|
|
Args:
|
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
tensor. Both channels-first and channels-last formats are supported.
|
|
text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`, *optional*):
|
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
|
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
|
If set, will return tensors of a particular framework. Acceptable values are:
|
|
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
|
|
Returns:
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
|
`None`).
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
|
"""
|
|
if text is None and images is None:
|
|
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
|
|
|
output_kwargs = self._merge_kwargs(
|
|
GitProcessorKwargs,
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
**kwargs,
|
|
)
|
|
|
|
data = {}
|
|
if text is not None:
|
|
text_features = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
|
data.update(text_features)
|
|
if images is not None:
|
|
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
|
|
data.update(image_features)
|
|
|
|
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))
|
|
|
|
def batch_decode(self, *args, **kwargs):
|
|
"""
|
|
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
|
refer to the docstring of this method for more information.
|
|
"""
|
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
|
|
def decode(self, *args, **kwargs):
|
|
"""
|
|
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
|
the docstring of this method for more information.
|
|
"""
|
|
return self.tokenizer.decode(*args, **kwargs)
|
|
|
|
@property
|
|
def model_input_names(self):
|
|
return ["input_ids", "attention_mask", "pixel_values"]
|
|
|
|
|
|
__all__ = ["GitProcessor"]
|