681 lines
28 KiB
Python
681 lines
28 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Transformers DAC model."""
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ...modeling_utils import PreTrainedAudioTokenizerBase
|
|
from ...utils import ModelOutput, auto_docstring
|
|
from .configuration_dac import DacConfig
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class DacOutput(ModelOutput):
|
|
r"""
|
|
loss (`torch.Tensor`):
|
|
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
|
|
audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
|
|
Reconstructed audio data.
|
|
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
|
Quantized continuous representation of input.
|
|
audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
|
|
Codebook indices for each codebook (quantized discrete representation of input).
|
|
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
|
|
Projected latents (continuous representation of input before quantization).
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
audio_values: Optional[torch.FloatTensor] = None
|
|
quantized_representation: Optional[torch.FloatTensor] = None
|
|
audio_codes: Optional[torch.LongTensor] = None
|
|
projected_latents: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
class DacEncoderOutput(ModelOutput):
|
|
r"""
|
|
loss (`torch.Tensor`):
|
|
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
|
|
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
|
|
Quantized continuous representation of input.
|
|
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
|
|
Codebook indices for each codebook (quantized discrete representation of input).
|
|
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
|
|
Projected latents (continuous representation of input before quantization).
|
|
"""
|
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
quantized_representation: Optional[torch.FloatTensor] = None
|
|
audio_codes: Optional[torch.FloatTensor] = None
|
|
projected_latents: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring
|
|
# Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
|
|
class DacDecoderOutput(ModelOutput):
|
|
r"""
|
|
audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*):
|
|
Decoded audio values, obtained using the decoder part of Dac.
|
|
"""
|
|
|
|
audio_values: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
class Snake1d(nn.Module):
|
|
"""
|
|
A 1-dimensional Snake activation function module.
|
|
"""
|
|
|
|
def __init__(self, hidden_dim):
|
|
super().__init__()
|
|
self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
|
|
|
|
def forward(self, hidden_states):
|
|
shape = hidden_states.shape
|
|
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
|
|
hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
|
|
hidden_states = hidden_states.reshape(shape)
|
|
return hidden_states
|
|
|
|
|
|
class DacVectorQuantize(nn.Module):
|
|
"""
|
|
Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)
|
|
|
|
Additionally uses following tricks from improved VQGAN
|
|
(https://huggingface.co/papers/2110.04627):
|
|
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
|
for improved codebook usage
|
|
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
|
improves training stability
|
|
"""
|
|
|
|
def __init__(self, config: DacConfig):
|
|
super().__init__()
|
|
|
|
self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
|
|
self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
|
|
self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)
|
|
|
|
def forward(self, hidden_state):
|
|
"""
|
|
Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.
|
|
|
|
Args:
|
|
hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
|
|
Input tensor.
|
|
|
|
Returns:
|
|
quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
|
|
Quantized continuous representation of input.
|
|
commitment_loss (`torch.FloatTensor`of shape `(1)`):
|
|
Commitment loss to train encoder to predict vectors closer to codebook entries.
|
|
codebook_loss (`torch.FloatTensor`of shape `(1)`):
|
|
Codebook loss to update the codebook.
|
|
audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
|
|
Codebook indices for each codebook, quantized discrete representation of input.
|
|
projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
|
|
Projected latents (continuous representation of input before quantization).
|
|
"""
|
|
|
|
projected_latents = self.in_proj(hidden_state)
|
|
quantized_representation, audio_codes = self.decode_latents(projected_latents)
|
|
|
|
commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
|
|
codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
|
|
# noop in forward pass, straight-through gradient estimator in backward pass
|
|
quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
|
|
quantized_representation = self.out_proj(quantized_representation)
|
|
|
|
return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents
|
|
|
|
def decode_latents(self, hidden_states):
|
|
batch_size, hidden_dim, sequence_length = hidden_states.shape
|
|
encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
|
|
codebook = self.codebook.weight # codebook: (N x D)
|
|
|
|
# L2 normalize encodings and codebook (ViT-VQGAN)
|
|
encodings = F.normalize(encodings)
|
|
codebook = F.normalize(codebook)
|
|
|
|
# Compute euclidean distance with codebook
|
|
l2_norm = encodings.pow(2).sum(1, keepdim=True)
|
|
dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()
|
|
|
|
indices = dist.max(1)[1]
|
|
indices = indices.reshape(hidden_states.size(0), -1)
|
|
quantized_representation = self.codebook(indices).transpose(1, 2)
|
|
return quantized_representation, indices
|
|
|
|
|
|
class DacResidualUnit(nn.Module):
|
|
"""
|
|
A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
|
|
"""
|
|
|
|
def __init__(self, dimension: int = 16, dilation: int = 1):
|
|
super().__init__()
|
|
pad = ((7 - 1) * dilation) // 2
|
|
|
|
self.snake1 = Snake1d(dimension)
|
|
self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
|
|
self.snake2 = Snake1d(dimension)
|
|
self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
|
|
|
|
def forward(self, hidden_state):
|
|
"""
|
|
Forward pass through the residual unit.
|
|
|
|
Args:
|
|
hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
|
|
Input tensor .
|
|
|
|
Returns:
|
|
output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
|
|
Input tensor after passing through the residual unit.
|
|
"""
|
|
output_tensor = hidden_state
|
|
output_tensor = self.conv1(self.snake1(output_tensor))
|
|
output_tensor = self.conv2(self.snake2(output_tensor))
|
|
|
|
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
|
|
if padding > 0:
|
|
hidden_state = hidden_state[..., padding:-padding]
|
|
output_tensor = hidden_state + output_tensor
|
|
return output_tensor
|
|
|
|
|
|
class DacEncoderBlock(nn.Module):
|
|
"""Encoder block used in DAC encoder."""
|
|
|
|
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
|
|
super().__init__()
|
|
|
|
dimension = config.encoder_hidden_size * 2**stride_index
|
|
self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
|
|
self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
|
|
self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
|
|
self.snake1 = Snake1d(dimension // 2)
|
|
self.conv1 = nn.Conv1d(
|
|
dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
|
|
)
|
|
|
|
def forward(self, hidden_state):
|
|
hidden_state = self.res_unit1(hidden_state)
|
|
hidden_state = self.res_unit2(hidden_state)
|
|
hidden_state = self.snake1(self.res_unit3(hidden_state))
|
|
hidden_state = self.conv1(hidden_state)
|
|
|
|
return hidden_state
|
|
|
|
|
|
class DacDecoderBlock(nn.Module):
|
|
"""Decoder block used in DAC decoder."""
|
|
|
|
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
|
|
super().__init__()
|
|
|
|
input_dim = config.decoder_hidden_size // 2**stride_index
|
|
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
|
|
self.snake1 = Snake1d(input_dim)
|
|
self.conv_t1 = nn.ConvTranspose1d(
|
|
input_dim,
|
|
output_dim,
|
|
kernel_size=2 * stride,
|
|
stride=stride,
|
|
padding=math.ceil(stride / 2),
|
|
)
|
|
|
|
self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
|
|
self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
|
|
self.res_unit3 = DacResidualUnit(output_dim, dilation=9)
|
|
|
|
def forward(self, hidden_state):
|
|
hidden_state = self.snake1(hidden_state)
|
|
hidden_state = self.conv_t1(hidden_state)
|
|
hidden_state = self.res_unit1(hidden_state)
|
|
hidden_state = self.res_unit2(hidden_state)
|
|
hidden_state = self.res_unit3(hidden_state)
|
|
|
|
return hidden_state
|
|
|
|
|
|
class DacResidualVectorQuantize(nn.Module):
|
|
"""
|
|
ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://huggingface.co/papers/2107.03312)
|
|
"""
|
|
|
|
def __init__(self, config: DacConfig):
|
|
super().__init__()
|
|
|
|
n_codebooks = config.n_codebooks
|
|
quantizer_dropout = config.quantizer_dropout
|
|
|
|
self.n_codebooks = n_codebooks
|
|
|
|
self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
|
|
self.quantizer_dropout = quantizer_dropout
|
|
|
|
def forward(self, hidden_state, n_quantizers: Optional[int] = None):
|
|
"""
|
|
Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
|
|
Args:
|
|
hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
|
Input tensor to be quantized.
|
|
n_quantizers (`int`, *optional*):
|
|
Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
|
|
this argument is ignored during training, and a random number of quantizers is used.
|
|
|
|
Returns:
|
|
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
|
Quantized continuous representation of input.
|
|
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
|
|
Codebook indices for each codebook (quantized discrete representation of input).
|
|
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
|
|
Projected latents (continuous representation of input before quantization).
|
|
commitment_loss (`torch.Tensor` of shape `(1)`):
|
|
Commitment loss to train the encoder to predict vectors closer to codebook entries.
|
|
codebook_loss (`torch.Tensor` of shape `(1)`):
|
|
Codebook loss to update the codebook.
|
|
"""
|
|
|
|
quantized_representation = 0
|
|
residual = hidden_state
|
|
commitment_loss = 0
|
|
codebook_loss = 0
|
|
|
|
audio_codes = []
|
|
projected_latents = []
|
|
|
|
n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
|
|
if self.training:
|
|
n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
|
|
dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
|
|
n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
|
|
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
|
n_quantizers = n_quantizers.to(hidden_state.device)
|
|
|
|
for i, quantizer in enumerate(self.quantizers):
|
|
if self.training is False and i >= n_quantizers:
|
|
break
|
|
|
|
quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
|
|
residual
|
|
)
|
|
|
|
# Create mask to apply quantizer dropout
|
|
mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers
|
|
quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
|
|
residual = residual - quantized_representation_i
|
|
|
|
# Sum losses
|
|
commitment_loss += commitment_loss_i * mask
|
|
codebook_loss += codebook_loss_i * mask
|
|
|
|
audio_codes.append(indices_i)
|
|
projected_latents.append(projected_latents_i)
|
|
|
|
audio_codes = torch.stack(audio_codes, dim=1)
|
|
projected_latents = torch.cat(projected_latents, dim=1)
|
|
|
|
return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss
|
|
|
|
def from_codes(self, audio_codes: torch.Tensor):
|
|
"""
|
|
Reconstructs the continuous representation from quantized codes.
|
|
|
|
Args:
|
|
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
|
|
Quantized discrete representation of input.
|
|
|
|
Returns:
|
|
quantized_representation (`torch.Tensor`):
|
|
Quantized continuous representation of input.
|
|
projected_latents (`torch.Tensor`):
|
|
List of projected latents (continuous representations of input before quantization)
|
|
for each codebook.
|
|
audio_codes (`torch.Tensor`):
|
|
Codebook indices for each codebook.
|
|
"""
|
|
quantized_representation = 0.0
|
|
projected_latents = []
|
|
n_codebooks = audio_codes.shape[1]
|
|
for i in range(n_codebooks):
|
|
projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
|
|
projected_latents.append(projected_latents_i)
|
|
quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
|
|
return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes
|
|
|
|
def from_latents(self, latents: torch.Tensor):
|
|
"""Reconstructs the quantized representation from unquantized latents.
|
|
|
|
Args:
|
|
latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
|
|
Continuous representation of input after projection.
|
|
|
|
Returns:
|
|
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
|
Quantized representation of the full-projected space.
|
|
quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
|
Quantized representation of the latent space (continuous representation before quantization).
|
|
"""
|
|
quantized_representation = 0
|
|
quantized_latents = []
|
|
codes = []
|
|
codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
|
|
dims = torch.cumsum(codebook_dims_tensor, dim=0)
|
|
|
|
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
|
|
for i in range(n_codebooks):
|
|
hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
|
|
quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :])
|
|
quantized_latents.append(quantized_latents_i)
|
|
codes.append(codes_i)
|
|
|
|
quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i)
|
|
quantized_representation = quantized_representation + quantized_representation_i
|
|
|
|
return quantized_representation, torch.cat(quantized_latents, dim=1)
|
|
|
|
|
|
class DacDecoder(nn.Module):
|
|
"""DAC Decoder"""
|
|
|
|
def __init__(self, config: DacConfig):
|
|
super().__init__()
|
|
|
|
input_channel = config.hidden_size
|
|
channels = config.decoder_hidden_size
|
|
strides = config.upsampling_ratios
|
|
|
|
# Add first conv layer
|
|
self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)
|
|
|
|
# Add upsampling + MRF blocks
|
|
block = []
|
|
for stride_index, stride in enumerate(strides):
|
|
block += [DacDecoderBlock(config, stride, stride_index)]
|
|
|
|
self.block = nn.ModuleList(block)
|
|
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
|
|
self.snake1 = Snake1d(output_dim)
|
|
self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
|
|
self.tanh = nn.Tanh()
|
|
|
|
def forward(self, hidden_state):
|
|
hidden_state = self.conv1(hidden_state)
|
|
|
|
for layer in self.block:
|
|
hidden_state = layer(hidden_state)
|
|
|
|
hidden_state = self.snake1(hidden_state)
|
|
hidden_state = self.conv2(hidden_state)
|
|
hidden_state = self.tanh(hidden_state)
|
|
|
|
return hidden_state
|
|
|
|
|
|
class DacEncoder(nn.Module):
|
|
"""DAC Encoder"""
|
|
|
|
def __init__(self, config: DacConfig):
|
|
super().__init__()
|
|
|
|
strides = config.downsampling_ratios
|
|
# Create first convolution
|
|
self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
|
|
|
|
self.block = []
|
|
# Create EncoderBlocks that double channels as they downsample by `stride`
|
|
for stride_index, stride in enumerate(strides):
|
|
stride_index = stride_index + 1
|
|
self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]
|
|
|
|
self.block = nn.ModuleList(self.block)
|
|
d_model = config.encoder_hidden_size * 2**stride_index
|
|
self.snake1 = Snake1d(d_model)
|
|
self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
|
|
|
|
def forward(self, hidden_state):
|
|
hidden_state = self.conv1(hidden_state)
|
|
|
|
for module in self.block:
|
|
hidden_state = module(hidden_state)
|
|
|
|
hidden_state = self.snake1(hidden_state)
|
|
hidden_state = self.conv2(hidden_state)
|
|
|
|
return hidden_state
|
|
|
|
|
|
@auto_docstring
|
|
class DacPreTrainedModel(PreTrainedAudioTokenizerBase):
|
|
config: DacConfig
|
|
base_model_prefix = "dac"
|
|
main_input_name = "input_values"
|
|
|
|
def _init_weights(self, module):
|
|
if isinstance(module, nn.Conv1d):
|
|
nn.init.trunc_normal_(module.weight, std=0.02)
|
|
nn.init.constant_(module.bias, 0)
|
|
elif isinstance(module, Snake1d):
|
|
module.alpha.data.fill_(1.0)
|
|
elif isinstance(module, nn.ConvTranspose1d):
|
|
module.reset_parameters()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
|
|
def apply_weight_norm(self):
|
|
weight_norm = nn.utils.weight_norm
|
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
|
weight_norm = nn.utils.parametrizations.weight_norm
|
|
|
|
for layer in self.quantizer.quantizers:
|
|
weight_norm(layer.in_proj)
|
|
weight_norm(layer.out_proj)
|
|
|
|
weight_norm(self.encoder.conv1)
|
|
weight_norm(self.encoder.conv2)
|
|
|
|
for layer in self.encoder.block:
|
|
weight_norm(layer.conv1)
|
|
weight_norm(layer.res_unit1.conv1)
|
|
weight_norm(layer.res_unit1.conv2)
|
|
weight_norm(layer.res_unit2.conv1)
|
|
weight_norm(layer.res_unit2.conv2)
|
|
weight_norm(layer.res_unit3.conv1)
|
|
weight_norm(layer.res_unit3.conv2)
|
|
|
|
weight_norm(self.decoder.conv1)
|
|
weight_norm(self.decoder.conv2)
|
|
|
|
for layer in self.decoder.block:
|
|
weight_norm(layer.conv_t1)
|
|
weight_norm(layer.res_unit1.conv1)
|
|
weight_norm(layer.res_unit1.conv2)
|
|
weight_norm(layer.res_unit2.conv1)
|
|
weight_norm(layer.res_unit2.conv2)
|
|
weight_norm(layer.res_unit3.conv1)
|
|
weight_norm(layer.res_unit3.conv2)
|
|
|
|
def remove_weight_norm(self):
|
|
for layer in self.quantizer.quantizers:
|
|
nn.utils.remove_weight_norm(layer.in_proj)
|
|
nn.utils.remove_weight_norm(layer.out_proj)
|
|
|
|
nn.utils.remove_weight_norm(self.encoder.conv1)
|
|
nn.utils.remove_weight_norm(self.encoder.conv2)
|
|
|
|
for layer in self.encoder.block:
|
|
nn.utils.remove_weight_norm(layer.conv1)
|
|
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
|
|
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
|
|
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
|
|
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
|
|
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
|
|
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
|
|
|
|
nn.utils.remove_weight_norm(self.decoder.conv1)
|
|
nn.utils.remove_weight_norm(self.decoder.conv2)
|
|
|
|
for layer in self.decoder.block:
|
|
nn.utils.remove_weight_norm(layer.conv_t1)
|
|
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
|
|
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
|
|
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
|
|
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
|
|
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
|
|
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The DAC (Descript Audio Codec) model.
|
|
"""
|
|
)
|
|
class DacModel(DacPreTrainedModel):
|
|
def __init__(self, config: DacConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.encoder = DacEncoder(config)
|
|
self.decoder = DacDecoder(config)
|
|
|
|
self.quantizer = DacResidualVectorQuantize(config)
|
|
|
|
self.bits_per_codebook = int(math.log2(self.config.codebook_size))
|
|
if 2**self.bits_per_codebook != self.config.codebook_size:
|
|
raise ValueError("The codebook_size must be a power of 2.")
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def encode(
|
|
self,
|
|
input_values: torch.Tensor,
|
|
n_quantizers: Optional[int] = None,
|
|
return_dict: Optional[bool] = None,
|
|
):
|
|
r"""
|
|
input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
|
|
Input audio data to encode,
|
|
n_quantizers (int, *optional*):
|
|
Number of quantizers to use. If None, all quantizers are used. Default is None.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
quantized_representation = self.encoder(input_values)
|
|
quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
|
|
quantized_representation, n_quantizers
|
|
)
|
|
|
|
loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss
|
|
|
|
if not return_dict:
|
|
return (loss, quantized_representation, audio_codes, projected_latents)
|
|
|
|
return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)
|
|
|
|
@auto_docstring
|
|
def decode(
|
|
self,
|
|
quantized_representation: Optional[torch.Tensor] = None,
|
|
audio_codes: Optional[torch.Tensor] = None,
|
|
return_dict: Optional[bool] = None,
|
|
):
|
|
r"""
|
|
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
|
|
Quantized continuous representation of input.
|
|
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
|
|
The codebook indices for each codebook, representing the quantized discrete
|
|
representation of the input. This parameter should be provided if you want
|
|
to decode directly from the audio codes (it will overwrite quantized_representation).
|
|
"""
|
|
|
|
if quantized_representation is None and audio_codes is None:
|
|
raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
if audio_codes is not None:
|
|
quantized_representation = self.quantizer.from_codes(audio_codes)[0]
|
|
|
|
audio_values = self.decoder(quantized_representation).squeeze(1)
|
|
|
|
if not return_dict:
|
|
return (audio_values,)
|
|
|
|
return DacDecoderOutput(audio_values)
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_values: torch.Tensor,
|
|
n_quantizers: Optional[int] = None,
|
|
return_dict: Optional[bool] = None,
|
|
):
|
|
r"""
|
|
input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`):
|
|
Audio data to encode.
|
|
n_quantizers (`int`, *optional*):
|
|
Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from datasets import load_dataset, Audio
|
|
>>> from transformers import DacModel, AutoProcessor
|
|
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
|
|
>>> model = DacModel.from_pretrained("descript/dac_16khz")
|
|
>>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
|
|
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
|
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
|
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
|
|
|
|
>>> encoder_outputs = model.encode(inputs["input_values"])
|
|
>>> # Get the intermediate audio codes
|
|
>>> audio_codes = encoder_outputs.audio_codes
|
|
>>> # Reconstruct the audio from its quantized representation
|
|
>>> audio_values = model.decode(encoder_outputs.quantized_representation)
|
|
>>> # or the equivalent with a forward pass
|
|
>>> audio_values = model(inputs["input_values"]).audio_values
|
|
```"""
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
length = input_values.shape[-1]
|
|
loss, quantized_representation, audio_codes, projected_latents = self.encode(
|
|
input_values, n_quantizers, return_dict=False
|
|
)
|
|
audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]
|
|
|
|
if not return_dict:
|
|
return (loss, audio_values, quantized_representation, audio_codes, projected_latents)
|
|
|
|
return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)
|
|
|
|
|
|
__all__ = ["DacModel", "DacPreTrainedModel"]
|