team-10/venv/Lib/site-packages/transformers/models/shieldgemma2/processing_shieldgemma2.py
2025-08-02 02:00:33 +02:00

195 lines
8.4 KiB
Python

# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping, Sequence
from typing import Optional
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import Unpack
from ...utils import logging
from ..gemma3.processing_gemma3 import Gemma3Processor, Gemma3ProcessorKwargs
logger = logging.get_logger(__name__)
DEFAULT_SHIELDGEMMA2_POLICIES: Mapping[str, str] = {
"dangerous": (
"No Dangerous Content: The image shall not contain content that"
" facilitates or encourages activities that could cause real-world harm"
" (e.g., building firearms and explosive devices, promotion of"
" terrorism, instructions for suicide)."
),
"sexual": (
"No Sexually Explicit content: The image shall not contain content that"
" depicts explicit or graphic sexual acts (e.g., pornography, erotic"
" nudity, depictions of rape or sexual assault)."
),
"violence": (
"No Violence/Gore content: The image shall not contain content that"
" depicts shocking, sensational, or gratuitous violence (e.g.,"
" excessive blood and gore, gratuitous violence against animals,"
" extreme injury or moment of death)."
),
}
class ShieldGemma2ProcessorKwargs(Gemma3ProcessorKwargs, total=False):
policies: Optional[Sequence[str]]
custom_policies: Optional[Mapping[str, str]]
_defaults = {
"text_kwargs": {
"padding": True,
},
"images_kwargs": {
"do_pan_and_scan": False,
},
}
class ShieldGemma2Processor(Gemma3Processor):
def __init__(
self, image_processor, tokenizer, chat_template=None, image_seq_length=256, policy_definitions=None, **kwargs
):
"""A processor for the ShieldGemma 2 model.
Args:
image_processor: The image processor to use, typically a `Gemma3ImageProcessorFast` instance.
tokenizer: The tokenizer to use, typically a `GemmaTokenizerFast` instance.
chat_template: The chat template to use with this processor. Typically, this is unset as the processor
configuration on Hugging Face Hub includes this value already.
image_seq_length: The number of soft tokens per image. Typically, this is unset as the processor
configuration on Hugging Face Hub includes this value already.
policy_definitions: A mapping from policy name to its description in text used as the default policies to
classify images against. The policy descriptions are included in the text of the prompts generated by
this processor. Typically, this is unset as the processor configuration on Hugging Face Hub includes
the base policies ShieldGemma was trained on.
"""
super().__init__(image_processor, tokenizer, chat_template, image_seq_length, **kwargs)
if policy_definitions is None:
self.policy_definitions = DEFAULT_SHIELDGEMMA2_POLICIES
else:
self.policy_definitions = policy_definitions
def __call__(
self,
images: ImageInput = None,
text=None,
videos=None,
audio=None,
**kwargs: Unpack[ShieldGemma2ProcessorKwargs],
) -> BatchFeature:
"""Generates a batch of inputs from the provided images.
ShieldGemma was trained to classify image content for policy compliance using a specific prompt construction.
This processor generates a batch of such prompts from the provided images by:
1. Creating a list of conversations, one for each `<image, policy>` pair;
2. Converting these conversations to text using `self.apply_chat_template()`; and
3. Encoding the conversations and images using the same techniques as `Gemma3Processor`.
Args:
images: A single image or a list of images to include in the batch.
text: Not supported.
videos: Not supported.
audio: Not supported.
kwargs: An optional dictionary of keyword arguments to configure the
processor. Possible values include:
* `custom_policies`: Additional policy definitions that augment the `self.policy_definitions` passed
into the constructor. Note that `custom_policies` that share a key with `self.policy_definitions`
will override the policy description
* `policies`: (Optional) a list of keys in the joint `self.policy_definitions | custom_policies`
dictionary of specific interest for the provided images. If empty or None, prompts will be
generated for every key in the joint dictionary.
Returns:
A `BatchFeature` containing `input_ids`, `pixel_values`, etc. where each Tensor is of shape
`(len(images) * len(policies), )`, and the order within the batch will be
img1_policy1, ... img1_policyN, ... imgM_policyN.
"""
del text, videos, audio
if not images:
raise ValueError("ShieldGemma 2 needs images to classify")
elif not isinstance(images, Sequence):
images = [images]
if not self.chat_template:
raise ValueError("ShieldGemma 2 requires the use of a specific chat template")
# Disable pan and scan
images_kwargs = kwargs.setdefault("images_kwargs", {})
if images_kwargs.get("do_pan_and_scan") is True:
logger.warning_once("ShieldGemma2 does not support pan and scan.")
images_kwargs["do_pan_and_scan"] = False
# Enable padding on the batch during tokenization
text_kwargs = kwargs.setdefault("text_kwargs", {})
if "padding" not in text_kwargs:
text_kwargs["padding"] = kwargs.pop("padding", True)
text_kwargs["padding_side"] = kwargs.pop("padding_side", "left")
policy_definitions: Mapping[str, str] = {
**self.policy_definitions,
**kwargs.get("custom_policies", {}),
}
if (policies := kwargs.get("policies")) is None:
policies = list(policy_definitions.keys())
# TODO(ryanmullins): Support images from PIL or URLs.
messages = []
expanded_images = []
for img in images:
for policy in policies:
messages.append(
[
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": policy_definitions[policy]},
],
}
]
)
expanded_images.append([img])
text = self.apply_chat_template(messages, tokenize=False)
return super().__call__(images=expanded_images, text=text, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["ShieldGemma2Processor"]