team-10/venv/Lib/site-packages/transformers/models/dac/modeling_dac.py
2025-08-02 02:00:33 +02:00

681 lines
28 KiB
Python

# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers DAC model."""
import math
from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modeling_utils import PreTrainedAudioTokenizerBase
from ...utils import ModelOutput, auto_docstring
from .configuration_dac import DacConfig
@dataclass
@auto_docstring
class DacOutput(ModelOutput):
r"""
loss (`torch.Tensor`):
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
Reconstructed audio data.
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Quantized continuous representation of input.
audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
Codebook indices for each codebook (quantized discrete representation of input).
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
Projected latents (continuous representation of input before quantization).
"""
loss: Optional[torch.FloatTensor] = None
audio_values: Optional[torch.FloatTensor] = None
quantized_representation: Optional[torch.FloatTensor] = None
audio_codes: Optional[torch.LongTensor] = None
projected_latents: Optional[torch.FloatTensor] = None
@dataclass
@auto_docstring
class DacEncoderOutput(ModelOutput):
r"""
loss (`torch.Tensor`):
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
Quantized continuous representation of input.
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
Codebook indices for each codebook (quantized discrete representation of input).
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
Projected latents (continuous representation of input before quantization).
"""
loss: Optional[torch.FloatTensor] = None
quantized_representation: Optional[torch.FloatTensor] = None
audio_codes: Optional[torch.FloatTensor] = None
projected_latents: Optional[torch.FloatTensor] = None
@dataclass
@auto_docstring
# Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
class DacDecoderOutput(ModelOutput):
r"""
audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*):
Decoded audio values, obtained using the decoder part of Dac.
"""
audio_values: Optional[torch.FloatTensor] = None
class Snake1d(nn.Module):
"""
A 1-dimensional Snake activation function module.
"""
def __init__(self, hidden_dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
def forward(self, hidden_states):
shape = hidden_states.shape
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
hidden_states = hidden_states.reshape(shape)
return hidden_states
class DacVectorQuantize(nn.Module):
"""
Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)
Additionally uses following tricks from improved VQGAN
(https://huggingface.co/papers/2110.04627):
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
"""
def __init__(self, config: DacConfig):
super().__init__()
self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)
def forward(self, hidden_state):
"""
Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.
Args:
hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
Input tensor.
Returns:
quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
Quantized continuous representation of input.
commitment_loss (`torch.FloatTensor`of shape `(1)`):
Commitment loss to train encoder to predict vectors closer to codebook entries.
codebook_loss (`torch.FloatTensor`of shape `(1)`):
Codebook loss to update the codebook.
audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
Codebook indices for each codebook, quantized discrete representation of input.
projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
Projected latents (continuous representation of input before quantization).
"""
projected_latents = self.in_proj(hidden_state)
quantized_representation, audio_codes = self.decode_latents(projected_latents)
commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
# noop in forward pass, straight-through gradient estimator in backward pass
quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
quantized_representation = self.out_proj(quantized_representation)
return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents
def decode_latents(self, hidden_states):
batch_size, hidden_dim, sequence_length = hidden_states.shape
encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
l2_norm = encodings.pow(2).sum(1, keepdim=True)
dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()
indices = dist.max(1)[1]
indices = indices.reshape(hidden_states.size(0), -1)
quantized_representation = self.codebook(indices).transpose(1, 2)
return quantized_representation, indices
class DacResidualUnit(nn.Module):
"""
A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
"""
def __init__(self, dimension: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dimension)
self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
self.snake2 = Snake1d(dimension)
self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
def forward(self, hidden_state):
"""
Forward pass through the residual unit.
Args:
hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
Input tensor .
Returns:
output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
Input tensor after passing through the residual unit.
"""
output_tensor = hidden_state
output_tensor = self.conv1(self.snake1(output_tensor))
output_tensor = self.conv2(self.snake2(output_tensor))
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
output_tensor = hidden_state + output_tensor
return output_tensor
class DacEncoderBlock(nn.Module):
"""Encoder block used in DAC encoder."""
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
super().__init__()
dimension = config.encoder_hidden_size * 2**stride_index
self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
self.snake1 = Snake1d(dimension // 2)
self.conv1 = nn.Conv1d(
dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
)
def forward(self, hidden_state):
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.snake1(self.res_unit3(hidden_state))
hidden_state = self.conv1(hidden_state)
return hidden_state
class DacDecoderBlock(nn.Module):
"""Decoder block used in DAC decoder."""
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
super().__init__()
input_dim = config.decoder_hidden_size // 2**stride_index
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
self.snake1 = Snake1d(input_dim)
self.conv_t1 = nn.ConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
)
self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
self.res_unit3 = DacResidualUnit(output_dim, dilation=9)
def forward(self, hidden_state):
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.res_unit3(hidden_state)
return hidden_state
class DacResidualVectorQuantize(nn.Module):
"""
ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://huggingface.co/papers/2107.03312)
"""
def __init__(self, config: DacConfig):
super().__init__()
n_codebooks = config.n_codebooks
quantizer_dropout = config.quantizer_dropout
self.n_codebooks = n_codebooks
self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
self.quantizer_dropout = quantizer_dropout
def forward(self, hidden_state, n_quantizers: Optional[int] = None):
"""
Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
Args:
hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Input tensor to be quantized.
n_quantizers (`int`, *optional*):
Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
this argument is ignored during training, and a random number of quantizers is used.
Returns:
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Quantized continuous representation of input.
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
Codebook indices for each codebook (quantized discrete representation of input).
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
Projected latents (continuous representation of input before quantization).
commitment_loss (`torch.Tensor` of shape `(1)`):
Commitment loss to train the encoder to predict vectors closer to codebook entries.
codebook_loss (`torch.Tensor` of shape `(1)`):
Codebook loss to update the codebook.
"""
quantized_representation = 0
residual = hidden_state
commitment_loss = 0
codebook_loss = 0
audio_codes = []
projected_latents = []
n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
if self.training:
n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(hidden_state.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
residual
)
# Create mask to apply quantizer dropout
mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers
quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
residual = residual - quantized_representation_i
# Sum losses
commitment_loss += commitment_loss_i * mask
codebook_loss += codebook_loss_i * mask
audio_codes.append(indices_i)
projected_latents.append(projected_latents_i)
audio_codes = torch.stack(audio_codes, dim=1)
projected_latents = torch.cat(projected_latents, dim=1)
return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss
def from_codes(self, audio_codes: torch.Tensor):
"""
Reconstructs the continuous representation from quantized codes.
Args:
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
Quantized discrete representation of input.
Returns:
quantized_representation (`torch.Tensor`):
Quantized continuous representation of input.
projected_latents (`torch.Tensor`):
List of projected latents (continuous representations of input before quantization)
for each codebook.
audio_codes (`torch.Tensor`):
Codebook indices for each codebook.
"""
quantized_representation = 0.0
projected_latents = []
n_codebooks = audio_codes.shape[1]
for i in range(n_codebooks):
projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
projected_latents.append(projected_latents_i)
quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes
def from_latents(self, latents: torch.Tensor):
"""Reconstructs the quantized representation from unquantized latents.
Args:
latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
Continuous representation of input after projection.
Returns:
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Quantized representation of the full-projected space.
quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Quantized representation of the latent space (continuous representation before quantization).
"""
quantized_representation = 0
quantized_latents = []
codes = []
codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
dims = torch.cumsum(codebook_dims_tensor, dim=0)
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
for i in range(n_codebooks):
hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :])
quantized_latents.append(quantized_latents_i)
codes.append(codes_i)
quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i)
quantized_representation = quantized_representation + quantized_representation_i
return quantized_representation, torch.cat(quantized_latents, dim=1)
class DacDecoder(nn.Module):
"""DAC Decoder"""
def __init__(self, config: DacConfig):
super().__init__()
input_channel = config.hidden_size
channels = config.decoder_hidden_size
strides = config.upsampling_ratios
# Add first conv layer
self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)
# Add upsampling + MRF blocks
block = []
for stride_index, stride in enumerate(strides):
block += [DacDecoderBlock(config, stride, stride_index)]
self.block = nn.ModuleList(block)
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
self.snake1 = Snake1d(output_dim)
self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
self.tanh = nn.Tanh()
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for layer in self.block:
hidden_state = layer(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
hidden_state = self.tanh(hidden_state)
return hidden_state
class DacEncoder(nn.Module):
"""DAC Encoder"""
def __init__(self, config: DacConfig):
super().__init__()
strides = config.downsampling_ratios
# Create first convolution
self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
self.block = []
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride_index, stride in enumerate(strides):
stride_index = stride_index + 1
self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]
self.block = nn.ModuleList(self.block)
d_model = config.encoder_hidden_size * 2**stride_index
self.snake1 = Snake1d(d_model)
self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for module in self.block:
hidden_state = module(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
@auto_docstring
class DacPreTrainedModel(PreTrainedAudioTokenizerBase):
config: DacConfig
base_model_prefix = "dac"
main_input_name = "input_values"
def _init_weights(self, module):
if isinstance(module, nn.Conv1d):
nn.init.trunc_normal_(module.weight, std=0.02)
nn.init.constant_(module.bias, 0)
elif isinstance(module, Snake1d):
module.alpha.data.fill_(1.0)
elif isinstance(module, nn.ConvTranspose1d):
module.reset_parameters()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
for layer in self.quantizer.quantizers:
weight_norm(layer.in_proj)
weight_norm(layer.out_proj)
weight_norm(self.encoder.conv1)
weight_norm(self.encoder.conv2)
for layer in self.encoder.block:
weight_norm(layer.conv1)
weight_norm(layer.res_unit1.conv1)
weight_norm(layer.res_unit1.conv2)
weight_norm(layer.res_unit2.conv1)
weight_norm(layer.res_unit2.conv2)
weight_norm(layer.res_unit3.conv1)
weight_norm(layer.res_unit3.conv2)
weight_norm(self.decoder.conv1)
weight_norm(self.decoder.conv2)
for layer in self.decoder.block:
weight_norm(layer.conv_t1)
weight_norm(layer.res_unit1.conv1)
weight_norm(layer.res_unit1.conv2)
weight_norm(layer.res_unit2.conv1)
weight_norm(layer.res_unit2.conv2)
weight_norm(layer.res_unit3.conv1)
weight_norm(layer.res_unit3.conv2)
def remove_weight_norm(self):
for layer in self.quantizer.quantizers:
nn.utils.remove_weight_norm(layer.in_proj)
nn.utils.remove_weight_norm(layer.out_proj)
nn.utils.remove_weight_norm(self.encoder.conv1)
nn.utils.remove_weight_norm(self.encoder.conv2)
for layer in self.encoder.block:
nn.utils.remove_weight_norm(layer.conv1)
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
nn.utils.remove_weight_norm(self.decoder.conv1)
nn.utils.remove_weight_norm(self.decoder.conv2)
for layer in self.decoder.block:
nn.utils.remove_weight_norm(layer.conv_t1)
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
@auto_docstring(
custom_intro="""
The DAC (Descript Audio Codec) model.
"""
)
class DacModel(DacPreTrainedModel):
def __init__(self, config: DacConfig):
super().__init__(config)
self.config = config
self.encoder = DacEncoder(config)
self.decoder = DacDecoder(config)
self.quantizer = DacResidualVectorQuantize(config)
self.bits_per_codebook = int(math.log2(self.config.codebook_size))
if 2**self.bits_per_codebook != self.config.codebook_size:
raise ValueError("The codebook_size must be a power of 2.")
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
def encode(
self,
input_values: torch.Tensor,
n_quantizers: Optional[int] = None,
return_dict: Optional[bool] = None,
):
r"""
input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
Input audio data to encode,
n_quantizers (int, *optional*):
Number of quantizers to use. If None, all quantizers are used. Default is None.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
quantized_representation = self.encoder(input_values)
quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
quantized_representation, n_quantizers
)
loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss
if not return_dict:
return (loss, quantized_representation, audio_codes, projected_latents)
return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)
@auto_docstring
def decode(
self,
quantized_representation: Optional[torch.Tensor] = None,
audio_codes: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
):
r"""
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
Quantized continuous representation of input.
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
The codebook indices for each codebook, representing the quantized discrete
representation of the input. This parameter should be provided if you want
to decode directly from the audio codes (it will overwrite quantized_representation).
"""
if quantized_representation is None and audio_codes is None:
raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")
return_dict = return_dict if return_dict is not None else self.config.return_dict
if audio_codes is not None:
quantized_representation = self.quantizer.from_codes(audio_codes)[0]
audio_values = self.decoder(quantized_representation).squeeze(1)
if not return_dict:
return (audio_values,)
return DacDecoderOutput(audio_values)
@auto_docstring
def forward(
self,
input_values: torch.Tensor,
n_quantizers: Optional[int] = None,
return_dict: Optional[bool] = None,
):
r"""
input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`):
Audio data to encode.
n_quantizers (`int`, *optional*):
Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
Examples:
```python
>>> from datasets import load_dataset, Audio
>>> from transformers import DacModel, AutoProcessor
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> model = DacModel.from_pretrained("descript/dac_16khz")
>>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
>>> encoder_outputs = model.encode(inputs["input_values"])
>>> # Get the intermediate audio codes
>>> audio_codes = encoder_outputs.audio_codes
>>> # Reconstruct the audio from its quantized representation
>>> audio_values = model.decode(encoder_outputs.quantized_representation)
>>> # or the equivalent with a forward pass
>>> audio_values = model(inputs["input_values"]).audio_values
```"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
length = input_values.shape[-1]
loss, quantized_representation, audio_codes, projected_latents = self.encode(
input_values, n_quantizers, return_dict=False
)
audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]
if not return_dict:
return (loss, audio_values, quantized_representation, audio_codes, projected_latents)
return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)
__all__ = ["DacModel", "DacPreTrainedModel"]