864 lines
31 KiB
Python
864 lines
31 KiB
Python
import math
|
|
from collections import OrderedDict
|
|
from functools import partial
|
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ..ops.misc import Conv2dNormActivation, MLP
|
|
from ..transforms._presets import ImageClassification, InterpolationMode
|
|
from ..utils import _log_api_usage_once
|
|
from ._api import register_model, Weights, WeightsEnum
|
|
from ._meta import _IMAGENET_CATEGORIES
|
|
from ._utils import _ovewrite_named_param, handle_legacy_interface
|
|
|
|
|
|
__all__ = [
|
|
"VisionTransformer",
|
|
"ViT_B_16_Weights",
|
|
"ViT_B_32_Weights",
|
|
"ViT_L_16_Weights",
|
|
"ViT_L_32_Weights",
|
|
"ViT_H_14_Weights",
|
|
"vit_b_16",
|
|
"vit_b_32",
|
|
"vit_l_16",
|
|
"vit_l_32",
|
|
"vit_h_14",
|
|
]
|
|
|
|
|
|
class ConvStemConfig(NamedTuple):
|
|
out_channels: int
|
|
kernel_size: int
|
|
stride: int
|
|
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
|
|
activation_layer: Callable[..., nn.Module] = nn.ReLU
|
|
|
|
|
|
class MLPBlock(MLP):
|
|
"""Transformer MLP block."""
|
|
|
|
_version = 2
|
|
|
|
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
|
|
super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.normal_(m.bias, std=1e-6)
|
|
|
|
def _load_from_state_dict(
|
|
self,
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
):
|
|
version = local_metadata.get("version", None)
|
|
|
|
if version is None or version < 2:
|
|
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
|
|
for i in range(2):
|
|
for type in ["weight", "bias"]:
|
|
old_key = f"{prefix}linear_{i+1}.{type}"
|
|
new_key = f"{prefix}{3*i}.{type}"
|
|
if old_key in state_dict:
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
super()._load_from_state_dict(
|
|
state_dict,
|
|
prefix,
|
|
local_metadata,
|
|
strict,
|
|
missing_keys,
|
|
unexpected_keys,
|
|
error_msgs,
|
|
)
|
|
|
|
|
|
class EncoderBlock(nn.Module):
|
|
"""Transformer encoder block."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
hidden_dim: int,
|
|
mlp_dim: int,
|
|
dropout: float,
|
|
attention_dropout: float,
|
|
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
|
|
# Attention block
|
|
self.ln_1 = norm_layer(hidden_dim)
|
|
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
# MLP block
|
|
self.ln_2 = norm_layer(hidden_dim)
|
|
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
|
|
|
|
def forward(self, input: torch.Tensor):
|
|
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
|
|
x = self.ln_1(input)
|
|
x, _ = self.self_attention(x, x, x, need_weights=False)
|
|
x = self.dropout(x)
|
|
x = x + input
|
|
|
|
y = self.ln_2(x)
|
|
y = self.mlp(y)
|
|
return x + y
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
"""Transformer Model Encoder for sequence to sequence translation."""
|
|
|
|
def __init__(
|
|
self,
|
|
seq_length: int,
|
|
num_layers: int,
|
|
num_heads: int,
|
|
hidden_dim: int,
|
|
mlp_dim: int,
|
|
dropout: float,
|
|
attention_dropout: float,
|
|
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
|
):
|
|
super().__init__()
|
|
# Note that batch_size is on the first dim because
|
|
# we have batch_first=True in nn.MultiAttention() by default
|
|
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
|
|
self.dropout = nn.Dropout(dropout)
|
|
layers: OrderedDict[str, nn.Module] = OrderedDict()
|
|
for i in range(num_layers):
|
|
layers[f"encoder_layer_{i}"] = EncoderBlock(
|
|
num_heads,
|
|
hidden_dim,
|
|
mlp_dim,
|
|
dropout,
|
|
attention_dropout,
|
|
norm_layer,
|
|
)
|
|
self.layers = nn.Sequential(layers)
|
|
self.ln = norm_layer(hidden_dim)
|
|
|
|
def forward(self, input: torch.Tensor):
|
|
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
|
|
input = input + self.pos_embedding
|
|
return self.ln(self.layers(self.dropout(input)))
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""
|
|
|
|
def __init__(
|
|
self,
|
|
image_size: int,
|
|
patch_size: int,
|
|
num_layers: int,
|
|
num_heads: int,
|
|
hidden_dim: int,
|
|
mlp_dim: int,
|
|
dropout: float = 0.0,
|
|
attention_dropout: float = 0.0,
|
|
num_classes: int = 1000,
|
|
representation_size: Optional[int] = None,
|
|
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
|
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
|
|
):
|
|
super().__init__()
|
|
_log_api_usage_once(self)
|
|
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
|
|
self.image_size = image_size
|
|
self.patch_size = patch_size
|
|
self.hidden_dim = hidden_dim
|
|
self.mlp_dim = mlp_dim
|
|
self.attention_dropout = attention_dropout
|
|
self.dropout = dropout
|
|
self.num_classes = num_classes
|
|
self.representation_size = representation_size
|
|
self.norm_layer = norm_layer
|
|
|
|
if conv_stem_configs is not None:
|
|
# As per https://arxiv.org/abs/2106.14881
|
|
seq_proj = nn.Sequential()
|
|
prev_channels = 3
|
|
for i, conv_stem_layer_config in enumerate(conv_stem_configs):
|
|
seq_proj.add_module(
|
|
f"conv_bn_relu_{i}",
|
|
Conv2dNormActivation(
|
|
in_channels=prev_channels,
|
|
out_channels=conv_stem_layer_config.out_channels,
|
|
kernel_size=conv_stem_layer_config.kernel_size,
|
|
stride=conv_stem_layer_config.stride,
|
|
norm_layer=conv_stem_layer_config.norm_layer,
|
|
activation_layer=conv_stem_layer_config.activation_layer,
|
|
),
|
|
)
|
|
prev_channels = conv_stem_layer_config.out_channels
|
|
seq_proj.add_module(
|
|
"conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
|
|
)
|
|
self.conv_proj: nn.Module = seq_proj
|
|
else:
|
|
self.conv_proj = nn.Conv2d(
|
|
in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
|
|
)
|
|
|
|
seq_length = (image_size // patch_size) ** 2
|
|
|
|
# Add a class token
|
|
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
|
|
seq_length += 1
|
|
|
|
self.encoder = Encoder(
|
|
seq_length,
|
|
num_layers,
|
|
num_heads,
|
|
hidden_dim,
|
|
mlp_dim,
|
|
dropout,
|
|
attention_dropout,
|
|
norm_layer,
|
|
)
|
|
self.seq_length = seq_length
|
|
|
|
heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
|
|
if representation_size is None:
|
|
heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
|
|
else:
|
|
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
|
|
heads_layers["act"] = nn.Tanh()
|
|
heads_layers["head"] = nn.Linear(representation_size, num_classes)
|
|
|
|
self.heads = nn.Sequential(heads_layers)
|
|
|
|
if isinstance(self.conv_proj, nn.Conv2d):
|
|
# Init the patchify stem
|
|
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
|
|
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
|
|
if self.conv_proj.bias is not None:
|
|
nn.init.zeros_(self.conv_proj.bias)
|
|
elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
|
|
# Init the last 1x1 conv of the conv stem
|
|
nn.init.normal_(
|
|
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
|
|
)
|
|
if self.conv_proj.conv_last.bias is not None:
|
|
nn.init.zeros_(self.conv_proj.conv_last.bias)
|
|
|
|
if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
|
|
fan_in = self.heads.pre_logits.in_features
|
|
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
|
|
nn.init.zeros_(self.heads.pre_logits.bias)
|
|
|
|
if isinstance(self.heads.head, nn.Linear):
|
|
nn.init.zeros_(self.heads.head.weight)
|
|
nn.init.zeros_(self.heads.head.bias)
|
|
|
|
def _process_input(self, x: torch.Tensor) -> torch.Tensor:
|
|
n, c, h, w = x.shape
|
|
p = self.patch_size
|
|
torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
|
|
torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
|
|
n_h = h // p
|
|
n_w = w // p
|
|
|
|
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
|
|
x = self.conv_proj(x)
|
|
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
|
|
x = x.reshape(n, self.hidden_dim, n_h * n_w)
|
|
|
|
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
|
|
# The self attention layer expects inputs in the format (N, S, E)
|
|
# where S is the source sequence length, N is the batch size, E is the
|
|
# embedding dimension
|
|
x = x.permute(0, 2, 1)
|
|
|
|
return x
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
# Reshape and permute the input tensor
|
|
x = self._process_input(x)
|
|
n = x.shape[0]
|
|
|
|
# Expand the class token to the full batch
|
|
batch_class_token = self.class_token.expand(n, -1, -1)
|
|
x = torch.cat([batch_class_token, x], dim=1)
|
|
|
|
x = self.encoder(x)
|
|
|
|
# Classifier "token" as used by standard language architectures
|
|
x = x[:, 0]
|
|
|
|
x = self.heads(x)
|
|
|
|
return x
|
|
|
|
|
|
def _vision_transformer(
|
|
patch_size: int,
|
|
num_layers: int,
|
|
num_heads: int,
|
|
hidden_dim: int,
|
|
mlp_dim: int,
|
|
weights: Optional[WeightsEnum],
|
|
progress: bool,
|
|
**kwargs: Any,
|
|
) -> VisionTransformer:
|
|
if weights is not None:
|
|
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
|
assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
|
|
_ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0])
|
|
image_size = kwargs.pop("image_size", 224)
|
|
|
|
model = VisionTransformer(
|
|
image_size=image_size,
|
|
patch_size=patch_size,
|
|
num_layers=num_layers,
|
|
num_heads=num_heads,
|
|
hidden_dim=hidden_dim,
|
|
mlp_dim=mlp_dim,
|
|
**kwargs,
|
|
)
|
|
|
|
if weights:
|
|
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
|
|
|
return model
|
|
|
|
|
|
_COMMON_META: Dict[str, Any] = {
|
|
"categories": _IMAGENET_CATEGORIES,
|
|
}
|
|
|
|
_COMMON_SWAG_META = {
|
|
**_COMMON_META,
|
|
"recipe": "https://github.com/facebookresearch/SWAG",
|
|
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
|
|
}
|
|
|
|
|
|
class ViT_B_16_Weights(WeightsEnum):
|
|
IMAGENET1K_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_b_16-c867db91.pth",
|
|
transforms=partial(ImageClassification, crop_size=224),
|
|
meta={
|
|
**_COMMON_META,
|
|
"num_params": 86567656,
|
|
"min_size": (224, 224),
|
|
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16",
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 81.072,
|
|
"acc@5": 95.318,
|
|
}
|
|
},
|
|
"_ops": 17.564,
|
|
"_file_size": 330.285,
|
|
"_docs": """
|
|
These weights were trained from scratch by using a modified version of `DeIT
|
|
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
|
|
""",
|
|
},
|
|
)
|
|
IMAGENET1K_SWAG_E2E_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
|
|
transforms=partial(
|
|
ImageClassification,
|
|
crop_size=384,
|
|
resize_size=384,
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
meta={
|
|
**_COMMON_SWAG_META,
|
|
"num_params": 86859496,
|
|
"min_size": (384, 384),
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 85.304,
|
|
"acc@5": 97.650,
|
|
}
|
|
},
|
|
"_ops": 55.484,
|
|
"_file_size": 331.398,
|
|
"_docs": """
|
|
These weights are learnt via transfer learning by end-to-end fine-tuning the original
|
|
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
|
|
""",
|
|
},
|
|
)
|
|
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth",
|
|
transforms=partial(
|
|
ImageClassification,
|
|
crop_size=224,
|
|
resize_size=224,
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
meta={
|
|
**_COMMON_SWAG_META,
|
|
"recipe": "https://github.com/pytorch/vision/pull/5793",
|
|
"num_params": 86567656,
|
|
"min_size": (224, 224),
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 81.886,
|
|
"acc@5": 96.180,
|
|
}
|
|
},
|
|
"_ops": 17.564,
|
|
"_file_size": 330.285,
|
|
"_docs": """
|
|
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
|
|
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
|
|
""",
|
|
},
|
|
)
|
|
DEFAULT = IMAGENET1K_V1
|
|
|
|
|
|
class ViT_B_32_Weights(WeightsEnum):
|
|
IMAGENET1K_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
|
|
transforms=partial(ImageClassification, crop_size=224),
|
|
meta={
|
|
**_COMMON_META,
|
|
"num_params": 88224232,
|
|
"min_size": (224, 224),
|
|
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32",
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 75.912,
|
|
"acc@5": 92.466,
|
|
}
|
|
},
|
|
"_ops": 4.409,
|
|
"_file_size": 336.604,
|
|
"_docs": """
|
|
These weights were trained from scratch by using a modified version of `DeIT
|
|
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
|
|
""",
|
|
},
|
|
)
|
|
DEFAULT = IMAGENET1K_V1
|
|
|
|
|
|
class ViT_L_16_Weights(WeightsEnum):
|
|
IMAGENET1K_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
|
|
transforms=partial(ImageClassification, crop_size=224, resize_size=242),
|
|
meta={
|
|
**_COMMON_META,
|
|
"num_params": 304326632,
|
|
"min_size": (224, 224),
|
|
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16",
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 79.662,
|
|
"acc@5": 94.638,
|
|
}
|
|
},
|
|
"_ops": 61.555,
|
|
"_file_size": 1161.023,
|
|
"_docs": """
|
|
These weights were trained from scratch by using a modified version of TorchVision's
|
|
`new training recipe
|
|
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
|
""",
|
|
},
|
|
)
|
|
IMAGENET1K_SWAG_E2E_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
|
|
transforms=partial(
|
|
ImageClassification,
|
|
crop_size=512,
|
|
resize_size=512,
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
meta={
|
|
**_COMMON_SWAG_META,
|
|
"num_params": 305174504,
|
|
"min_size": (512, 512),
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 88.064,
|
|
"acc@5": 98.512,
|
|
}
|
|
},
|
|
"_ops": 361.986,
|
|
"_file_size": 1164.258,
|
|
"_docs": """
|
|
These weights are learnt via transfer learning by end-to-end fine-tuning the original
|
|
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
|
|
""",
|
|
},
|
|
)
|
|
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth",
|
|
transforms=partial(
|
|
ImageClassification,
|
|
crop_size=224,
|
|
resize_size=224,
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
meta={
|
|
**_COMMON_SWAG_META,
|
|
"recipe": "https://github.com/pytorch/vision/pull/5793",
|
|
"num_params": 304326632,
|
|
"min_size": (224, 224),
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 85.146,
|
|
"acc@5": 97.422,
|
|
}
|
|
},
|
|
"_ops": 61.555,
|
|
"_file_size": 1161.023,
|
|
"_docs": """
|
|
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
|
|
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
|
|
""",
|
|
},
|
|
)
|
|
DEFAULT = IMAGENET1K_V1
|
|
|
|
|
|
class ViT_L_32_Weights(WeightsEnum):
|
|
IMAGENET1K_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_l_32-c7638314.pth",
|
|
transforms=partial(ImageClassification, crop_size=224),
|
|
meta={
|
|
**_COMMON_META,
|
|
"num_params": 306535400,
|
|
"min_size": (224, 224),
|
|
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32",
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 76.972,
|
|
"acc@5": 93.07,
|
|
}
|
|
},
|
|
"_ops": 15.378,
|
|
"_file_size": 1169.449,
|
|
"_docs": """
|
|
These weights were trained from scratch by using a modified version of `DeIT
|
|
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
|
|
""",
|
|
},
|
|
)
|
|
DEFAULT = IMAGENET1K_V1
|
|
|
|
|
|
class ViT_H_14_Weights(WeightsEnum):
|
|
IMAGENET1K_SWAG_E2E_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
|
|
transforms=partial(
|
|
ImageClassification,
|
|
crop_size=518,
|
|
resize_size=518,
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
meta={
|
|
**_COMMON_SWAG_META,
|
|
"num_params": 633470440,
|
|
"min_size": (518, 518),
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 88.552,
|
|
"acc@5": 98.694,
|
|
}
|
|
},
|
|
"_ops": 1016.717,
|
|
"_file_size": 2416.643,
|
|
"_docs": """
|
|
These weights are learnt via transfer learning by end-to-end fine-tuning the original
|
|
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
|
|
""",
|
|
},
|
|
)
|
|
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
|
|
url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth",
|
|
transforms=partial(
|
|
ImageClassification,
|
|
crop_size=224,
|
|
resize_size=224,
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
meta={
|
|
**_COMMON_SWAG_META,
|
|
"recipe": "https://github.com/pytorch/vision/pull/5793",
|
|
"num_params": 632045800,
|
|
"min_size": (224, 224),
|
|
"_metrics": {
|
|
"ImageNet-1K": {
|
|
"acc@1": 85.708,
|
|
"acc@5": 97.730,
|
|
}
|
|
},
|
|
"_ops": 167.295,
|
|
"_file_size": 2411.209,
|
|
"_docs": """
|
|
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
|
|
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
|
|
""",
|
|
},
|
|
)
|
|
DEFAULT = IMAGENET1K_SWAG_E2E_V1
|
|
|
|
|
|
@register_model()
|
|
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
|
|
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
|
|
"""
|
|
Constructs a vit_b_16 architecture from
|
|
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
|
|
|
|
Args:
|
|
weights (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The pretrained
|
|
weights to use. See :class:`~torchvision.models.ViT_B_16_Weights`
|
|
below for more details and possible values. By default, no pre-trained weights are used.
|
|
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
|
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
|
|
base class. Please refer to the `source code
|
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
|
|
for more details about this class.
|
|
|
|
.. autoclass:: torchvision.models.ViT_B_16_Weights
|
|
:members:
|
|
"""
|
|
weights = ViT_B_16_Weights.verify(weights)
|
|
|
|
return _vision_transformer(
|
|
patch_size=16,
|
|
num_layers=12,
|
|
num_heads=12,
|
|
hidden_dim=768,
|
|
mlp_dim=3072,
|
|
weights=weights,
|
|
progress=progress,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@register_model()
|
|
@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1))
|
|
def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
|
|
"""
|
|
Constructs a vit_b_32 architecture from
|
|
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
|
|
|
|
Args:
|
|
weights (:class:`~torchvision.models.ViT_B_32_Weights`, optional): The pretrained
|
|
weights to use. See :class:`~torchvision.models.ViT_B_32_Weights`
|
|
below for more details and possible values. By default, no pre-trained weights are used.
|
|
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
|
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
|
|
base class. Please refer to the `source code
|
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
|
|
for more details about this class.
|
|
|
|
.. autoclass:: torchvision.models.ViT_B_32_Weights
|
|
:members:
|
|
"""
|
|
weights = ViT_B_32_Weights.verify(weights)
|
|
|
|
return _vision_transformer(
|
|
patch_size=32,
|
|
num_layers=12,
|
|
num_heads=12,
|
|
hidden_dim=768,
|
|
mlp_dim=3072,
|
|
weights=weights,
|
|
progress=progress,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@register_model()
|
|
@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1))
|
|
def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
|
|
"""
|
|
Constructs a vit_l_16 architecture from
|
|
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
|
|
|
|
Args:
|
|
weights (:class:`~torchvision.models.ViT_L_16_Weights`, optional): The pretrained
|
|
weights to use. See :class:`~torchvision.models.ViT_L_16_Weights`
|
|
below for more details and possible values. By default, no pre-trained weights are used.
|
|
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
|
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
|
|
base class. Please refer to the `source code
|
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
|
|
for more details about this class.
|
|
|
|
.. autoclass:: torchvision.models.ViT_L_16_Weights
|
|
:members:
|
|
"""
|
|
weights = ViT_L_16_Weights.verify(weights)
|
|
|
|
return _vision_transformer(
|
|
patch_size=16,
|
|
num_layers=24,
|
|
num_heads=16,
|
|
hidden_dim=1024,
|
|
mlp_dim=4096,
|
|
weights=weights,
|
|
progress=progress,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@register_model()
|
|
@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1))
|
|
def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
|
|
"""
|
|
Constructs a vit_l_32 architecture from
|
|
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
|
|
|
|
Args:
|
|
weights (:class:`~torchvision.models.ViT_L_32_Weights`, optional): The pretrained
|
|
weights to use. See :class:`~torchvision.models.ViT_L_32_Weights`
|
|
below for more details and possible values. By default, no pre-trained weights are used.
|
|
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
|
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
|
|
base class. Please refer to the `source code
|
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
|
|
for more details about this class.
|
|
|
|
.. autoclass:: torchvision.models.ViT_L_32_Weights
|
|
:members:
|
|
"""
|
|
weights = ViT_L_32_Weights.verify(weights)
|
|
|
|
return _vision_transformer(
|
|
patch_size=32,
|
|
num_layers=24,
|
|
num_heads=16,
|
|
hidden_dim=1024,
|
|
mlp_dim=4096,
|
|
weights=weights,
|
|
progress=progress,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@register_model()
|
|
@handle_legacy_interface(weights=("pretrained", None))
|
|
def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
|
|
"""
|
|
Constructs a vit_h_14 architecture from
|
|
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
|
|
|
|
Args:
|
|
weights (:class:`~torchvision.models.ViT_H_14_Weights`, optional): The pretrained
|
|
weights to use. See :class:`~torchvision.models.ViT_H_14_Weights`
|
|
below for more details and possible values. By default, no pre-trained weights are used.
|
|
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
|
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
|
|
base class. Please refer to the `source code
|
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
|
|
for more details about this class.
|
|
|
|
.. autoclass:: torchvision.models.ViT_H_14_Weights
|
|
:members:
|
|
"""
|
|
weights = ViT_H_14_Weights.verify(weights)
|
|
|
|
return _vision_transformer(
|
|
patch_size=14,
|
|
num_layers=32,
|
|
num_heads=16,
|
|
hidden_dim=1280,
|
|
mlp_dim=5120,
|
|
weights=weights,
|
|
progress=progress,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def interpolate_embeddings(
|
|
image_size: int,
|
|
patch_size: int,
|
|
model_state: "OrderedDict[str, torch.Tensor]",
|
|
interpolation_mode: str = "bicubic",
|
|
reset_heads: bool = False,
|
|
) -> "OrderedDict[str, torch.Tensor]":
|
|
"""This function helps interpolate positional embeddings during checkpoint loading,
|
|
especially when you want to apply a pre-trained model on images with different resolution.
|
|
|
|
Args:
|
|
image_size (int): Image size of the new model.
|
|
patch_size (int): Patch size of the new model.
|
|
model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
|
|
interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
|
|
reset_heads (bool): If true, not copying the state of heads. Default: False.
|
|
|
|
Returns:
|
|
OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
|
|
"""
|
|
# Shape of pos_embedding is (1, seq_length, hidden_dim)
|
|
pos_embedding = model_state["encoder.pos_embedding"]
|
|
n, seq_length, hidden_dim = pos_embedding.shape
|
|
if n != 1:
|
|
raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
|
|
|
|
new_seq_length = (image_size // patch_size) ** 2 + 1
|
|
|
|
# Need to interpolate the weights for the position embedding.
|
|
# We do this by reshaping the positions embeddings to a 2d grid, performing
|
|
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
|
|
if new_seq_length != seq_length:
|
|
# The class token embedding shouldn't be interpolated, so we split it up.
|
|
seq_length -= 1
|
|
new_seq_length -= 1
|
|
pos_embedding_token = pos_embedding[:, :1, :]
|
|
pos_embedding_img = pos_embedding[:, 1:, :]
|
|
|
|
# (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
|
|
pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
|
|
seq_length_1d = int(math.sqrt(seq_length))
|
|
if seq_length_1d * seq_length_1d != seq_length:
|
|
raise ValueError(
|
|
f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}"
|
|
)
|
|
|
|
# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
|
|
pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
|
|
new_seq_length_1d = image_size // patch_size
|
|
|
|
# Perform interpolation.
|
|
# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
|
|
new_pos_embedding_img = nn.functional.interpolate(
|
|
pos_embedding_img,
|
|
size=new_seq_length_1d,
|
|
mode=interpolation_mode,
|
|
align_corners=True,
|
|
)
|
|
|
|
# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
|
|
new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
|
|
|
|
# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
|
|
new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
|
|
new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)
|
|
|
|
model_state["encoder.pos_embedding"] = new_pos_embedding
|
|
|
|
if reset_heads:
|
|
model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
|
|
for k, v in model_state.items():
|
|
if not k.startswith("heads"):
|
|
model_state_copy[k] = v
|
|
model_state = model_state_copy
|
|
|
|
return model_state
|