team-10/venv/Lib/site-packages/transformers/models/dia/processing_dia.py
2025-08-02 02:00:33 +02:00

484 lines
20 KiB
Python

# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor class for Dia"""
import math
from pathlib import Path
from typing import Optional, Union
from ...audio_utils import AudioInput, make_list_of_audio
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...utils import is_soundfile_available, is_torch_available
if is_torch_available():
import torch
if is_soundfile_available():
import soundfile as sf
class DiaAudioKwargs(AudioKwargs, total=False):
bos_token_id: int
eos_token_id: int
pad_token_id: int
delay_pattern: list[int]
generation: bool
class DiaProcessorKwargs(ProcessingKwargs, total=False):
audio_kwargs: DiaAudioKwargs
_defaults = {
"text_kwargs": {
"padding": True,
"padding_side": "right",
"add_special_tokens": False,
},
"audio_kwargs": {
"eos_token_id": 1024,
"pad_token_id": 1025,
"bos_token_id": 1026,
"delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
"generation": True,
"sampling_rate": 44100,
},
"common_kwargs": {"return_tensors": "pt"},
}
class DiaProcessor(ProcessorMixin):
r"""
Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into
a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio-
nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more
information.
Args:
feature_extractor (`DiaFeatureExtractor`):
An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input.
tokenizer (`DiaTokenizer`):
An instance of [`DiaTokenizer`]. The tokenizer is a required input.
audio_tokenizer (`DacModel`):
An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input.
"""
feature_extractor_class = "DiaFeatureExtractor"
tokenizer_class = "DiaTokenizer"
audio_tokenizer_class = "DacModel"
def __init__(self, feature_extractor, tokenizer, audio_tokenizer):
super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer)
@property
def model_input_names(self):
"""
We no longer pass the raw audio values but the codebooks encoded by the `audio_tokenizer`.
Conventions may differ between audio models due to architectural choices.
"""
tokenizer_input_names = self.tokenizer.model_input_names
audio_tokenizer_input_names = ["decoder_input_ids", "decoder_attention_mask"]
return list(dict.fromkeys(tokenizer_input_names + audio_tokenizer_input_names))
def __call__(
self,
text: Union[str, list[str]],
audio: Optional[AudioInput] = None,
output_labels: Optional[bool] = False,
**kwargs: Unpack[DiaProcessorKwargs],
):
"""
Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is
forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the
DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer
to the docstring of the above methods for more information.
"""
if not is_torch_available():
raise ValueError(
"The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't "
"find it in your environment. You can install torch via `pip install torch`."
)
if text is None:
raise ValueError("You need to specify the `text` input to process.")
output_kwargs = self._merge_kwargs(
DiaProcessorKwargs,
**kwargs,
)
text_kwargs = output_kwargs["text_kwargs"]
audio_kwargs = output_kwargs["audio_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
return_tensors = common_kwargs.pop("return_tensors", None)
if return_tensors != "pt":
raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
data = {}
# Text
if isinstance(text, str):
text = [text]
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
encodings = self.tokenizer(text, **text_kwargs)
data.update(encodings)
# Audio
delay_pattern = audio_kwargs.pop("delay_pattern", None)
audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
audio_eos_token_id = audio_kwargs.pop("eos_token_id", None)
audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
generation = audio_kwargs.pop("generation", True)
if (
audio_bos_token_id is None
or audio_eos_token_id is None
or audio_pad_token_id is None
or delay_pattern is None
):
raise ValueError(
"To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, "
"`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those."
)
if generation and output_labels:
raise ValueError(
f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}."
)
batch_size = data["input_ids"].shape[0]
num_channels = len(delay_pattern)
max_delay = max(delay_pattern)
# Voice cloning generation / general training
if audio is not None:
audio = make_list_of_audio(audio)
input_audios = self.feature_extractor(audio, **audio_kwargs)
compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios)
max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate
decoder_input_ids = []
decoder_attention_mask = []
# TODO: dac with batching is currently broken, but non-batch is working
# refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script
for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]):
# get current length with hop length in mind (as if it were sampled as a single audio)
base_pad_len = self.feature_extractor.hop_length
current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len
encoded_sequence_len = current_audio_len // compression_rate
padding_len = max_encoded_sequence_len - encoded_sequence_len
# compute non-padded forward pass; one extra bos (and eos if training) is added
with torch.no_grad():
audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device)
input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2)
if not generation:
input_ids = torch.nn.functional.pad(
input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id
)
# apply padding
# +1 for the bos within the real sequence
input_ids = torch.nn.functional.pad(
input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id
)
num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay
num_valid_inputs += 0 if generation else 1 # eos if training
attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :]
decoder_input_ids.append(input_ids)
decoder_attention_mask.append(attention_mask)
decoder_input_ids = torch.cat(decoder_input_ids, dim=0)
decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0)
# TTS generation
elif generation:
# all bos to start with TTS
decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long)
# we preemptively add the delay
decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long)
else:
raise ValueError("If you try to train, you should provide audio data as well.")
if batch_size != decoder_input_ids.shape[0]:
raise ValueError(
f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and "
f"audio samples = {decoder_input_ids.shape[0]} instead."
)
# prepare shift indices per delay
max_seq_len = decoder_attention_mask.shape[-1]
max_audio_len = max_seq_len - max_delay
precomputed_idx = self.build_indices(
bsz=batch_size,
seq_len=max_seq_len,
num_channels=num_channels,
delay_pattern=delay_pattern,
revert=False,
)
# create delay pattern input
# the pad token will be used for masking which input is valid for prediction during generation
prefill = torch.full(
(batch_size, max_seq_len, num_channels),
fill_value=audio_pad_token_id,
dtype=torch.int,
)
prefill[:, :max_audio_len] = decoder_input_ids
delayed_decoder_input_ids = self.apply_audio_delay(
audio=prefill,
pad_token_id=audio_pad_token_id,
bos_token_id=audio_bos_token_id,
precomputed_idx=precomputed_idx,
)
data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask})
if output_labels:
# Base idea is to shift on the sequence dim
labels = data["decoder_input_ids"].clone()[:, 1:]
labels[labels == audio_pad_token_id] = -100
labels[labels == audio_bos_token_id] = -100
data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long()
data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1]
data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1]
return BatchFeature(data=data, tensor_type=return_tensors)
def batch_decode(
self,
decoder_input_ids: "torch.Tensor",
audio_prompt_len: Optional[int] = None,
**kwargs: Unpack[DiaProcessorKwargs],
) -> list["torch.Tensor"]:
"""
Decodes a batch of audio codebook sequences into their respective audio waveforms via the
`audio_tokenizer`. See [`~DacModel.decode`] for more information.
Args:
decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
"""
output_kwargs = self._merge_kwargs(
DiaProcessorKwargs,
**kwargs,
)
audio_kwargs = output_kwargs["audio_kwargs"]
delay_pattern = audio_kwargs.pop("delay_pattern", None)
audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None:
raise ValueError(
"To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, "
"and `delay_pattern`. You may have accidentally overwritten one of those."
)
# either decode the whole audio sequence or only the generated parts
if audio_prompt_len is not None:
audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long)
start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0])
else:
start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1)
# -1 for the eos token
end_of_generation_idx = (
decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1
)
# revert delay
bsz, seq_len, num_channels = decoder_input_ids.shape
precomputed_idx = self.build_indices(
bsz=bsz,
seq_len=seq_len,
num_channels=num_channels,
delay_pattern=delay_pattern,
revert=True,
)
output_sequences = self.apply_audio_delay(
audio=decoder_input_ids,
# We do not care about these values as we cut them out
# with `start_of_generation_idx` and `end_of_generation_idx`
pad_token_id=-1,
bos_token_id=-1,
precomputed_idx=precomputed_idx,
).transpose(1, 2)
# retrieve the correct sequences each
audios = []
# TODO: see above, dac doesn't work in batches yet
with torch.no_grad():
for i in range(start_of_generation_idx.shape[0]):
output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...]
output_i = output_i.to(self.audio_tokenizer.device)
audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze()
audios.append(audio_i)
return audios
def decode(
self,
decoder_input_ids: "torch.Tensor",
audio_prompt_len: Optional[int] = None,
**kwargs: Unpack[DiaProcessorKwargs],
) -> "torch.Tensor":
"""
Decodes a single sequence of audio codebooks into the respective audio waveform via the
`audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
"""
if decoder_input_ids.shape[0] != 1:
raise ValueError(
f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead."
)
return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0]
def get_audio_prompt_len(
self,
decoder_attention_mask: "torch.Tensor",
**kwargs: Unpack[DiaProcessorKwargs],
) -> int:
"""Utility function to get the audio prompt length."""
output_kwargs = self._merge_kwargs(
DiaProcessorKwargs,
**kwargs,
)
audio_kwargs = output_kwargs["audio_kwargs"]
delay_pattern = audio_kwargs.pop("delay_pattern", None)
if delay_pattern is None:
raise ValueError(
"To enable the utility of retrieving the prompt length for Dia, we need the "
"`delay_pattern`. You may have accidentally overwritten this."
)
return decoder_attention_mask.shape[1] - max(delay_pattern)
# Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia
def save_audio(
self,
audio: AudioInput,
saving_path: Union[str, Path, list[Union[str, Path]]],
**kwargs: Unpack[DiaProcessorKwargs],
):
# TODO: @eustlb, this should be in AudioProcessor
if not is_soundfile_available():
raise ImportError("Please install `soundfile` to save audio files.")
# ensure correct audio input
audio = make_list_of_audio(audio)
# ensure correct saving path
if isinstance(saving_path, (str, Path)):
saving_path = [saving_path]
elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
raise ValueError("Invalid input path. Please provide a string, or a list of strings")
if len(audio) != len(saving_path):
raise ValueError("The number of audio and saving paths must be the same")
output_kwargs = self._merge_kwargs(
DiaProcessorKwargs,
**kwargs,
)
audio_kwargs = output_kwargs["audio_kwargs"]
sampling_rate = audio_kwargs["sampling_rate"]
for audio_value, p in zip(audio, saving_path):
if isinstance(audio_value, torch.Tensor):
audio_value = audio_value.cpu().float().numpy()
sf.write(p, audio_value, sampling_rate)
@staticmethod
def build_indices(
bsz: int,
seq_len: int,
num_channels: int,
delay_pattern: list[int],
revert: bool = False,
) -> tuple["torch.Tensor", "torch.Tensor"]:
"""
Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
"""
delay_array = torch.tensor(delay_pattern, dtype=torch.int32)
# (0..seq_len-1)
sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None]
# + or - delay depending if we delay or revert the delay
if not revert:
sequence_idx = sequence_idx - delay_array[None, None, :]
else:
sequence_idx = sequence_idx + delay_array[None, None, :]
# if delay goes over the range we clamp back to valid values
valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1)
batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels)
channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels)
all_idx = torch.stack(
[batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)],
dim=1,
).long()
return sequence_idx, all_idx
@staticmethod
def apply_audio_delay(
audio: "torch.Tensor",
pad_token_id: int,
bos_token_id: int,
precomputed_idx: tuple["torch.Tensor", "torch.Tensor"],
) -> "torch.Tensor":
"""
Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.
Args:
audio: audio tokens of shape [bsz, seq_len, num_channels]
pad_token_id: the PAD token
bos_token_id: the BOS token
precomputed_idx: from `build_indices`
Returns:
final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
"""
# Move everything to the same device
device = audio.device
sequence_idx, all_idx = precomputed_idx
sequence_idx = sequence_idx.to(device)
all_idx = all_idx.to(device)
# Gather per precomputed indices
batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1)
gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size())
# Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD
mask_bos = sequence_idx < 0
mask_pad = sequence_idx >= audio.shape[1]
final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio))
return final_audio
__all__ = ["DiaProcessor"]