# coding=utf-8 # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from typing import Optional, Union import numpy as np from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, make_nested_list_of_images from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import to_py_obj class Gemma3ImagesKwargs(ImagesKwargs): do_pan_and_scan: Optional[bool] pan_and_scan_min_crop_size: Optional[int] pan_and_scan_max_num_crops: Optional[int] pan_and_scan_min_ratio_to_activate: Optional[float] do_convert_rgb: Optional[bool] class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): images_kwargs: Gemma3ImagesKwargs _defaults = { "text_kwargs": { "padding": False, "return_mm_token_type_ids": True, }, "images_kwargs": { "do_convert_rgb": True, "do_pan_and_scan": False, "pan_and_scan_min_crop_size": 256, "pan_and_scan_max_num_crops": 4, "pan_and_scan_min_ratio_to_activate": 1.2, }, } class Gemma3Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__( self, image_processor, tokenizer, chat_template=None, image_seq_length: int = 256, **kwargs, ): self.image_seq_length = image_seq_length self.image_token_id = tokenizer.image_token_id self.boi_token = tokenizer.boi_token self.image_token = tokenizer.image_token image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length) self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" super().__init__( image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template, **kwargs, ) def __call__( self, images: ImageInput = None, text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, videos=None, audio=None, **kwargs: Unpack[Gemma3ProcessorKwargs], ) -> BatchFeature: if text is None and images is None: raise ValueError("Provide at least one of `text` or `images`.") output_kwargs = self._merge_kwargs( Gemma3ProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise TypeError("Invalid input text. Please provide a string, or a list of strings") image_inputs = {} if images is not None: batched_images = make_nested_list_of_images(images) image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) # Create empty text to be replaced with placeholders if not text: text = [" ".join([self.boi_token] * len(images)) for images in batched_images] if len(batched_images) != len(text): raise ValueError( f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." ) # Replace image tokens by the full expanded sequence num_crops = to_py_obj(image_inputs.pop("num_crops")) batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images] for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)): image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)] if len(images) != len(image_indexes): raise ValueError( f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images." ) # Insert additional image tokens for Pan-and-Scan crops for num, idx in reversed(list(zip(num_crops, image_indexes))): if num: formatted_image_text = ( f"Here is the original image {self.boi_token} and here are some crops to help you see better " + " ".join([self.boi_token] * num) ) prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :] text[batch_idx] = prompt # Expand placeholder image tokens to the full image token sequence text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text] return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) # Add token type ids manually, as tokenizer can't do arbitrary position token types if return_mm_token_type_ids: array_ids = np.array(text_inputs["input_ids"]) mm_token_type_ids = np.zeros_like(array_ids) mm_token_type_ids[array_ids == self.image_token_id] = 1 text_inputs["token_type_ids"] = mm_token_type_ids.tolist() return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): """ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. Args: image_sizes (`list[list[int]]`, *optional*): The input sizes formatted as (height, width) per each image. Returns: `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided input modalities, along with other useful data. """ vision_data = {} if image_sizes is not None: # NOTE: no image cropping supported yet num_image_tokens = [self.image_seq_length] * len(image_sizes) num_image_patches = [1] * len(image_sizes) vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) return MultiModalData(**vision_data) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma def decode(self, *args, **kwargs): """ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"] image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) __all__ = ["Gemma3Processor"]