876 lines
32 KiB
Python
876 lines
32 KiB
Python
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Upsample2D(nn.Module):
|
|
"""
|
|
An upsampling layer with an optional convolution.
|
|
|
|
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
|
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
upsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.use_conv_transpose = use_conv_transpose
|
|
self.name = name
|
|
|
|
conv = None
|
|
if use_conv_transpose:
|
|
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
|
elif use_conv:
|
|
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
|
|
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
|
if name == "conv":
|
|
self.conv = conv
|
|
else:
|
|
self.Conv2d_0 = conv
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
if self.use_conv_transpose:
|
|
return self.conv(x)
|
|
|
|
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
|
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
|
if self.use_conv:
|
|
if self.name == "conv":
|
|
x = self.conv(x)
|
|
else:
|
|
x = self.Conv2d_0(x)
|
|
|
|
return x
|
|
|
|
|
|
class Downsample2D(nn.Module):
|
|
"""
|
|
A downsampling layer with an optional convolution.
|
|
|
|
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
|
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
downsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.padding = padding
|
|
stride = 2
|
|
self.name = name
|
|
|
|
if use_conv:
|
|
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
|
else:
|
|
assert self.channels == self.out_channels
|
|
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
|
|
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
|
if name == "conv":
|
|
self.Conv2d_0 = conv
|
|
self.conv = conv
|
|
elif name == "Conv2d_0":
|
|
self.conv = conv
|
|
else:
|
|
self.conv = conv
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
if self.use_conv and self.padding == 0:
|
|
pad = (0, 1, 0, 1)
|
|
x = F.pad(x, pad, mode="constant", value=0)
|
|
|
|
assert x.shape[1] == self.channels
|
|
x = self.conv(x)
|
|
|
|
return x
|
|
|
|
|
|
class FirUpsample2D(nn.Module):
|
|
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
|
super().__init__()
|
|
out_channels = out_channels if out_channels else channels
|
|
if use_conv:
|
|
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
self.use_conv = use_conv
|
|
self.fir_kernel = fir_kernel
|
|
self.out_channels = out_channels
|
|
|
|
def _upsample_2d(self, x, w=None, k=None, factor=2, gain=1):
|
|
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
|
|
|
Args:
|
|
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
|
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
|
|
order.
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
|
C]`.
|
|
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
|
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
|
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
|
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
|
|
`x`.
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
|
|
# Setup filter kernel.
|
|
if k is None:
|
|
k = [1] * factor
|
|
|
|
# setup kernel
|
|
k = np.asarray(k, dtype=np.float32)
|
|
if k.ndim == 1:
|
|
k = np.outer(k, k)
|
|
k /= np.sum(k)
|
|
|
|
k = k * (gain * (factor**2))
|
|
|
|
if self.use_conv:
|
|
convH = w.shape[2]
|
|
convW = w.shape[3]
|
|
inC = w.shape[1]
|
|
|
|
p = (k.shape[0] - factor) - (convW - 1)
|
|
|
|
stride = (factor, factor)
|
|
# Determine data dimensions.
|
|
stride = [1, 1, factor, factor]
|
|
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
|
output_padding = (
|
|
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
|
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
|
|
)
|
|
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
|
inC = w.shape[1]
|
|
num_groups = x.shape[1] // inC
|
|
|
|
# Transpose weights.
|
|
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
|
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
|
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
|
|
|
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
|
|
|
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
|
else:
|
|
p = k.shape[0] - factor
|
|
x = upfirdn2d_native(
|
|
x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
|
|
)
|
|
|
|
return x
|
|
|
|
def forward(self, x):
|
|
if self.use_conv:
|
|
h = self._upsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
|
|
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
|
else:
|
|
h = self._upsample_2d(x, k=self.fir_kernel, factor=2)
|
|
|
|
return h
|
|
|
|
|
|
class FirDownsample2D(nn.Module):
|
|
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
|
super().__init__()
|
|
out_channels = out_channels if out_channels else channels
|
|
if use_conv:
|
|
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
self.fir_kernel = fir_kernel
|
|
self.use_conv = use_conv
|
|
self.out_channels = out_channels
|
|
|
|
def _downsample_2d(self, x, w=None, k=None, factor=2, gain=1):
|
|
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
|
|
|
Args:
|
|
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
|
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
|
|
order.
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
|
|
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
|
|
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
|
factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
|
|
Scaling factor for signal magnitude (default: 1.0).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
|
datatype as `x`.
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
if k is None:
|
|
k = [1] * factor
|
|
|
|
# setup kernel
|
|
k = np.asarray(k, dtype=np.float32)
|
|
if k.ndim == 1:
|
|
k = np.outer(k, k)
|
|
k /= np.sum(k)
|
|
|
|
k = k * gain
|
|
|
|
if self.use_conv:
|
|
_, _, convH, convW = w.shape
|
|
p = (k.shape[0] - factor) + (convW - 1)
|
|
s = [factor, factor]
|
|
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
|
|
x = F.conv2d(x, w, stride=s, padding=0)
|
|
else:
|
|
p = k.shape[0] - factor
|
|
x = upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
|
|
|
return x
|
|
|
|
def forward(self, x):
|
|
if self.use_conv:
|
|
x = self._downsample_2d(x, w=self.Conv2d_0.weight, k=self.fir_kernel)
|
|
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
|
else:
|
|
x = self._downsample_2d(x, k=self.fir_kernel, factor=2)
|
|
|
|
return x
|
|
|
|
|
|
class ResnetBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
in_channels,
|
|
out_channels=None,
|
|
conv_shortcut=False,
|
|
dropout=0.0,
|
|
temb_channels=512,
|
|
groups=32,
|
|
groups_out=None,
|
|
pre_norm=True,
|
|
eps=1e-6,
|
|
non_linearity="swish",
|
|
time_embedding_norm="default",
|
|
kernel=None,
|
|
output_scale_factor=1.0,
|
|
use_nin_shortcut=None,
|
|
up=False,
|
|
down=False,
|
|
):
|
|
super().__init__()
|
|
self.pre_norm = pre_norm
|
|
self.pre_norm = True
|
|
self.in_channels = in_channels
|
|
out_channels = in_channels if out_channels is None else out_channels
|
|
self.out_channels = out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
self.time_embedding_norm = time_embedding_norm
|
|
self.up = up
|
|
self.down = down
|
|
self.output_scale_factor = output_scale_factor
|
|
|
|
if groups_out is None:
|
|
groups_out = groups
|
|
|
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
|
|
|
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
|
if temb_channels is not None:
|
|
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
|
else:
|
|
self.time_emb_proj = None
|
|
|
|
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
|
if non_linearity == "swish":
|
|
self.nonlinearity = lambda x: F.silu(x)
|
|
elif non_linearity == "mish":
|
|
self.nonlinearity = Mish()
|
|
elif non_linearity == "silu":
|
|
self.nonlinearity = nn.SiLU()
|
|
|
|
self.upsample = self.downsample = None
|
|
if self.up:
|
|
if kernel == "fir":
|
|
fir_kernel = (1, 3, 3, 1)
|
|
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
|
|
elif kernel == "sde_vp":
|
|
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
|
else:
|
|
self.upsample = Upsample2D(in_channels, use_conv=False)
|
|
elif self.down:
|
|
if kernel == "fir":
|
|
fir_kernel = (1, 3, 3, 1)
|
|
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
|
|
elif kernel == "sde_vp":
|
|
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
|
else:
|
|
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
|
|
|
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
|
|
|
self.conv_shortcut = None
|
|
if self.use_nin_shortcut:
|
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
|
|
def forward(self, x, temb, hey=False):
|
|
h = x
|
|
|
|
# make sure hidden states is in float32
|
|
# when running in half-precision
|
|
h = self.norm1(h.float()).type(h.dtype)
|
|
h = self.nonlinearity(h)
|
|
|
|
if self.upsample is not None:
|
|
x = self.upsample(x)
|
|
h = self.upsample(h)
|
|
elif self.downsample is not None:
|
|
x = self.downsample(x)
|
|
h = self.downsample(h)
|
|
|
|
h = self.conv1(h)
|
|
|
|
if temb is not None:
|
|
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
|
h = h + temb
|
|
|
|
# make sure hidden states is in float32
|
|
# when running in half-precision
|
|
h = self.norm2(h.float()).type(h.dtype)
|
|
h = self.nonlinearity(h)
|
|
|
|
h = self.dropout(h)
|
|
h = self.conv2(h)
|
|
|
|
if self.conv_shortcut is not None:
|
|
x = self.conv_shortcut(x)
|
|
|
|
out = (x + h) / self.output_scale_factor
|
|
|
|
return out
|
|
|
|
def set_weight(self, resnet):
|
|
self.norm1.weight.data = resnet.norm1.weight.data
|
|
self.norm1.bias.data = resnet.norm1.bias.data
|
|
|
|
self.conv1.weight.data = resnet.conv1.weight.data
|
|
self.conv1.bias.data = resnet.conv1.bias.data
|
|
|
|
if self.time_emb_proj is not None:
|
|
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
|
|
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
|
|
|
|
self.norm2.weight.data = resnet.norm2.weight.data
|
|
self.norm2.bias.data = resnet.norm2.bias.data
|
|
|
|
self.conv2.weight.data = resnet.conv2.weight.data
|
|
self.conv2.bias.data = resnet.conv2.bias.data
|
|
|
|
if self.use_nin_shortcut:
|
|
self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
|
|
self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
|
|
|
|
|
|
# THE FOLLOWING SHOULD BE DELETED ONCE ALL CHECKPOITNS ARE CONVERTED
|
|
|
|
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
|
|
# => All 2D-Resnets are included here now!
|
|
class ResnetBlock2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
in_channels,
|
|
out_channels=None,
|
|
conv_shortcut=False,
|
|
dropout=0.0,
|
|
temb_channels=512,
|
|
groups=32,
|
|
groups_out=None,
|
|
pre_norm=True,
|
|
eps=1e-6,
|
|
non_linearity="swish",
|
|
time_embedding_norm="default",
|
|
kernel=None,
|
|
output_scale_factor=1.0,
|
|
use_nin_shortcut=None,
|
|
up=False,
|
|
down=False,
|
|
overwrite_for_grad_tts=False,
|
|
overwrite_for_ldm=False,
|
|
overwrite_for_glide=False,
|
|
overwrite_for_score_vde=False,
|
|
):
|
|
super().__init__()
|
|
self.pre_norm = pre_norm
|
|
self.in_channels = in_channels
|
|
out_channels = in_channels if out_channels is None else out_channels
|
|
self.out_channels = out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
self.time_embedding_norm = time_embedding_norm
|
|
self.up = up
|
|
self.down = down
|
|
self.output_scale_factor = output_scale_factor
|
|
|
|
if groups_out is None:
|
|
groups_out = groups
|
|
|
|
if self.pre_norm:
|
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
|
else:
|
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
|
|
|
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
|
if time_embedding_norm == "default" and temb_channels > 0:
|
|
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
|
elif time_embedding_norm == "scale_shift" and temb_channels > 0:
|
|
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
|
|
|
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
|
if non_linearity == "swish":
|
|
self.nonlinearity = lambda x: F.silu(x)
|
|
elif non_linearity == "mish":
|
|
self.nonlinearity = Mish()
|
|
elif non_linearity == "silu":
|
|
self.nonlinearity = nn.SiLU()
|
|
|
|
self.upsample = self.downsample = None
|
|
if self.up:
|
|
if kernel == "fir":
|
|
fir_kernel = (1, 3, 3, 1)
|
|
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
|
|
elif kernel == "sde_vp":
|
|
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
|
else:
|
|
self.upsample = Upsample2D(in_channels, use_conv=False)
|
|
elif self.down:
|
|
if kernel == "fir":
|
|
fir_kernel = (1, 3, 3, 1)
|
|
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
|
|
elif kernel == "sde_vp":
|
|
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
|
else:
|
|
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
|
|
|
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
|
|
|
self.nin_shortcut = None
|
|
if self.use_nin_shortcut:
|
|
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
|
|
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
|
|
self.is_overwritten = False
|
|
self.overwrite_for_glide = overwrite_for_glide
|
|
self.overwrite_for_grad_tts = overwrite_for_grad_tts
|
|
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
|
|
self.overwrite_for_score_vde = overwrite_for_score_vde
|
|
if self.overwrite_for_grad_tts:
|
|
dim = in_channels
|
|
dim_out = out_channels
|
|
time_emb_dim = temb_channels
|
|
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
|
self.pre_norm = pre_norm
|
|
|
|
self.block1 = Block(dim, dim_out, groups=groups)
|
|
self.block2 = Block(dim_out, dim_out, groups=groups)
|
|
if dim != dim_out:
|
|
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
|
|
else:
|
|
self.res_conv = torch.nn.Identity()
|
|
elif self.overwrite_for_ldm:
|
|
channels = in_channels
|
|
emb_channels = temb_channels
|
|
use_scale_shift_norm = False
|
|
non_linearity = "silu"
|
|
|
|
self.in_layers = nn.Sequential(
|
|
normalization(channels, swish=1.0),
|
|
nn.Identity(),
|
|
nn.Conv2d(channels, self.out_channels, 3, padding=1),
|
|
)
|
|
self.emb_layers = nn.Sequential(
|
|
nn.SiLU(),
|
|
linear(
|
|
emb_channels,
|
|
2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels,
|
|
),
|
|
)
|
|
self.out_layers = nn.Sequential(
|
|
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
|
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
|
nn.Dropout(p=dropout),
|
|
zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
|
)
|
|
if self.out_channels == in_channels:
|
|
self.skip_connection = nn.Identity()
|
|
else:
|
|
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
|
|
self.set_weights_ldm()
|
|
elif self.overwrite_for_score_vde:
|
|
in_ch = in_channels
|
|
out_ch = out_channels
|
|
|
|
eps = 1e-6
|
|
num_groups = min(in_ch // 4, 32)
|
|
num_groups_out = min(out_ch // 4, 32)
|
|
temb_dim = temb_channels
|
|
|
|
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
|
|
self.up = up
|
|
self.down = down
|
|
self.Conv_0 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
|
if temb_dim is not None:
|
|
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
|
nn.init.zeros_(self.Dense_0.bias)
|
|
|
|
self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps)
|
|
self.Dropout_0 = nn.Dropout(dropout)
|
|
self.Conv_1 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
|
|
if in_ch != out_ch or up or down:
|
|
# 1x1 convolution with DDPM initialization.
|
|
self.Conv_2 = nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=0)
|
|
|
|
self.in_ch = in_ch
|
|
self.out_ch = out_ch
|
|
self.set_weights_score_vde()
|
|
|
|
def set_weights_grad_tts(self):
|
|
self.conv1.weight.data = self.block1.block[0].weight.data
|
|
self.conv1.bias.data = self.block1.block[0].bias.data
|
|
self.norm1.weight.data = self.block1.block[1].weight.data
|
|
self.norm1.bias.data = self.block1.block[1].bias.data
|
|
|
|
self.conv2.weight.data = self.block2.block[0].weight.data
|
|
self.conv2.bias.data = self.block2.block[0].bias.data
|
|
self.norm2.weight.data = self.block2.block[1].weight.data
|
|
self.norm2.bias.data = self.block2.block[1].bias.data
|
|
|
|
self.temb_proj.weight.data = self.mlp[1].weight.data
|
|
self.temb_proj.bias.data = self.mlp[1].bias.data
|
|
|
|
if self.in_channels != self.out_channels:
|
|
self.nin_shortcut.weight.data = self.res_conv.weight.data
|
|
self.nin_shortcut.bias.data = self.res_conv.bias.data
|
|
|
|
def set_weights_ldm(self):
|
|
self.norm1.weight.data = self.in_layers[0].weight.data
|
|
self.norm1.bias.data = self.in_layers[0].bias.data
|
|
|
|
self.conv1.weight.data = self.in_layers[-1].weight.data
|
|
self.conv1.bias.data = self.in_layers[-1].bias.data
|
|
|
|
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
|
|
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
|
|
|
|
self.norm2.weight.data = self.out_layers[0].weight.data
|
|
self.norm2.bias.data = self.out_layers[0].bias.data
|
|
|
|
self.conv2.weight.data = self.out_layers[-1].weight.data
|
|
self.conv2.bias.data = self.out_layers[-1].bias.data
|
|
|
|
if self.in_channels != self.out_channels:
|
|
self.nin_shortcut.weight.data = self.skip_connection.weight.data
|
|
self.nin_shortcut.bias.data = self.skip_connection.bias.data
|
|
|
|
def set_weights_score_vde(self):
|
|
self.conv1.weight.data = self.Conv_0.weight.data
|
|
self.conv1.bias.data = self.Conv_0.bias.data
|
|
self.norm1.weight.data = self.GroupNorm_0.weight.data
|
|
self.norm1.bias.data = self.GroupNorm_0.bias.data
|
|
|
|
self.conv2.weight.data = self.Conv_1.weight.data
|
|
self.conv2.bias.data = self.Conv_1.bias.data
|
|
self.norm2.weight.data = self.GroupNorm_1.weight.data
|
|
self.norm2.bias.data = self.GroupNorm_1.bias.data
|
|
|
|
self.temb_proj.weight.data = self.Dense_0.weight.data
|
|
self.temb_proj.bias.data = self.Dense_0.bias.data
|
|
|
|
if self.in_channels != self.out_channels or self.up or self.down:
|
|
self.nin_shortcut.weight.data = self.Conv_2.weight.data
|
|
self.nin_shortcut.bias.data = self.Conv_2.bias.data
|
|
|
|
def forward(self, x, temb, hey=False, mask=1.0):
|
|
# TODO(Patrick) eventually this class should be split into multiple classes
|
|
# too many if else statements
|
|
if self.overwrite_for_grad_tts and not self.is_overwritten:
|
|
self.set_weights_grad_tts()
|
|
self.is_overwritten = True
|
|
# elif self.overwrite_for_score_vde and not self.is_overwritten:
|
|
# self.set_weights_score_vde()
|
|
# self.is_overwritten = True
|
|
|
|
# h2 tensor(110029.2109)
|
|
# h3 tensor(49596.9492)
|
|
|
|
h = x
|
|
|
|
h = h * mask
|
|
if self.pre_norm:
|
|
h = self.norm1(h)
|
|
h = self.nonlinearity(h)
|
|
|
|
if self.upsample is not None:
|
|
x = self.upsample(x)
|
|
h = self.upsample(h)
|
|
elif self.downsample is not None:
|
|
x = self.downsample(x)
|
|
h = self.downsample(h)
|
|
|
|
h = self.conv1(h)
|
|
|
|
if not self.pre_norm:
|
|
h = self.norm1(h)
|
|
h = self.nonlinearity(h)
|
|
h = h * mask
|
|
|
|
if temb is not None:
|
|
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
|
else:
|
|
temb = 0
|
|
|
|
if self.time_embedding_norm == "scale_shift":
|
|
scale, shift = torch.chunk(temb, 2, dim=1)
|
|
|
|
h = self.norm2(h)
|
|
h = h + h * scale + shift
|
|
h = self.nonlinearity(h)
|
|
elif self.time_embedding_norm == "default":
|
|
h = h + temb
|
|
h = h * mask
|
|
if self.pre_norm:
|
|
h = self.norm2(h)
|
|
h = self.nonlinearity(h)
|
|
|
|
h = self.dropout(h)
|
|
h = self.conv2(h)
|
|
|
|
if not self.pre_norm:
|
|
h = self.norm2(h)
|
|
h = self.nonlinearity(h)
|
|
h = h * mask
|
|
|
|
x = x * mask
|
|
if self.nin_shortcut is not None:
|
|
x = self.nin_shortcut(x)
|
|
|
|
out = (x + h) / self.output_scale_factor
|
|
|
|
return out
|
|
|
|
|
|
# TODO(Patrick) - just there to convert the weights; can delete afterward
|
|
class Block(torch.nn.Module):
|
|
def __init__(self, dim, dim_out, groups=8):
|
|
super(Block, self).__init__()
|
|
self.block = torch.nn.Sequential(
|
|
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
|
|
)
|
|
|
|
|
|
# HELPER Modules
|
|
|
|
|
|
def normalization(channels, swish=0.0):
|
|
"""
|
|
Make a standard normalization layer, with an optional swish activation.
|
|
|
|
:param channels: number of input channels. :return: an nn.Module for normalization.
|
|
"""
|
|
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
|
|
|
|
|
class GroupNorm32(nn.GroupNorm):
|
|
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
|
|
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
|
|
self.swish = swish
|
|
|
|
def forward(self, x):
|
|
y = super().forward(x.float()).to(x.dtype)
|
|
if self.swish == 1.0:
|
|
y = F.silu(y)
|
|
elif self.swish:
|
|
y = y * F.sigmoid(y * float(self.swish))
|
|
return y
|
|
|
|
|
|
def linear(*args, **kwargs):
|
|
"""
|
|
Create a linear module.
|
|
"""
|
|
return nn.Linear(*args, **kwargs)
|
|
|
|
|
|
def zero_module(module):
|
|
"""
|
|
Zero out the parameters of a module and return it.
|
|
"""
|
|
for p in module.parameters():
|
|
p.detach().zero_()
|
|
return module
|
|
|
|
|
|
class Mish(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x * torch.tanh(torch.nn.functional.softplus(x))
|
|
|
|
|
|
class Conv1dBlock(nn.Module):
|
|
"""
|
|
Conv1d --> GroupNorm --> Mish
|
|
"""
|
|
|
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
|
super().__init__()
|
|
|
|
self.block = nn.Sequential(
|
|
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
|
RearrangeDim(),
|
|
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
|
|
nn.GroupNorm(n_groups, out_channels),
|
|
RearrangeDim(),
|
|
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
|
|
nn.Mish(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class RearrangeDim(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, tensor):
|
|
if len(tensor.shape) == 2:
|
|
return tensor[:, :, None]
|
|
if len(tensor.shape) == 3:
|
|
return tensor[:, :, None, :]
|
|
elif len(tensor.shape) == 4:
|
|
return tensor[:, :, 0, :]
|
|
else:
|
|
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
|
|
|
|
|
def upsample_2d(x, k=None, factor=2, gain=1):
|
|
r"""Upsample2D a batch of 2D images with the given filter.
|
|
|
|
Args:
|
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
|
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
|
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
|
|
multiple of the upsampling factor.
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
|
C]`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
|
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
|
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H * factor, W * factor]`
|
|
"""
|
|
assert isinstance(factor, int) and factor >= 1
|
|
if k is None:
|
|
k = [1] * factor
|
|
|
|
k = np.asarray(k, dtype=np.float32)
|
|
if k.ndim == 1:
|
|
k = np.outer(k, k)
|
|
k /= np.sum(k)
|
|
|
|
k = k * (gain * (factor**2))
|
|
p = k.shape[0] - factor
|
|
return upfirdn2d_native(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
|
|
|
|
|
def downsample_2d(x, k=None, factor=2, gain=1):
|
|
r"""Downsample2D a batch of 2D images with the given filter.
|
|
|
|
Args:
|
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
|
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
|
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
|
shape is a multiple of the downsampling factor.
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
|
C]`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
|
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
|
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H // factor, W // factor]`
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
if k is None:
|
|
k = [1] * factor
|
|
|
|
k = np.asarray(k, dtype=np.float32)
|
|
if k.ndim == 1:
|
|
k = np.outer(k, k)
|
|
k /= np.sum(k)
|
|
|
|
k = k * gain
|
|
p = k.shape[0] - factor
|
|
return upfirdn2d_native(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
|
|
|
|
|
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
|
|
up_x = up_y = up
|
|
down_x = down_y = down
|
|
pad_x0 = pad_y0 = pad[0]
|
|
pad_x1 = pad_y1 = pad[1]
|
|
|
|
_, channel, in_h, in_w = input.shape
|
|
input = input.reshape(-1, in_h, in_w, 1)
|
|
|
|
_, in_h, in_w, minor = input.shape
|
|
kernel_h, kernel_w = kernel.shape
|
|
|
|
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
|
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
|
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
|
|
|
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
|
out = out[
|
|
:,
|
|
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
|
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
|
:,
|
|
]
|
|
|
|
out = out.permute(0, 3, 1, 2)
|
|
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
|
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
|
out = F.conv2d(out, w)
|
|
out = out.reshape(
|
|
-1,
|
|
minor,
|
|
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
|
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
|
)
|
|
out = out.permute(0, 2, 3, 1)
|
|
out = out[:, ::down_y, ::down_x, :]
|
|
|
|
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
|
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
|
|
|
return out.view(-1, channel, out_h, out_w)
|