217 lines
8.3 KiB
Python
217 lines
8.3 KiB
Python
![]() |
"""
|
||
|
This file is part of the private API. Please do not use directly these classes as they will be modified on
|
||
|
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
|
||
|
"""
|
||
|
from typing import Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
from torch import nn, Tensor
|
||
|
|
||
|
from . import functional as F, InterpolationMode
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"ObjectDetection",
|
||
|
"ImageClassification",
|
||
|
"VideoClassification",
|
||
|
"SemanticSegmentation",
|
||
|
"OpticalFlow",
|
||
|
]
|
||
|
|
||
|
|
||
|
class ObjectDetection(nn.Module):
|
||
|
def forward(self, img: Tensor) -> Tensor:
|
||
|
if not isinstance(img, Tensor):
|
||
|
img = F.pil_to_tensor(img)
|
||
|
return F.convert_image_dtype(img, torch.float)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return self.__class__.__name__ + "()"
|
||
|
|
||
|
def describe(self) -> str:
|
||
|
return (
|
||
|
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
|
||
|
"The images are rescaled to ``[0.0, 1.0]``."
|
||
|
)
|
||
|
|
||
|
|
||
|
class ImageClassification(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
crop_size: int,
|
||
|
resize_size: int = 256,
|
||
|
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
|
||
|
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
|
||
|
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
||
|
antialias: Optional[bool] = True,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self.crop_size = [crop_size]
|
||
|
self.resize_size = [resize_size]
|
||
|
self.mean = list(mean)
|
||
|
self.std = list(std)
|
||
|
self.interpolation = interpolation
|
||
|
self.antialias = antialias
|
||
|
|
||
|
def forward(self, img: Tensor) -> Tensor:
|
||
|
img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
|
||
|
img = F.center_crop(img, self.crop_size)
|
||
|
if not isinstance(img, Tensor):
|
||
|
img = F.pil_to_tensor(img)
|
||
|
img = F.convert_image_dtype(img, torch.float)
|
||
|
img = F.normalize(img, mean=self.mean, std=self.std)
|
||
|
return img
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
format_string = self.__class__.__name__ + "("
|
||
|
format_string += f"\n crop_size={self.crop_size}"
|
||
|
format_string += f"\n resize_size={self.resize_size}"
|
||
|
format_string += f"\n mean={self.mean}"
|
||
|
format_string += f"\n std={self.std}"
|
||
|
format_string += f"\n interpolation={self.interpolation}"
|
||
|
format_string += "\n)"
|
||
|
return format_string
|
||
|
|
||
|
def describe(self) -> str:
|
||
|
return (
|
||
|
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
|
||
|
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
|
||
|
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
|
||
|
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
|
||
|
)
|
||
|
|
||
|
|
||
|
class VideoClassification(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
crop_size: Tuple[int, int],
|
||
|
resize_size: Union[Tuple[int], Tuple[int, int]],
|
||
|
mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
|
||
|
std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
|
||
|
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self.crop_size = list(crop_size)
|
||
|
self.resize_size = list(resize_size)
|
||
|
self.mean = list(mean)
|
||
|
self.std = list(std)
|
||
|
self.interpolation = interpolation
|
||
|
|
||
|
def forward(self, vid: Tensor) -> Tensor:
|
||
|
need_squeeze = False
|
||
|
if vid.ndim < 5:
|
||
|
vid = vid.unsqueeze(dim=0)
|
||
|
need_squeeze = True
|
||
|
|
||
|
N, T, C, H, W = vid.shape
|
||
|
vid = vid.view(-1, C, H, W)
|
||
|
# We hard-code antialias=False to preserve results after we changed
|
||
|
# its default from None to True (see
|
||
|
# https://github.com/pytorch/vision/pull/7160)
|
||
|
# TODO: we could re-train the video models with antialias=True?
|
||
|
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation, antialias=False)
|
||
|
vid = F.center_crop(vid, self.crop_size)
|
||
|
vid = F.convert_image_dtype(vid, torch.float)
|
||
|
vid = F.normalize(vid, mean=self.mean, std=self.std)
|
||
|
H, W = self.crop_size
|
||
|
vid = vid.view(N, T, C, H, W)
|
||
|
vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W)
|
||
|
|
||
|
if need_squeeze:
|
||
|
vid = vid.squeeze(dim=0)
|
||
|
return vid
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
format_string = self.__class__.__name__ + "("
|
||
|
format_string += f"\n crop_size={self.crop_size}"
|
||
|
format_string += f"\n resize_size={self.resize_size}"
|
||
|
format_string += f"\n mean={self.mean}"
|
||
|
format_string += f"\n std={self.std}"
|
||
|
format_string += f"\n interpolation={self.interpolation}"
|
||
|
format_string += "\n)"
|
||
|
return format_string
|
||
|
|
||
|
def describe(self) -> str:
|
||
|
return (
|
||
|
"Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
|
||
|
f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
|
||
|
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
|
||
|
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
|
||
|
"dimensions are permuted to ``(..., C, T, H, W)`` tensors."
|
||
|
)
|
||
|
|
||
|
|
||
|
class SemanticSegmentation(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
resize_size: Optional[int],
|
||
|
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
|
||
|
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
|
||
|
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
||
|
antialias: Optional[bool] = True,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self.resize_size = [resize_size] if resize_size is not None else None
|
||
|
self.mean = list(mean)
|
||
|
self.std = list(std)
|
||
|
self.interpolation = interpolation
|
||
|
self.antialias = antialias
|
||
|
|
||
|
def forward(self, img: Tensor) -> Tensor:
|
||
|
if isinstance(self.resize_size, list):
|
||
|
img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
|
||
|
if not isinstance(img, Tensor):
|
||
|
img = F.pil_to_tensor(img)
|
||
|
img = F.convert_image_dtype(img, torch.float)
|
||
|
img = F.normalize(img, mean=self.mean, std=self.std)
|
||
|
return img
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
format_string = self.__class__.__name__ + "("
|
||
|
format_string += f"\n resize_size={self.resize_size}"
|
||
|
format_string += f"\n mean={self.mean}"
|
||
|
format_string += f"\n std={self.std}"
|
||
|
format_string += f"\n interpolation={self.interpolation}"
|
||
|
format_string += "\n)"
|
||
|
return format_string
|
||
|
|
||
|
def describe(self) -> str:
|
||
|
return (
|
||
|
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
|
||
|
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
|
||
|
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
|
||
|
f"``std={self.std}``."
|
||
|
)
|
||
|
|
||
|
|
||
|
class OpticalFlow(nn.Module):
|
||
|
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
|
||
|
if not isinstance(img1, Tensor):
|
||
|
img1 = F.pil_to_tensor(img1)
|
||
|
if not isinstance(img2, Tensor):
|
||
|
img2 = F.pil_to_tensor(img2)
|
||
|
|
||
|
img1 = F.convert_image_dtype(img1, torch.float)
|
||
|
img2 = F.convert_image_dtype(img2, torch.float)
|
||
|
|
||
|
# map [0, 1] into [-1, 1]
|
||
|
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||
|
img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||
|
|
||
|
img1 = img1.contiguous()
|
||
|
img2 = img2.contiguous()
|
||
|
|
||
|
return img1, img2
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return self.__class__.__name__ + "()"
|
||
|
|
||
|
def describe(self) -> str:
|
||
|
return (
|
||
|
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
|
||
|
"The images are rescaled to ``[-1.0, 1.0]``."
|
||
|
)
|