689 lines
31 KiB
Python
689 lines
31 KiB
Python
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
|
# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# --------------------------------------------------------------------------
|
|
# More information and citation instructions are available on the
|
|
# Marigold project website: https://marigoldcomputervision.github.io
|
|
# --------------------------------------------------------------------------
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import PIL
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
|
|
from ... import ConfigMixin
|
|
from ...configuration_utils import register_to_config
|
|
from ...image_processor import PipelineImageInput
|
|
from ...utils import CONFIG_NAME, logging
|
|
from ...utils.import_utils import is_matplotlib_available
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class MarigoldImageProcessor(ConfigMixin):
|
|
config_name = CONFIG_NAME
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
vae_scale_factor: int = 8,
|
|
do_normalize: bool = True,
|
|
do_range_check: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
@staticmethod
|
|
def expand_tensor_or_array(images: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]:
|
|
"""
|
|
Expand a tensor or array to a specified number of images.
|
|
"""
|
|
if isinstance(images, np.ndarray):
|
|
if images.ndim == 2: # [H,W] -> [1,H,W,1]
|
|
images = images[None, ..., None]
|
|
if images.ndim == 3: # [H,W,C] -> [1,H,W,C]
|
|
images = images[None]
|
|
elif isinstance(images, torch.Tensor):
|
|
if images.ndim == 2: # [H,W] -> [1,1,H,W]
|
|
images = images[None, None]
|
|
elif images.ndim == 3: # [1,H,W] -> [1,1,H,W]
|
|
images = images[None]
|
|
else:
|
|
raise ValueError(f"Unexpected input type: {type(images)}")
|
|
return images
|
|
|
|
@staticmethod
|
|
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
|
"""
|
|
Convert a PyTorch tensor to a NumPy image.
|
|
"""
|
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
return images
|
|
|
|
@staticmethod
|
|
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
|
"""
|
|
Convert a NumPy image to a PyTorch tensor.
|
|
"""
|
|
if np.issubdtype(images.dtype, np.integer) and not np.issubdtype(images.dtype, np.unsignedinteger):
|
|
raise ValueError(f"Input image dtype={images.dtype} cannot be a signed integer.")
|
|
if np.issubdtype(images.dtype, np.complexfloating):
|
|
raise ValueError(f"Input image dtype={images.dtype} cannot be complex.")
|
|
if np.issubdtype(images.dtype, bool):
|
|
raise ValueError(f"Input image dtype={images.dtype} cannot be boolean.")
|
|
|
|
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
|
return images
|
|
|
|
@staticmethod
|
|
def resize_antialias(
|
|
image: torch.Tensor, size: Tuple[int, int], mode: str, is_aa: Optional[bool] = None
|
|
) -> torch.Tensor:
|
|
if not torch.is_tensor(image):
|
|
raise ValueError(f"Invalid input type={type(image)}.")
|
|
if not torch.is_floating_point(image):
|
|
raise ValueError(f"Invalid input dtype={image.dtype}.")
|
|
if image.dim() != 4:
|
|
raise ValueError(f"Invalid input dimensions; shape={image.shape}.")
|
|
|
|
antialias = is_aa and mode in ("bilinear", "bicubic")
|
|
image = F.interpolate(image, size, mode=mode, antialias=antialias)
|
|
|
|
return image
|
|
|
|
@staticmethod
|
|
def resize_to_max_edge(image: torch.Tensor, max_edge_sz: int, mode: str) -> torch.Tensor:
|
|
if not torch.is_tensor(image):
|
|
raise ValueError(f"Invalid input type={type(image)}.")
|
|
if not torch.is_floating_point(image):
|
|
raise ValueError(f"Invalid input dtype={image.dtype}.")
|
|
if image.dim() != 4:
|
|
raise ValueError(f"Invalid input dimensions; shape={image.shape}.")
|
|
|
|
h, w = image.shape[-2:]
|
|
max_orig = max(h, w)
|
|
new_h = h * max_edge_sz // max_orig
|
|
new_w = w * max_edge_sz // max_orig
|
|
|
|
if new_h == 0 or new_w == 0:
|
|
raise ValueError(f"Extreme aspect ratio of the input image: [{w} x {h}]")
|
|
|
|
image = MarigoldImageProcessor.resize_antialias(image, (new_h, new_w), mode, is_aa=True)
|
|
|
|
return image
|
|
|
|
@staticmethod
|
|
def pad_image(image: torch.Tensor, align: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
|
if not torch.is_tensor(image):
|
|
raise ValueError(f"Invalid input type={type(image)}.")
|
|
if not torch.is_floating_point(image):
|
|
raise ValueError(f"Invalid input dtype={image.dtype}.")
|
|
if image.dim() != 4:
|
|
raise ValueError(f"Invalid input dimensions; shape={image.shape}.")
|
|
|
|
h, w = image.shape[-2:]
|
|
ph, pw = -h % align, -w % align
|
|
|
|
image = F.pad(image, (0, pw, 0, ph), mode="replicate")
|
|
|
|
return image, (ph, pw)
|
|
|
|
@staticmethod
|
|
def unpad_image(image: torch.Tensor, padding: Tuple[int, int]) -> torch.Tensor:
|
|
if not torch.is_tensor(image):
|
|
raise ValueError(f"Invalid input type={type(image)}.")
|
|
if not torch.is_floating_point(image):
|
|
raise ValueError(f"Invalid input dtype={image.dtype}.")
|
|
if image.dim() != 4:
|
|
raise ValueError(f"Invalid input dimensions; shape={image.shape}.")
|
|
|
|
ph, pw = padding
|
|
uh = None if ph == 0 else -ph
|
|
uw = None if pw == 0 else -pw
|
|
|
|
image = image[:, :, :uh, :uw]
|
|
|
|
return image
|
|
|
|
@staticmethod
|
|
def load_image_canonical(
|
|
image: Union[torch.Tensor, np.ndarray, Image.Image],
|
|
device: torch.device = torch.device("cpu"),
|
|
dtype: torch.dtype = torch.float32,
|
|
) -> Tuple[torch.Tensor, int]:
|
|
if isinstance(image, Image.Image):
|
|
image = np.array(image)
|
|
|
|
image_dtype_max = None
|
|
if isinstance(image, (np.ndarray, torch.Tensor)):
|
|
image = MarigoldImageProcessor.expand_tensor_or_array(image)
|
|
if image.ndim != 4:
|
|
raise ValueError("Input image is not 2-, 3-, or 4-dimensional.")
|
|
if isinstance(image, np.ndarray):
|
|
if np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.unsignedinteger):
|
|
raise ValueError(f"Input image dtype={image.dtype} cannot be a signed integer.")
|
|
if np.issubdtype(image.dtype, np.complexfloating):
|
|
raise ValueError(f"Input image dtype={image.dtype} cannot be complex.")
|
|
if np.issubdtype(image.dtype, bool):
|
|
raise ValueError(f"Input image dtype={image.dtype} cannot be boolean.")
|
|
if np.issubdtype(image.dtype, np.unsignedinteger):
|
|
image_dtype_max = np.iinfo(image.dtype).max
|
|
image = image.astype(np.float32) # because torch does not have unsigned dtypes beyond torch.uint8
|
|
image = MarigoldImageProcessor.numpy_to_pt(image)
|
|
|
|
if torch.is_tensor(image) and not torch.is_floating_point(image) and image_dtype_max is None:
|
|
if image.dtype != torch.uint8:
|
|
raise ValueError(f"Image dtype={image.dtype} is not supported.")
|
|
image_dtype_max = 255
|
|
|
|
if not torch.is_tensor(image):
|
|
raise ValueError(f"Input type unsupported: {type(image)}.")
|
|
|
|
if image.shape[1] == 1:
|
|
image = image.repeat(1, 3, 1, 1) # [N,1,H,W] -> [N,3,H,W]
|
|
if image.shape[1] != 3:
|
|
raise ValueError(f"Input image is not 1- or 3-channel: {image.shape}.")
|
|
|
|
image = image.to(device=device, dtype=dtype)
|
|
|
|
if image_dtype_max is not None:
|
|
image = image / image_dtype_max
|
|
|
|
return image
|
|
|
|
@staticmethod
|
|
def check_image_values_range(image: torch.Tensor) -> None:
|
|
if not torch.is_tensor(image):
|
|
raise ValueError(f"Invalid input type={type(image)}.")
|
|
if not torch.is_floating_point(image):
|
|
raise ValueError(f"Invalid input dtype={image.dtype}.")
|
|
if image.min().item() < 0.0 or image.max().item() > 1.0:
|
|
raise ValueError("Input image data is partially outside of the [0,1] range.")
|
|
|
|
def preprocess(
|
|
self,
|
|
image: PipelineImageInput,
|
|
processing_resolution: Optional[int] = None,
|
|
resample_method_input: str = "bilinear",
|
|
device: torch.device = torch.device("cpu"),
|
|
dtype: torch.dtype = torch.float32,
|
|
):
|
|
if isinstance(image, list):
|
|
images = None
|
|
for i, img in enumerate(image):
|
|
img = self.load_image_canonical(img, device, dtype) # [N,3,H,W]
|
|
if images is None:
|
|
images = img
|
|
else:
|
|
if images.shape[2:] != img.shape[2:]:
|
|
raise ValueError(
|
|
f"Input image[{i}] has incompatible dimensions {img.shape[2:]} with the previous images "
|
|
f"{images.shape[2:]}"
|
|
)
|
|
images = torch.cat((images, img), dim=0)
|
|
image = images
|
|
del images
|
|
else:
|
|
image = self.load_image_canonical(image, device, dtype) # [N,3,H,W]
|
|
|
|
original_resolution = image.shape[2:]
|
|
|
|
if self.config.do_range_check:
|
|
self.check_image_values_range(image)
|
|
|
|
if self.config.do_normalize:
|
|
image = image * 2.0 - 1.0
|
|
|
|
if processing_resolution is not None and processing_resolution > 0:
|
|
image = self.resize_to_max_edge(image, processing_resolution, resample_method_input) # [N,3,PH,PW]
|
|
|
|
image, padding = self.pad_image(image, self.config.vae_scale_factor) # [N,3,PPH,PPW]
|
|
|
|
return image, padding, original_resolution
|
|
|
|
@staticmethod
|
|
def colormap(
|
|
image: Union[np.ndarray, torch.Tensor],
|
|
cmap: str = "Spectral",
|
|
bytes: bool = False,
|
|
_force_method: Optional[str] = None,
|
|
) -> Union[np.ndarray, torch.Tensor]:
|
|
"""
|
|
Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the
|
|
behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral",
|
|
"binary") without having to install or import matplotlib. For all other cases, the function will attempt to use
|
|
the native implementation.
|
|
|
|
Args:
|
|
image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor.
|
|
cmap: Colormap name.
|
|
bytes: Whether to return the output as uint8 or floating point image.
|
|
_force_method:
|
|
Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom
|
|
implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default).
|
|
|
|
Returns:
|
|
An RGB-colorized tensor corresponding to the input image.
|
|
"""
|
|
if not (torch.is_tensor(image) or isinstance(image, np.ndarray)):
|
|
raise ValueError("Argument must be a numpy array or torch tensor.")
|
|
if _force_method not in (None, "matplotlib", "custom"):
|
|
raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.")
|
|
|
|
supported_cmaps = {
|
|
"binary": [
|
|
(1.0, 1.0, 1.0),
|
|
(0.0, 0.0, 0.0),
|
|
],
|
|
"Spectral": [ # Taken from matplotlib/_cm.py
|
|
(0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0]
|
|
(0.83529411764705885, 0.24313725490196078, 0.30980392156862746),
|
|
(0.95686274509803926, 0.42745098039215684, 0.2627450980392157),
|
|
(0.99215686274509807, 0.68235294117647061, 0.38039215686274508),
|
|
(0.99607843137254903, 0.8784313725490196, 0.54509803921568623),
|
|
(1.0, 1.0, 0.74901960784313726),
|
|
(0.90196078431372551, 0.96078431372549022, 0.59607843137254901),
|
|
(0.6705882352941176, 0.8666666666666667, 0.64313725490196083),
|
|
(0.4, 0.76078431372549016, 0.6470588235294118),
|
|
(0.19607843137254902, 0.53333333333333333, 0.74117647058823533),
|
|
(0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1]
|
|
],
|
|
}
|
|
|
|
def method_matplotlib(image, cmap, bytes=False):
|
|
if is_matplotlib_available():
|
|
import matplotlib
|
|
else:
|
|
return None
|
|
|
|
arg_is_pt, device = torch.is_tensor(image), None
|
|
if arg_is_pt:
|
|
image, device = image.cpu().numpy(), image.device
|
|
|
|
if cmap not in matplotlib.colormaps:
|
|
raise ValueError(
|
|
f"Unexpected color map {cmap}; available options are: {', '.join(list(matplotlib.colormaps.keys()))}"
|
|
)
|
|
|
|
cmap = matplotlib.colormaps[cmap]
|
|
out = cmap(image, bytes=bytes) # [?,4]
|
|
out = out[..., :3] # [?,3]
|
|
|
|
if arg_is_pt:
|
|
out = torch.tensor(out, device=device)
|
|
|
|
return out
|
|
|
|
def method_custom(image, cmap, bytes=False):
|
|
arg_is_np = isinstance(image, np.ndarray)
|
|
if arg_is_np:
|
|
image = torch.tensor(image)
|
|
if image.dtype == torch.uint8:
|
|
image = image.float() / 255
|
|
else:
|
|
image = image.float()
|
|
|
|
is_cmap_reversed = cmap.endswith("_r")
|
|
if is_cmap_reversed:
|
|
cmap = cmap[:-2]
|
|
|
|
if cmap not in supported_cmaps:
|
|
raise ValueError(
|
|
f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib."
|
|
)
|
|
|
|
cmap = supported_cmaps[cmap]
|
|
if is_cmap_reversed:
|
|
cmap = cmap[::-1]
|
|
cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) # [K,3]
|
|
K = cmap.shape[0]
|
|
|
|
pos = image.clamp(min=0, max=1) * (K - 1)
|
|
left = pos.long()
|
|
right = (left + 1).clamp(max=K - 1)
|
|
|
|
d = (pos - left.float()).unsqueeze(-1)
|
|
left_colors = cmap[left]
|
|
right_colors = cmap[right]
|
|
|
|
out = (1 - d) * left_colors + d * right_colors
|
|
|
|
if bytes:
|
|
out = (out * 255).to(torch.uint8)
|
|
|
|
if arg_is_np:
|
|
out = out.numpy()
|
|
|
|
return out
|
|
|
|
if _force_method is None and torch.is_tensor(image) and cmap == "Spectral":
|
|
return method_custom(image, cmap, bytes)
|
|
|
|
out = None
|
|
if _force_method != "custom":
|
|
out = method_matplotlib(image, cmap, bytes)
|
|
|
|
if _force_method == "matplotlib" and out is None:
|
|
raise ImportError("Make sure to install matplotlib if you want to use a color map other than 'Spectral'.")
|
|
|
|
if out is None:
|
|
out = method_custom(image, cmap, bytes)
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
def visualize_depth(
|
|
depth: Union[
|
|
PIL.Image.Image,
|
|
np.ndarray,
|
|
torch.Tensor,
|
|
List[PIL.Image.Image],
|
|
List[np.ndarray],
|
|
List[torch.Tensor],
|
|
],
|
|
val_min: float = 0.0,
|
|
val_max: float = 1.0,
|
|
color_map: str = "Spectral",
|
|
) -> List[PIL.Image.Image]:
|
|
"""
|
|
Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`.
|
|
|
|
Args:
|
|
depth (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray],
|
|
List[torch.Tensor]]`): Depth maps.
|
|
val_min (`float`, *optional*, defaults to `0.0`): Minimum value of the visualized depth range.
|
|
val_max (`float`, *optional*, defaults to `1.0`): Maximum value of the visualized depth range.
|
|
color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel
|
|
depth prediction into colored representation.
|
|
|
|
Returns: `List[PIL.Image.Image]` with depth maps visualization.
|
|
"""
|
|
if val_max <= val_min:
|
|
raise ValueError(f"Invalid values range: [{val_min}, {val_max}].")
|
|
|
|
def visualize_depth_one(img, idx=None):
|
|
prefix = "Depth" + (f"[{idx}]" if idx else "")
|
|
if isinstance(img, PIL.Image.Image):
|
|
if img.mode != "I;16":
|
|
raise ValueError(f"{prefix}: invalid PIL mode={img.mode}.")
|
|
img = np.array(img).astype(np.float32) / (2**16 - 1)
|
|
if isinstance(img, np.ndarray) or torch.is_tensor(img):
|
|
if img.ndim != 2:
|
|
raise ValueError(f"{prefix}: unexpected shape={img.shape}.")
|
|
if isinstance(img, np.ndarray):
|
|
img = torch.from_numpy(img)
|
|
if not torch.is_floating_point(img):
|
|
raise ValueError(f"{prefix}: unexpected dtype={img.dtype}.")
|
|
else:
|
|
raise ValueError(f"{prefix}: unexpected type={type(img)}.")
|
|
if val_min != 0.0 or val_max != 1.0:
|
|
img = (img - val_min) / (val_max - val_min)
|
|
img = MarigoldImageProcessor.colormap(img, cmap=color_map, bytes=True) # [H,W,3]
|
|
img = PIL.Image.fromarray(img.cpu().numpy())
|
|
return img
|
|
|
|
if depth is None or isinstance(depth, list) and any(o is None for o in depth):
|
|
raise ValueError("Input depth is `None`")
|
|
if isinstance(depth, (np.ndarray, torch.Tensor)):
|
|
depth = MarigoldImageProcessor.expand_tensor_or_array(depth)
|
|
if isinstance(depth, np.ndarray):
|
|
depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W]
|
|
if not (depth.ndim == 4 and depth.shape[1] == 1): # [N,1,H,W]
|
|
raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].")
|
|
return [visualize_depth_one(img[0], idx) for idx, img in enumerate(depth)]
|
|
elif isinstance(depth, list):
|
|
return [visualize_depth_one(img, idx) for idx, img in enumerate(depth)]
|
|
else:
|
|
raise ValueError(f"Unexpected input type: {type(depth)}")
|
|
|
|
@staticmethod
|
|
def export_depth_to_16bit_png(
|
|
depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
|
|
val_min: float = 0.0,
|
|
val_max: float = 1.0,
|
|
) -> List[PIL.Image.Image]:
|
|
def export_depth_to_16bit_png_one(img, idx=None):
|
|
prefix = "Depth" + (f"[{idx}]" if idx else "")
|
|
if not isinstance(img, np.ndarray) and not torch.is_tensor(img):
|
|
raise ValueError(f"{prefix}: unexpected type={type(img)}.")
|
|
if img.ndim != 2:
|
|
raise ValueError(f"{prefix}: unexpected shape={img.shape}.")
|
|
if torch.is_tensor(img):
|
|
img = img.cpu().numpy()
|
|
if not np.issubdtype(img.dtype, np.floating):
|
|
raise ValueError(f"{prefix}: unexpected dtype={img.dtype}.")
|
|
if val_min != 0.0 or val_max != 1.0:
|
|
img = (img - val_min) / (val_max - val_min)
|
|
img = (img * (2**16 - 1)).astype(np.uint16)
|
|
img = PIL.Image.fromarray(img, mode="I;16")
|
|
return img
|
|
|
|
if depth is None or isinstance(depth, list) and any(o is None for o in depth):
|
|
raise ValueError("Input depth is `None`")
|
|
if isinstance(depth, (np.ndarray, torch.Tensor)):
|
|
depth = MarigoldImageProcessor.expand_tensor_or_array(depth)
|
|
if isinstance(depth, np.ndarray):
|
|
depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W]
|
|
if not (depth.ndim == 4 and depth.shape[1] == 1):
|
|
raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].")
|
|
return [export_depth_to_16bit_png_one(img[0], idx) for idx, img in enumerate(depth)]
|
|
elif isinstance(depth, list):
|
|
return [export_depth_to_16bit_png_one(img, idx) for idx, img in enumerate(depth)]
|
|
else:
|
|
raise ValueError(f"Unexpected input type: {type(depth)}")
|
|
|
|
@staticmethod
|
|
def visualize_normals(
|
|
normals: Union[
|
|
np.ndarray,
|
|
torch.Tensor,
|
|
List[np.ndarray],
|
|
List[torch.Tensor],
|
|
],
|
|
flip_x: bool = False,
|
|
flip_y: bool = False,
|
|
flip_z: bool = False,
|
|
) -> List[PIL.Image.Image]:
|
|
"""
|
|
Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`.
|
|
|
|
Args:
|
|
normals (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
|
|
Surface normals.
|
|
flip_x (`bool`, *optional*, defaults to `False`): Flips the X axis of the normals frame of reference.
|
|
Default direction is right.
|
|
flip_y (`bool`, *optional*, defaults to `False`): Flips the Y axis of the normals frame of reference.
|
|
Default direction is top.
|
|
flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference.
|
|
Default direction is facing the observer.
|
|
|
|
Returns: `List[PIL.Image.Image]` with surface normals visualization.
|
|
"""
|
|
flip_vec = None
|
|
if any((flip_x, flip_y, flip_z)):
|
|
flip_vec = torch.tensor(
|
|
[
|
|
(-1) ** flip_x,
|
|
(-1) ** flip_y,
|
|
(-1) ** flip_z,
|
|
],
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
def visualize_normals_one(img, idx=None):
|
|
img = img.permute(1, 2, 0)
|
|
if flip_vec is not None:
|
|
img *= flip_vec.to(img.device)
|
|
img = (img + 1.0) * 0.5
|
|
img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy()
|
|
img = PIL.Image.fromarray(img)
|
|
return img
|
|
|
|
if normals is None or isinstance(normals, list) and any(o is None for o in normals):
|
|
raise ValueError("Input normals is `None`")
|
|
if isinstance(normals, (np.ndarray, torch.Tensor)):
|
|
normals = MarigoldImageProcessor.expand_tensor_or_array(normals)
|
|
if isinstance(normals, np.ndarray):
|
|
normals = MarigoldImageProcessor.numpy_to_pt(normals) # [N,3,H,W]
|
|
if not (normals.ndim == 4 and normals.shape[1] == 3):
|
|
raise ValueError(f"Unexpected input shape={normals.shape}, expecting [N,3,H,W].")
|
|
return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)]
|
|
elif isinstance(normals, list):
|
|
return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)]
|
|
else:
|
|
raise ValueError(f"Unexpected input type: {type(normals)}")
|
|
|
|
@staticmethod
|
|
def visualize_intrinsics(
|
|
prediction: Union[
|
|
np.ndarray,
|
|
torch.Tensor,
|
|
List[np.ndarray],
|
|
List[torch.Tensor],
|
|
],
|
|
target_properties: Dict[str, Any],
|
|
color_map: Union[str, Dict[str, str]] = "binary",
|
|
) -> List[Dict[str, PIL.Image.Image]]:
|
|
"""
|
|
Visualizes intrinsic image decomposition, such as predictions of the `MarigoldIntrinsicsPipeline`.
|
|
|
|
Args:
|
|
prediction (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
|
|
Intrinsic image decomposition.
|
|
target_properties (`Dict[str, Any]`):
|
|
Decomposition properties. Expected entries: `target_names: List[str]` and a dictionary with keys
|
|
`prediction_space: str`, `sub_target_names: List[Union[str, Null]]` (must have 3 entries, null for
|
|
missing modalities), `up_to_scale: bool`, one for each target and sub-target.
|
|
color_map (`Union[str, Dict[str, str]]`, *optional*, defaults to `"Spectral"`):
|
|
Color map used to convert a single-channel predictions into colored representations. When a dictionary
|
|
is passed, each modality can be colored with its own color map.
|
|
|
|
Returns: `List[Dict[str, PIL.Image.Image]]` with intrinsic image decomposition visualization.
|
|
"""
|
|
if "target_names" not in target_properties:
|
|
raise ValueError("Missing `target_names` in target_properties")
|
|
if not isinstance(color_map, str) and not (
|
|
isinstance(color_map, dict)
|
|
and all(isinstance(k, str) and isinstance(v, str) for k, v in color_map.items())
|
|
):
|
|
raise ValueError("`color_map` must be a string or a dictionary of strings")
|
|
n_targets = len(target_properties["target_names"])
|
|
|
|
def visualize_targets_one(images, idx=None):
|
|
# img: [T, 3, H, W]
|
|
out = {}
|
|
for target_name, img in zip(target_properties["target_names"], images):
|
|
img = img.permute(1, 2, 0) # [H, W, 3]
|
|
prediction_space = target_properties[target_name].get("prediction_space", "srgb")
|
|
if prediction_space == "stack":
|
|
sub_target_names = target_properties[target_name]["sub_target_names"]
|
|
if len(sub_target_names) != 3 or any(
|
|
not (isinstance(s, str) or s is None) for s in sub_target_names
|
|
):
|
|
raise ValueError(f"Unexpected target sub-names {sub_target_names} in {target_name}")
|
|
for i, sub_target_name in enumerate(sub_target_names):
|
|
if sub_target_name is None:
|
|
continue
|
|
sub_img = img[:, :, i]
|
|
sub_prediction_space = target_properties[sub_target_name].get("prediction_space", "srgb")
|
|
if sub_prediction_space == "linear":
|
|
sub_up_to_scale = target_properties[sub_target_name].get("up_to_scale", False)
|
|
if sub_up_to_scale:
|
|
sub_img = sub_img / max(sub_img.max().item(), 1e-6)
|
|
sub_img = sub_img ** (1 / 2.2)
|
|
cmap_name = (
|
|
color_map if isinstance(color_map, str) else color_map.get(sub_target_name, "binary")
|
|
)
|
|
sub_img = MarigoldImageProcessor.colormap(sub_img, cmap=cmap_name, bytes=True)
|
|
sub_img = PIL.Image.fromarray(sub_img.cpu().numpy())
|
|
out[sub_target_name] = sub_img
|
|
elif prediction_space == "linear":
|
|
up_to_scale = target_properties[target_name].get("up_to_scale", False)
|
|
if up_to_scale:
|
|
img = img / max(img.max().item(), 1e-6)
|
|
img = img ** (1 / 2.2)
|
|
elif prediction_space == "srgb":
|
|
pass
|
|
img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy()
|
|
img = PIL.Image.fromarray(img)
|
|
out[target_name] = img
|
|
return out
|
|
|
|
if prediction is None or isinstance(prediction, list) and any(o is None for o in prediction):
|
|
raise ValueError("Input prediction is `None`")
|
|
if isinstance(prediction, (np.ndarray, torch.Tensor)):
|
|
prediction = MarigoldImageProcessor.expand_tensor_or_array(prediction)
|
|
if isinstance(prediction, np.ndarray):
|
|
prediction = MarigoldImageProcessor.numpy_to_pt(prediction) # [N*T,3,H,W]
|
|
if not (prediction.ndim == 4 and prediction.shape[1] == 3 and prediction.shape[0] % n_targets == 0):
|
|
raise ValueError(f"Unexpected input shape={prediction.shape}, expecting [N*T,3,H,W].")
|
|
N_T, _, H, W = prediction.shape
|
|
N = N_T // n_targets
|
|
prediction = prediction.reshape(N, n_targets, 3, H, W)
|
|
return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)]
|
|
elif isinstance(prediction, list):
|
|
return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)]
|
|
else:
|
|
raise ValueError(f"Unexpected input type: {type(prediction)}")
|
|
|
|
@staticmethod
|
|
def visualize_uncertainty(
|
|
uncertainty: Union[
|
|
np.ndarray,
|
|
torch.Tensor,
|
|
List[np.ndarray],
|
|
List[torch.Tensor],
|
|
],
|
|
saturation_percentile=95,
|
|
) -> List[PIL.Image.Image]:
|
|
"""
|
|
Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline`, `MarigoldNormalsPipeline`, or
|
|
`MarigoldIntrinsicsPipeline`.
|
|
|
|
Args:
|
|
uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
|
|
Uncertainty maps.
|
|
saturation_percentile (`int`, *optional*, defaults to `95`):
|
|
Specifies the percentile uncertainty value visualized with maximum intensity.
|
|
|
|
Returns: `List[PIL.Image.Image]` with uncertainty visualization.
|
|
"""
|
|
|
|
def visualize_uncertainty_one(img, idx=None):
|
|
prefix = "Uncertainty" + (f"[{idx}]" if idx else "")
|
|
if img.min() < 0:
|
|
raise ValueError(f"{prefix}: unexpected data range, min={img.min()}.")
|
|
img = img.permute(1, 2, 0) # [H,W,C]
|
|
img = img.squeeze(2).cpu().numpy() # [H,W] or [H,W,3]
|
|
saturation_value = np.percentile(img, saturation_percentile)
|
|
img = np.clip(img * 255 / saturation_value, 0, 255)
|
|
img = img.astype(np.uint8)
|
|
img = PIL.Image.fromarray(img)
|
|
return img
|
|
|
|
if uncertainty is None or isinstance(uncertainty, list) and any(o is None for o in uncertainty):
|
|
raise ValueError("Input uncertainty is `None`")
|
|
if isinstance(uncertainty, (np.ndarray, torch.Tensor)):
|
|
uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty)
|
|
if isinstance(uncertainty, np.ndarray):
|
|
uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,C,H,W]
|
|
if not (uncertainty.ndim == 4 and uncertainty.shape[1] in (1, 3)):
|
|
raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,C,H,W] with C in (1,3).")
|
|
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)]
|
|
elif isinstance(uncertainty, list):
|
|
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)]
|
|
else:
|
|
raise ValueError(f"Unexpected input type: {type(uncertainty)}")
|