team-10/venv/Lib/site-packages/transformers/models/csm/modular_csm.py
2025-08-02 02:00:33 +02:00

768 lines
35 KiB
Python

# coding=utf-8
# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from transformers.utils.generic import check_model_inputs
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from ..auto import AutoModel
from ..llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
TransformersKwargs,
)
from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
from .generation_csm import CsmGenerationMixin
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(
custom_intro="""
Base class for the model autoregressive outputs.
"""
)
class CsmOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction) of the depth decoder model.
depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
depth_decoder_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction) of the backbone model.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
depth_decoder_loss: Optional[torch.FloatTensor] = None
depth_decoder_logits: torch.FloatTensor = None
depth_decoder_past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
depth_decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
depth_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
backbone_loss: Optional[torch.FloatTensor] = None
# manually specify names for correct naming when converting from modualr
class CsmRMSNorm(LlamaRMSNorm):
pass
class CsmRotaryEmbedding(LlamaRotaryEmbedding):
pass
class CsmMLP(LlamaMLP):
pass
class CsmAttention(LlamaAttention):
pass
class CsmDecoderLayer(LlamaDecoderLayer):
pass
@auto_docstring(
custom_intro="""
The bare Csm Model outputting raw hidden-states without any specific head on top.
"""
)
@auto_docstring
class CsmPreTrainedModel(PreTrainedModel):
config: CsmConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["CsmDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
# does not because of Mimi codec model
# _supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": CsmDecoderLayer,
"attentions": CsmAttention,
}
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, CsmCodebooksHead):
num_codebooks = module.num_codebooks
for i in range(num_codebooks - 1):
module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range)
@auto_docstring
class CsmDepthDecoderModel(LlamaModel, CsmPreTrainedModel):
config: CsmDepthDecoderConfig
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size)
self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False)
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
r"""
backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
is provided in the `input_ids` argument.
"""
if position_ids is not None and not torch.compiler.is_compiling():
logger.warning_once(
"Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids "
"from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored."
)
position_ids = None
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
if inputs_embeds is None:
codebook_idxs = torch.clamp(cache_position - 1, min=0)
offset = codebook_idxs * self.vocab_size
inputs_embeds = self.embed_tokens(input_ids + offset)
input_ids_are_first_codebook = cache_position[0] == 0
if backbone_last_hidden_state is not None:
inputs_embeds[:, 0] = backbone_last_hidden_state
else:
if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
logger.warning(
"When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
)
inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_ids = cache_position.unsqueeze(0)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
class CsmCodebooksHead(nn.Module):
def __init__(self, hidden_size, num_codebooks, vocab_size):
super().__init__()
self.num_codebooks = num_codebooks
self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size))
def forward(self, hidden_states, cache_position=None):
if cache_position is None:
seq_length = hidden_states.shape[1]
codebook_weight = self.weight[torch.arange(seq_length)]
else:
codebook_idxs = cache_position - 1
codebook_weight = self.weight[codebook_idxs]
hidden_states = [
nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T)
for codebook_idx in range(codebook_weight.shape[0])
]
hidden_states = torch.stack(hidden_states, dim=1)
return hidden_states
@auto_docstring(
custom_intro="""
The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
(e.g. position 0 is the first codebook and uses the first codebook head, etc.)
"""
)
class CsmDepthDecoderForCausalLM(LlamaForCausalLM, GenerationMixin):
_tied_weights_keys = None
_tp_plan = None
_pp_plan = None
def __init__(self, config):
super().__init__(config)
del self.lm_head
self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size)
self.model = CsmDepthDecoderModel(config)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
)
is_first_generation_step = model_inputs["cache_position"][0] == 0
if not is_first_generation_step:
model_inputs.pop("backbone_last_hidden_state")
# csm depth decoder does not use position_ids
model_inputs.pop("position_ids")
return model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
is provided in the `input_ids` argument.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
outputs = self.model(
input_ids=input_ids,
backbone_last_hidden_state=backbone_last_hidden_state,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
if isinstance(logits_to_keep, int):
if logits_to_keep == 0:
# skip idx 0 logits since it's for the concatenated backbone last hidden state
slice_indices = slice(1, None)
else:
slice_indices = slice(-logits_to_keep, None)
else:
slice_indices = logits_to_keep
logits = self.codebooks_head(
hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None
)
logits = logits.contiguous()
loss = None
if labels is not None:
shift_labels = labels[..., 1:].contiguous()
loss = self.loss_function(
logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class CsmBackboneModelEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.hidden_size)
self.register_buffer(
"audio_tokens_offsets", torch.arange(config.num_codebooks) * config.vocab_size, persistent=False
)
def forward(self, input_ids):
input_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets)
input_embeds = input_embeds.sum(dim=2)
return input_embeds
@auto_docstring
class CsmBackboneModel(LlamaModel):
def __init__(self, config):
super().__init__(config)
self.embed_tokens = CsmBackboneModelEmbeddings(config)
@check_model_inputs
@auto_docstring
def forward(self, **super_kwargs):
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
"""
return super().forward(**super_kwargs)
@auto_docstring(
custom_intro="""
The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
"""
)
class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
_tied_weights_keys = [
"backbone_model.embed_tokens.embed_audio_tokens.weight",
"depth_decoder.model.embed_tokens.weight",
]
def __init__(self, config):
super().__init__(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size)
self.backbone_model = CsmBackboneModel._from_config(config)
self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config)
self.codec_model = AutoModel.from_config(config.codec_config)
self.post_init()
def get_input_embeddings(self):
return self.backbone_model.embed_tokens
def set_input_embeddings(self, value):
self.backbone_model.embed_tokens = value
def _tie_weights(self):
if self.config.tie_codebooks_embeddings:
self._tie_or_clone_weights(
self.backbone_model.embed_tokens.embed_audio_tokens,
self.depth_decoder.model.embed_tokens,
)
@classmethod
def from_pretrained(cls, *args, **kwargs):
if kwargs.get("output_loading_info", False):
model, loading_info = super().from_pretrained(*args, **kwargs)
else:
model = super().from_pretrained(*args, **kwargs)
# copy depth decoder generation conf attr to the depth decoder generation config
prefix = "depth_decoder_"
prefix_len = len(prefix)
depth_decoder_attrs = {
attr[prefix_len:]: value
for attr, value in vars(model.generation_config).items()
if attr.startswith(prefix)
}
vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs})
# remove the depth decoder generation conf attr from the model generation config
for attr in depth_decoder_attrs:
delattr(model.generation_config, prefix + attr)
if "output_loading_info" in kwargs:
return model, loading_info
else:
return model
def save_pretrained(self, *args, **kwargs):
# copy the depth decoder generation config attributes to the model generation config
prefix = "depth_decoder_"
depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict()
depth_decoder_attrs.pop("transformers_version", None)
for attr, value in depth_decoder_attrs.items():
setattr(self.generation_config, prefix + attr, value)
super().save_pretrained(*args, **kwargs)
def _merge_input_ids_with_input_values(
self,
input_ids: Optional[torch.Tensor] = None,
input_values: Optional[torch.Tensor] = None,
input_values_cutoffs: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
"""
Merges the input_ids and input_values to produce a single inputs_embeds tensor:
1 - Infers the codec model on the input_values to retreive codebook token.
2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.
Args:
input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The input ids to embed.
input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
The audio input values to embed.
input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
"""
inputs_embeds = self.embed_text_tokens(input_ids)
if input_values is not None:
# infer input_values_mask
input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0))
audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff()
audio_lengths = audio_lengths[audio_lengths > 0]
input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand(
len(audio_lengths), -1
)
input_values_mask = input_values_mask < audio_lengths.unsqueeze(1)
# =======================================
# TODO: @eustlb, this should be batched !!!
# but requires making sure batched inference of the codec model works as intended
with torch.no_grad():
audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
for i in range(batch_input_values_cutoffs.shape[0] - 1):
start_idx = batch_input_values_cutoffs[i]
end_idx = batch_input_values_cutoffs[i + 1]
audio_batch = batch_input_values[..., start_idx:end_idx]
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
audio_tokens_list.append(codebook_ids[0])
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
batched_audio_token_ids = torch.stack(
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
)
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
# =======================================
audio_token_id = self.config.audio_token_id
audio_token_mask = input_ids == audio_token_id
audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids)
inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask]
# same for the audio eos token
audio_eos_frame_ids = (
torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long)
* self.config.codebook_eos_token_id
)
audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1)
audio_eos_token_mask = input_ids == self.config.audio_eos_token_id
inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1)
# if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks)
if labels is not None:
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
# mask depth decoder
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
labels = labels_expanded
return {"inputs_embeds": inputs_embeds, "labels": labels}
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None:
merged_inputs = self._merge_input_ids_with_input_values(
input_ids=input_ids,
input_values=kwargs.get("input_values"),
input_values_cutoffs=kwargs.get("input_values_cutoffs"),
labels=kwargs.get("labels"),
)
model_inputs.update(
{"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None}
)
return model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
input_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
input_values_cutoffs: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, CsmOutputWithPast]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
Requires targeted `input_values` to be provided as audio tokens will be infered from it using the `codec_model`.
- `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
- `-100` will be ignored in the loss computation
- `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
logits_to_keep (`int` or `torch.Tensor`, *optional*):
Kept for compatibility. Does not support another value than:
1. `0`, which is equivalent to keeping all logits, used in the training regime
2. `1`, which is equivalent to keeping only the last logit, used in the generation regime
Example:
```python
>>> import torch
>>> from transformers import CsmForConditionalGeneration, AutoProcessor
>>> from datasets import load_dataset, Audio
>>> model_id = "sesame/csm-1b"
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
>>> # ensure the audio is 24kHz
>>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
>>> conversation = []
>>> # prepare a conversation with text and corresponding audio
>>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
... conversation.append(
... {
... "role": f"{speaker_id}",
... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
... }
... )
>>> inputs = processor.apply_chat_template(
... conversation,
... tokenize=True,
... return_dict=True,
... output_labels=True,
... ).to(torch_device)
>>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
>>> output = model(**inputs)
>>> output.loss.backward()
```"""
if input_ids is not None and input_ids.ndim == 2:
merged_inputs = self._merge_input_ids_with_input_values(
input_ids, input_values, input_values_cutoffs, labels
)
inputs_embeds = merged_inputs["inputs_embeds"]
labels = merged_inputs["labels"]
input_ids = None
backbone_outputs = self.backbone_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
backbone_hidden_states = backbone_outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])
loss = None
backbone_loss = None
depth_decoder_loss = None
depth_decoder_outputs = None
if labels is not None:
# select first codebook as labels for the backbone model
backbone_labels = labels[:, :, 0]
backbone_loss = self.loss_function(
logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
)
# for the depth decoder, we need to select the frames to train on
# those are frames where the label is not uniformly `ignore_index` along the codebook dimension
train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
# add place holder in position 0 that will be replaced by the backbone_last_hidden_state
depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)
train_idxs = train_mask.nonzero(as_tuple=True)
backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
depth_decoder_labels = labels[train_mask]
depth_decoder_outputs = self.depth_decoder(
input_ids=depth_decoder_input_ids,
backbone_last_hidden_state=backbone_last_hidden_states,
use_cache=use_cache,
return_dict=True,
labels=depth_decoder_labels,
**kwargs,
)
depth_decoder_loss = depth_decoder_outputs.loss
loss = backbone_loss + depth_decoder_loss
return CsmOutputWithPast(
loss=loss,
backbone_loss=backbone_loss,
depth_decoder_loss=depth_decoder_loss,
logits=backbone_logits,
past_key_values=backbone_outputs.past_key_values,
hidden_states=backbone_outputs.hidden_states,
attentions=backbone_outputs.attentions,
depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
depth_decoder_past_key_values=depth_decoder_outputs.past_key_values
if depth_decoder_outputs is not None
else None,
depth_decoder_hidden_states=depth_decoder_outputs.hidden_states
if depth_decoder_outputs is not None
else None,
depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
)
__all__ = [
"CsmPreTrainedModel",
"CsmBackboneModel",
"CsmDepthDecoderModel",
"CsmDepthDecoderForCausalLM",
"CsmForConditionalGeneration",
]