67 lines
1.9 KiB
Python
67 lines
1.9 KiB
Python
from typing import List
|
|
|
|
import PIL.Image
|
|
import PIL.ImageOps
|
|
from packaging import version
|
|
from PIL import Image
|
|
|
|
|
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
|
PIL_INTERPOLATION = {
|
|
"linear": PIL.Image.Resampling.BILINEAR,
|
|
"bilinear": PIL.Image.Resampling.BILINEAR,
|
|
"bicubic": PIL.Image.Resampling.BICUBIC,
|
|
"lanczos": PIL.Image.Resampling.LANCZOS,
|
|
"nearest": PIL.Image.Resampling.NEAREST,
|
|
}
|
|
else:
|
|
PIL_INTERPOLATION = {
|
|
"linear": PIL.Image.LINEAR,
|
|
"bilinear": PIL.Image.BILINEAR,
|
|
"bicubic": PIL.Image.BICUBIC,
|
|
"lanczos": PIL.Image.LANCZOS,
|
|
"nearest": PIL.Image.NEAREST,
|
|
}
|
|
|
|
|
|
def pt_to_pil(images):
|
|
"""
|
|
Convert a torch image to a PIL image.
|
|
"""
|
|
images = (images / 2 + 0.5).clamp(0, 1)
|
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
images = numpy_to_pil(images)
|
|
return images
|
|
|
|
|
|
def numpy_to_pil(images):
|
|
"""
|
|
Convert a numpy image or a batch of images to a PIL image.
|
|
"""
|
|
if images.ndim == 3:
|
|
images = images[None, ...]
|
|
images = (images * 255).round().astype("uint8")
|
|
if images.shape[-1] == 1:
|
|
# special case for grayscale (single channel) images
|
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
|
else:
|
|
pil_images = [Image.fromarray(image) for image in images]
|
|
|
|
return pil_images
|
|
|
|
|
|
def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image:
|
|
"""
|
|
Prepares a single grid of images. Useful for visualization purposes.
|
|
"""
|
|
assert len(images) == rows * cols
|
|
|
|
if resize is not None:
|
|
images = [img.resize((resize, resize)) for img in images]
|
|
|
|
w, h = images[0].size
|
|
grid = Image.new("RGB", size=(cols * w, rows * h))
|
|
|
|
for i, img in enumerate(images):
|
|
grid.paste(img, box=(i % cols * w, i // cols * h))
|
|
return grid
|