148 lines
6.8 KiB
Python
148 lines
6.8 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 WenXiang ZhongzhiCheng LedellWu LiuGuang BoWenZhang The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Image/Text processor class for AltCLIP
|
|
"""
|
|
|
|
from typing import Union
|
|
|
|
from ...image_utils import ImageInput
|
|
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
|
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
|
from ...utils.deprecation import deprecate_kwarg
|
|
|
|
|
|
class AltClipProcessorKwargs(ProcessingKwargs, total=False):
|
|
_defaults = {}
|
|
|
|
|
|
class AltCLIPProcessor(ProcessorMixin):
|
|
r"""
|
|
Constructs a AltCLIP processor which wraps a CLIP image processor and a XLM-Roberta tokenizer into a single
|
|
processor.
|
|
|
|
[`AltCLIPProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`XLMRobertaTokenizerFast`]. See
|
|
the [`~AltCLIPProcessor.__call__`] and [`~AltCLIPProcessor.decode`] for more information.
|
|
|
|
Args:
|
|
image_processor ([`CLIPImageProcessor`], *optional*):
|
|
The image processor is a required input.
|
|
tokenizer ([`XLMRobertaTokenizerFast`], *optional*):
|
|
The tokenizer is a required input.
|
|
"""
|
|
|
|
attributes = ["image_processor", "tokenizer"]
|
|
image_processor_class = ("CLIPImageProcessor", "CLIPImageProcessorFast")
|
|
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")
|
|
|
|
@deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor")
|
|
def __init__(self, image_processor=None, tokenizer=None):
|
|
if image_processor is None:
|
|
raise ValueError("You need to specify an `image_processor`.")
|
|
if tokenizer is None:
|
|
raise ValueError("You need to specify a `tokenizer`.")
|
|
|
|
super().__init__(image_processor, tokenizer)
|
|
|
|
def __call__(
|
|
self,
|
|
images: ImageInput = None,
|
|
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
|
audio=None,
|
|
videos=None,
|
|
**kwargs: Unpack[AltClipProcessorKwargs],
|
|
) -> BatchEncoding:
|
|
"""
|
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
|
and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not
|
|
`None` to encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
|
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
|
of the above two methods for more information.
|
|
|
|
Args:
|
|
|
|
images (`ImageInput`):
|
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
tensor. Both channels-first and channels-last formats are supported.
|
|
text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`):
|
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
|
If set, will return tensors of a particular framework. Acceptable values are:
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
Returns:
|
|
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
|
`None`).
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
|
"""
|
|
|
|
if text is None and images is None:
|
|
raise ValueError("You must specify either text or images.")
|
|
|
|
if text is None and images is None:
|
|
raise ValueError("You must specify either text or images.")
|
|
output_kwargs = self._merge_kwargs(
|
|
AltClipProcessorKwargs,
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
**kwargs,
|
|
)
|
|
|
|
if text is not None:
|
|
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
|
if images is not None:
|
|
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
|
|
|
|
# BC for explicit return_tensors
|
|
if "return_tensors" in output_kwargs["common_kwargs"]:
|
|
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
|
|
|
|
if text is not None and images is not None:
|
|
encoding["pixel_values"] = image_features.pixel_values
|
|
return encoding
|
|
elif text is not None:
|
|
return encoding
|
|
else:
|
|
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
|
|
|
|
def batch_decode(self, *args, **kwargs):
|
|
"""
|
|
This method forwards all its arguments to XLMRobertaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`].
|
|
Please refer to the docstring of this method for more information.
|
|
"""
|
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
|
|
def decode(self, *args, **kwargs):
|
|
"""
|
|
This method forwards all its arguments to XLMRobertaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please
|
|
refer to the docstring of this method for more information.
|
|
"""
|
|
return self.tokenizer.decode(*args, **kwargs)
|
|
|
|
@property
|
|
def model_input_names(self):
|
|
tokenizer_input_names = self.tokenizer.model_input_names
|
|
image_processor_input_names = self.image_processor.model_input_names
|
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
|
|
|
|
|
__all__ = ["AltCLIPProcessor"]
|