251 lines
11 KiB
Python
251 lines
11 KiB
Python
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from ...image_processor import VaeImageProcessor
|
|
|
|
|
|
class VisualClozeProcessor(VaeImageProcessor):
|
|
"""
|
|
Image processor for the VisualCloze pipeline.
|
|
|
|
This processor handles the preprocessing of images for visual cloze tasks, including resizing, normalization, and
|
|
mask generation.
|
|
|
|
Args:
|
|
resolution (int, optional):
|
|
Target resolution for processing images. Each image will be resized to this resolution before being
|
|
concatenated to avoid the out-of-memory error. Defaults to 384.
|
|
*args: Additional arguments passed to [~image_processor.VaeImageProcessor]
|
|
**kwargs: Additional keyword arguments passed to [~image_processor.VaeImageProcessor]
|
|
"""
|
|
|
|
def __init__(self, *args, resolution: int = 384, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.resolution = resolution
|
|
|
|
def preprocess_image(
|
|
self, input_images: List[List[Optional[Image.Image]]], vae_scale_factor: int
|
|
) -> Tuple[List[List[torch.Tensor]], List[List[List[int]]], List[int]]:
|
|
"""
|
|
Preprocesses input images for the VisualCloze pipeline.
|
|
|
|
This function handles the preprocessing of input images by:
|
|
1. Resizing and cropping images to maintain consistent dimensions
|
|
2. Converting images to the Tensor format for the VAE
|
|
3. Normalizing pixel values
|
|
4. Tracking image sizes and positions of target images
|
|
|
|
Args:
|
|
input_images (List[List[Optional[Image.Image]]]):
|
|
A nested list of PIL Images where:
|
|
- Outer list represents different samples, including in-context examples and the query
|
|
- Inner list contains images for the task
|
|
- In the last row, condition images are provided and the target images are placed as None
|
|
vae_scale_factor (int):
|
|
The scale factor used by the VAE for resizing images
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- List[List[torch.Tensor]]: Preprocessed images in tensor format
|
|
- List[List[List[int]]]: Dimensions of each processed image [height, width]
|
|
- List[int]: Target positions indicating which images are to be generated
|
|
"""
|
|
n_samples, n_task_images = len(input_images), len(input_images[0])
|
|
divisible = 2 * vae_scale_factor
|
|
|
|
processed_images: List[List[Image.Image]] = [[] for _ in range(n_samples)]
|
|
resize_size: List[Optional[Tuple[int, int]]] = [None for _ in range(n_samples)]
|
|
target_position: List[int] = []
|
|
|
|
# Process each sample
|
|
for i in range(n_samples):
|
|
# Determine size from first non-None image
|
|
for j in range(n_task_images):
|
|
if input_images[i][j] is not None:
|
|
aspect_ratio = input_images[i][j].width / input_images[i][j].height
|
|
target_area = self.resolution * self.resolution
|
|
new_h = int((target_area / aspect_ratio) ** 0.5)
|
|
new_w = int(new_h * aspect_ratio)
|
|
|
|
new_w = max(new_w // divisible, 1) * divisible
|
|
new_h = max(new_h // divisible, 1) * divisible
|
|
resize_size[i] = (new_w, new_h)
|
|
break
|
|
|
|
# Process all images in the sample
|
|
for j in range(n_task_images):
|
|
if input_images[i][j] is not None:
|
|
target = self._resize_and_crop(input_images[i][j], resize_size[i][0], resize_size[i][1])
|
|
processed_images[i].append(target)
|
|
if i == n_samples - 1:
|
|
target_position.append(0)
|
|
else:
|
|
blank = Image.new("RGB", resize_size[i] or (self.resolution, self.resolution), (0, 0, 0))
|
|
processed_images[i].append(blank)
|
|
if i == n_samples - 1:
|
|
target_position.append(1)
|
|
|
|
# Ensure consistent width for multiple target images when there are multiple target images
|
|
if len(target_position) > 1 and sum(target_position) > 1:
|
|
new_w = resize_size[n_samples - 1][0] or 384
|
|
for i in range(len(processed_images)):
|
|
for j in range(len(processed_images[i])):
|
|
if processed_images[i][j] is not None:
|
|
new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width))
|
|
new_w = int(new_w / 16) * 16
|
|
new_h = int(new_h / 16) * 16
|
|
processed_images[i][j] = self.height(processed_images[i][j], new_h, new_w)
|
|
|
|
# Convert to tensors and normalize
|
|
image_sizes = []
|
|
for i in range(len(processed_images)):
|
|
image_sizes.append([[img.height, img.width] for img in processed_images[i]])
|
|
for j, image in enumerate(processed_images[i]):
|
|
image = self.pil_to_numpy(image)
|
|
image = self.numpy_to_pt(image)
|
|
image = self.normalize(image)
|
|
processed_images[i][j] = image
|
|
|
|
return processed_images, image_sizes, target_position
|
|
|
|
def preprocess_mask(
|
|
self, input_images: List[List[Image.Image]], target_position: List[int]
|
|
) -> List[List[torch.Tensor]]:
|
|
"""
|
|
Generate masks for the VisualCloze pipeline.
|
|
|
|
Args:
|
|
input_images (List[List[Image.Image]]):
|
|
Processed images from preprocess_image
|
|
target_position (List[int]):
|
|
Binary list marking the positions of target images (1 for target, 0 for condition)
|
|
|
|
Returns:
|
|
List[List[torch.Tensor]]:
|
|
A nested list of mask tensors (1 for target positions, 0 for condition images)
|
|
"""
|
|
mask = []
|
|
for i, row in enumerate(input_images):
|
|
if i == len(input_images) - 1: # Query row
|
|
row_masks = [
|
|
torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=m) for m in target_position
|
|
]
|
|
else: # In-context examples
|
|
row_masks = [
|
|
torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=0) for _ in target_position
|
|
]
|
|
mask.append(row_masks)
|
|
return mask
|
|
|
|
def preprocess_image_upsampling(
|
|
self,
|
|
input_images: List[List[Image.Image]],
|
|
height: int,
|
|
width: int,
|
|
) -> Tuple[List[List[Image.Image]], List[List[List[int]]]]:
|
|
"""Process images for the upsampling stage in the VisualCloze pipeline.
|
|
|
|
Args:
|
|
input_images: Input image to process
|
|
height: Target height
|
|
width: Target width
|
|
|
|
Returns:
|
|
Tuple of processed image and its size
|
|
"""
|
|
image = self.resize(input_images[0][0], height, width)
|
|
image = self.pil_to_numpy(image) # to np
|
|
image = self.numpy_to_pt(image) # to pt
|
|
image = self.normalize(image)
|
|
|
|
input_images[0][0] = image
|
|
image_sizes = [[[height, width]]]
|
|
return input_images, image_sizes
|
|
|
|
def preprocess_mask_upsampling(self, input_images: List[List[Image.Image]]) -> List[List[torch.Tensor]]:
|
|
return [[torch.ones((1, 1, input_images[0][0].shape[2], input_images[0][0].shape[3]))]]
|
|
|
|
def get_layout_prompt(self, size: Tuple[int, int]) -> str:
|
|
layout_instruction = (
|
|
f"A grid layout with {size[0]} rows and {size[1]} columns, displaying {size[0] * size[1]} images arranged side by side.",
|
|
)
|
|
return layout_instruction
|
|
|
|
def preprocess(
|
|
self,
|
|
task_prompt: Union[str, List[str]],
|
|
content_prompt: Union[str, List[str]],
|
|
input_images: Optional[List[List[List[Optional[str]]]]] = None,
|
|
height: Optional[int] = None,
|
|
width: Optional[int] = None,
|
|
upsampling: bool = False,
|
|
vae_scale_factor: int = 16,
|
|
) -> Dict:
|
|
"""Process visual cloze inputs.
|
|
|
|
Args:
|
|
task_prompt: Task description(s)
|
|
content_prompt: Content description(s)
|
|
input_images: List of images or None for the target images
|
|
height: Optional target height for upsampling stage
|
|
width: Optional target width for upsampling stage
|
|
upsampling: Whether this is in the upsampling processing stage
|
|
|
|
Returns:
|
|
Dictionary containing processed images, masks, prompts and metadata
|
|
"""
|
|
if isinstance(task_prompt, str):
|
|
task_prompt = [task_prompt]
|
|
content_prompt = [content_prompt]
|
|
input_images = [input_images]
|
|
|
|
output = {
|
|
"init_image": [],
|
|
"mask": [],
|
|
"task_prompt": task_prompt if not upsampling else [None for _ in range(len(task_prompt))],
|
|
"content_prompt": content_prompt,
|
|
"layout_prompt": [],
|
|
"target_position": [],
|
|
"image_size": [],
|
|
}
|
|
for i in range(len(task_prompt)):
|
|
if upsampling:
|
|
layout_prompt = None
|
|
else:
|
|
layout_prompt = self.get_layout_prompt((len(input_images[i]), len(input_images[i][0])))
|
|
|
|
if upsampling:
|
|
cur_processed_images, cur_image_size = self.preprocess_image_upsampling(
|
|
input_images[i], height=height, width=width
|
|
)
|
|
cur_mask = self.preprocess_mask_upsampling(cur_processed_images)
|
|
else:
|
|
cur_processed_images, cur_image_size, cur_target_position = self.preprocess_image(
|
|
input_images[i], vae_scale_factor=vae_scale_factor
|
|
)
|
|
cur_mask = self.preprocess_mask(cur_processed_images, cur_target_position)
|
|
|
|
output["target_position"].append(cur_target_position)
|
|
|
|
output["image_size"].append(cur_image_size)
|
|
output["init_image"].append(cur_processed_images)
|
|
output["mask"].append(cur_mask)
|
|
output["layout_prompt"].append(layout_prompt)
|
|
|
|
return output
|