434 lines
17 KiB
Python
434 lines
17 KiB
Python
import math
|
|
from inspect import isfunction
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
|
|
class AttentionBlockNew(nn.Module):
|
|
"""
|
|
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
|
to the N-d case.
|
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
Uses three q, k, v linear layers to compute attention
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
num_head_channels=None,
|
|
num_groups=32,
|
|
rescale_output_factor=1.0,
|
|
eps=1e-5,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
|
|
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
|
self.num_head_size = num_head_channels
|
|
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
|
|
|
# define q,k,v as linear layers
|
|
self.query = nn.Linear(channels, channels)
|
|
self.key = nn.Linear(channels, channels)
|
|
self.value = nn.Linear(channels, channels)
|
|
|
|
self.rescale_output_factor = rescale_output_factor
|
|
self.proj_attn = nn.Linear(channels, channels, 1)
|
|
|
|
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
|
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
|
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
|
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
|
return new_projection
|
|
|
|
def forward(self, hidden_states):
|
|
residual = hidden_states
|
|
batch, channel, height, width = hidden_states.shape
|
|
|
|
# norm
|
|
hidden_states = self.group_norm(hidden_states)
|
|
|
|
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
|
|
|
# proj to q, k, v
|
|
query_proj = self.query(hidden_states)
|
|
key_proj = self.key(hidden_states)
|
|
value_proj = self.value(hidden_states)
|
|
|
|
# transpose
|
|
query_states = self.transpose_for_scores(query_proj)
|
|
key_states = self.transpose_for_scores(key_proj)
|
|
value_states = self.transpose_for_scores(value_proj)
|
|
|
|
# get scores
|
|
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
|
|
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
|
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
|
|
|
# compute attention output
|
|
context_states = torch.matmul(attention_probs, value_states)
|
|
|
|
context_states = context_states.permute(0, 2, 1, 3).contiguous()
|
|
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
|
|
context_states = context_states.view(new_context_states_shape)
|
|
|
|
# compute next hidden_states
|
|
hidden_states = self.proj_attn(context_states)
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
|
|
|
# res connect and rescale
|
|
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
|
return hidden_states
|
|
|
|
def set_weight(self, attn_layer):
|
|
self.group_norm.weight.data = attn_layer.norm.weight.data
|
|
self.group_norm.bias.data = attn_layer.norm.bias.data
|
|
|
|
if hasattr(attn_layer, "q"):
|
|
self.query.weight.data = attn_layer.q.weight.data[:, :, 0, 0]
|
|
self.key.weight.data = attn_layer.k.weight.data[:, :, 0, 0]
|
|
self.value.weight.data = attn_layer.v.weight.data[:, :, 0, 0]
|
|
|
|
self.query.bias.data = attn_layer.q.bias.data
|
|
self.key.bias.data = attn_layer.k.bias.data
|
|
self.value.bias.data = attn_layer.v.bias.data
|
|
|
|
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
|
|
self.proj_attn.bias.data = attn_layer.proj_out.bias.data
|
|
elif hasattr(attn_layer, "NIN_0"):
|
|
self.query.weight.data = attn_layer.NIN_0.W.data.T
|
|
self.key.weight.data = attn_layer.NIN_1.W.data.T
|
|
self.value.weight.data = attn_layer.NIN_2.W.data.T
|
|
|
|
self.query.bias.data = attn_layer.NIN_0.b.data
|
|
self.key.bias.data = attn_layer.NIN_1.b.data
|
|
self.value.bias.data = attn_layer.NIN_2.b.data
|
|
|
|
self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T
|
|
self.proj_attn.bias.data = attn_layer.NIN_3.b.data
|
|
|
|
self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data
|
|
self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data
|
|
else:
|
|
qkv_weight = attn_layer.qkv.weight.data.reshape(
|
|
self.num_heads, 3 * self.channels // self.num_heads, self.channels
|
|
)
|
|
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads)
|
|
|
|
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1)
|
|
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1)
|
|
|
|
self.query.weight.data = q_w.reshape(-1, self.channels)
|
|
self.key.weight.data = k_w.reshape(-1, self.channels)
|
|
self.value.weight.data = v_w.reshape(-1, self.channels)
|
|
|
|
self.query.bias.data = q_b.reshape(-1)
|
|
self.key.bias.data = k_b.reshape(-1)
|
|
self.value.bias.data = v_b.reshape(-1)
|
|
|
|
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
|
|
self.proj_attn.bias.data = attn_layer.proj.bias.data
|
|
|
|
|
|
class SpatialTransformer(nn.Module):
|
|
"""
|
|
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
|
standard transformer action. Finally, reshape to image
|
|
"""
|
|
|
|
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
self.d_head = d_head
|
|
self.in_channels = in_channels
|
|
inner_dim = n_heads * d_head
|
|
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
|
|
|
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
|
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
|
for d in range(depth)
|
|
]
|
|
)
|
|
|
|
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
|
|
|
def forward(self, x, context=None):
|
|
# note: if no context is given, cross-attention defaults to self-attention
|
|
b, c, h, w = x.shape
|
|
x_in = x
|
|
x = self.norm(x)
|
|
x = self.proj_in(x)
|
|
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
|
for block in self.transformer_blocks:
|
|
x = block(x, context=context)
|
|
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
|
x = self.proj_out(x)
|
|
return x + x_in
|
|
|
|
def set_weight(self, layer):
|
|
self.norm = layer.norm
|
|
self.proj_in = layer.proj_in
|
|
self.transformer_blocks = layer.transformer_blocks
|
|
self.proj_out = layer.proj_out
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module):
|
|
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
|
super().__init__()
|
|
self.attn1 = CrossAttention(
|
|
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
) # is a self-attention
|
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
|
self.attn2 = CrossAttention(
|
|
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
) # is self-attn if context is none
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
self.norm3 = nn.LayerNorm(dim)
|
|
self.checkpoint = checkpoint
|
|
|
|
def forward(self, x, context=None):
|
|
x = self.attn1(self.norm1(x)) + x
|
|
x = self.attn2(self.norm2(x), context=context) + x
|
|
x = self.ff(self.norm3(x)) + x
|
|
return x
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
context_dim = default(context_dim, query_dim)
|
|
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
|
|
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
|
|
|
def reshape_heads_to_batch_dim(self, tensor):
|
|
batch_size, seq_len, dim = tensor.shape
|
|
head_size = self.heads
|
|
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
|
return tensor
|
|
|
|
def reshape_batch_dim_to_heads(self, tensor):
|
|
batch_size, seq_len, dim = tensor.shape
|
|
head_size = self.heads
|
|
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
|
return tensor
|
|
|
|
def forward(self, x, context=None, mask=None):
|
|
batch_size, sequence_length, dim = x.shape
|
|
|
|
h = self.heads
|
|
|
|
q = self.to_q(x)
|
|
context = default(context, x)
|
|
k = self.to_k(context)
|
|
v = self.to_v(context)
|
|
|
|
q = self.reshape_heads_to_batch_dim(q)
|
|
k = self.reshape_heads_to_batch_dim(k)
|
|
v = self.reshape_heads_to_batch_dim(v)
|
|
|
|
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
|
|
|
if exists(mask):
|
|
mask = mask.reshape(batch_size, -1)
|
|
max_neg_value = -torch.finfo(sim.dtype).max
|
|
mask = mask[:, None, :].repeat(h, 1, 1)
|
|
sim.masked_fill_(~mask, max_neg_value)
|
|
|
|
# attention, what we cannot get enough of
|
|
attn = sim.softmax(dim=-1)
|
|
|
|
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
|
out = self.reshape_batch_dim_to_heads(out)
|
|
return self.to_out(out)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
dim_out = default(dim_out, dim)
|
|
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
|
|
|
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
# feedforward
|
|
class GEGLU(nn.Module):
|
|
def __init__(self, dim_in, dim_out):
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
|
|
|
def forward(self, x):
|
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
return x * F.gelu(gate)
|
|
|
|
|
|
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
|
|
class NIN(nn.Module):
|
|
def __init__(self, in_dim, num_units, init_scale=0.1):
|
|
super().__init__()
|
|
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
|
|
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def default(val, d):
|
|
if exists(val):
|
|
return val
|
|
return d() if isfunction(d) else d
|
|
|
|
|
|
# the main attention block that is used for all models
|
|
class AttentionBlock(nn.Module):
|
|
"""
|
|
An attention block that allows spatial positions to attend to each other.
|
|
|
|
Originally ported from here, but adapted to the N-d case.
|
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
num_heads=1,
|
|
num_head_channels=None,
|
|
num_groups=32,
|
|
encoder_channels=None,
|
|
overwrite_qkv=False,
|
|
overwrite_linear=False,
|
|
rescale_output_factor=1.0,
|
|
eps=1e-5,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
if num_head_channels is None:
|
|
self.num_heads = num_heads
|
|
else:
|
|
assert (
|
|
channels % num_head_channels == 0
|
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
self.num_heads = channels // num_head_channels
|
|
|
|
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
|
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
|
self.n_heads = self.num_heads
|
|
self.rescale_output_factor = rescale_output_factor
|
|
|
|
if encoder_channels is not None:
|
|
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
|
|
|
|
self.proj = nn.Conv1d(channels, channels, 1)
|
|
|
|
self.overwrite_qkv = overwrite_qkv
|
|
self.overwrite_linear = overwrite_linear
|
|
|
|
if overwrite_qkv:
|
|
in_channels = channels
|
|
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
|
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
elif self.overwrite_linear:
|
|
num_groups = min(channels // 4, 32)
|
|
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
|
|
self.NIN_0 = NIN(channels, channels)
|
|
self.NIN_1 = NIN(channels, channels)
|
|
self.NIN_2 = NIN(channels, channels)
|
|
self.NIN_3 = NIN(channels, channels)
|
|
|
|
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
|
|
else:
|
|
self.proj_out = nn.Conv1d(channels, channels, 1)
|
|
self.set_weights(self)
|
|
|
|
self.is_overwritten = False
|
|
|
|
def set_weights(self, module):
|
|
if self.overwrite_qkv:
|
|
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
|
|
:, :, :, 0
|
|
]
|
|
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
|
|
|
|
self.qkv.weight.data = qkv_weight
|
|
self.qkv.bias.data = qkv_bias
|
|
|
|
proj_out = nn.Conv1d(self.channels, self.channels, 1)
|
|
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
|
|
proj_out.bias.data = module.proj_out.bias.data
|
|
|
|
self.proj = proj_out
|
|
elif self.overwrite_linear:
|
|
self.qkv.weight.data = torch.concat(
|
|
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
|
|
)[:, :, None]
|
|
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
|
|
|
|
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
|
|
self.proj.bias.data = self.NIN_3.b.data
|
|
|
|
self.norm.weight.data = self.GroupNorm_0.weight.data
|
|
self.norm.bias.data = self.GroupNorm_0.bias.data
|
|
else:
|
|
self.proj.weight.data = self.proj_out.weight.data
|
|
self.proj.bias.data = self.proj_out.bias.data
|
|
|
|
def forward(self, x, encoder_out=None):
|
|
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
|
|
self.set_weights(self)
|
|
self.is_overwritten = True
|
|
|
|
b, c, *spatial = x.shape
|
|
hid_states = self.norm(x).view(b, c, -1)
|
|
|
|
qkv = self.qkv(hid_states)
|
|
bs, width, length = qkv.shape
|
|
assert width % (3 * self.n_heads) == 0
|
|
ch = width // (3 * self.n_heads)
|
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
|
|
|
if encoder_out is not None:
|
|
encoder_kv = self.encoder_kv(encoder_out)
|
|
assert encoder_kv.shape[1] == self.n_heads * ch * 2
|
|
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
|
|
k = torch.cat([ek, k], dim=-1)
|
|
v = torch.cat([ev, v], dim=-1)
|
|
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
|
|
a = torch.einsum("bts,bcs->bct", weight, v)
|
|
h = a.reshape(bs, -1, length)
|
|
|
|
h = self.proj(h)
|
|
h = h.reshape(b, c, *spatial)
|
|
|
|
result = x + h
|
|
|
|
result = result / self.rescale_output_factor
|
|
|
|
return result
|