769 lines
35 KiB
Python
769 lines
35 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Optional, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from transformers.utils.generic import check_model_inputs
|
||
|
|
||
|
from ...cache_utils import Cache, DynamicCache
|
||
|
from ...generation import GenerationMixin
|
||
|
from ...masking_utils import create_causal_mask
|
||
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...processing_utils import Unpack
|
||
|
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||
|
from ..auto import AutoModel
|
||
|
from ..llama.modeling_llama import (
|
||
|
LlamaAttention,
|
||
|
LlamaDecoderLayer,
|
||
|
LlamaForCausalLM,
|
||
|
LlamaMLP,
|
||
|
LlamaModel,
|
||
|
LlamaRMSNorm,
|
||
|
LlamaRotaryEmbedding,
|
||
|
TransformersKwargs,
|
||
|
)
|
||
|
from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
|
||
|
from .generation_csm import CsmGenerationMixin
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
@auto_docstring(
|
||
|
custom_intro="""
|
||
|
Base class for the model autoregressive outputs.
|
||
|
"""
|
||
|
)
|
||
|
class CsmOutputWithPast(ModelOutput):
|
||
|
r"""
|
||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||
|
Language modeling loss (for next-token prediction).
|
||
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||
|
|
||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||
|
`past_key_values` input) to speed up sequential decoding.
|
||
|
depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||
|
Language modeling loss (for next-token prediction) of the depth decoder model.
|
||
|
depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||
|
Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
|
||
|
depth_decoder_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||
|
depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||
|
depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||
|
sequence_length)`.
|
||
|
backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||
|
Language modeling loss (for next-token prediction) of the backbone model.
|
||
|
"""
|
||
|
|
||
|
loss: Optional[torch.FloatTensor] = None
|
||
|
logits: torch.FloatTensor = None
|
||
|
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
||
|
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
||
|
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
||
|
depth_decoder_loss: Optional[torch.FloatTensor] = None
|
||
|
depth_decoder_logits: torch.FloatTensor = None
|
||
|
depth_decoder_past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
||
|
depth_decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
||
|
depth_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
||
|
backbone_loss: Optional[torch.FloatTensor] = None
|
||
|
|
||
|
|
||
|
# manually specify names for correct naming when converting from modualr
|
||
|
class CsmRMSNorm(LlamaRMSNorm):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class CsmRotaryEmbedding(LlamaRotaryEmbedding):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class CsmMLP(LlamaMLP):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class CsmAttention(LlamaAttention):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class CsmDecoderLayer(LlamaDecoderLayer):
|
||
|
pass
|
||
|
|
||
|
|
||
|
@auto_docstring(
|
||
|
custom_intro="""
|
||
|
The bare Csm Model outputting raw hidden-states without any specific head on top.
|
||
|
"""
|
||
|
)
|
||
|
@auto_docstring
|
||
|
class CsmPreTrainedModel(PreTrainedModel):
|
||
|
config: CsmConfig
|
||
|
base_model_prefix = "model"
|
||
|
supports_gradient_checkpointing = True
|
||
|
_no_split_modules = ["CsmDecoderLayer"]
|
||
|
_skip_keys_device_placement = ["past_key_values"]
|
||
|
_supports_flash_attn = True
|
||
|
_supports_sdpa = True
|
||
|
# does not because of Mimi codec model
|
||
|
# _supports_flex_attn = True
|
||
|
|
||
|
_can_compile_fullgraph = True
|
||
|
_supports_attention_backend = True
|
||
|
_can_record_outputs = {
|
||
|
"hidden_states": CsmDecoderLayer,
|
||
|
"attentions": CsmAttention,
|
||
|
}
|
||
|
|
||
|
def _init_weights(self, module):
|
||
|
super()._init_weights(module)
|
||
|
if isinstance(module, CsmCodebooksHead):
|
||
|
num_codebooks = module.num_codebooks
|
||
|
for i in range(num_codebooks - 1):
|
||
|
module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range)
|
||
|
|
||
|
|
||
|
@auto_docstring
|
||
|
class CsmDepthDecoderModel(LlamaModel, CsmPreTrainedModel):
|
||
|
config: CsmDepthDecoderConfig
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size)
|
||
|
self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False)
|
||
|
|
||
|
@check_model_inputs
|
||
|
@auto_docstring
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_values: Optional[Cache] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
**kwargs: Unpack[TransformersKwargs],
|
||
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
||
|
r"""
|
||
|
backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
|
||
|
The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
|
||
|
is provided in the `input_ids` argument.
|
||
|
"""
|
||
|
if position_ids is not None and not torch.compiler.is_compiling():
|
||
|
logger.warning_once(
|
||
|
"Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids "
|
||
|
"from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored."
|
||
|
)
|
||
|
position_ids = None
|
||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
|
||
|
|
||
|
if use_cache and past_key_values is None:
|
||
|
past_key_values = DynamicCache()
|
||
|
|
||
|
if cache_position is None:
|
||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||
|
inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
|
||
|
device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
|
||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
|
||
|
|
||
|
if inputs_embeds is None:
|
||
|
codebook_idxs = torch.clamp(cache_position - 1, min=0)
|
||
|
offset = codebook_idxs * self.vocab_size
|
||
|
inputs_embeds = self.embed_tokens(input_ids + offset)
|
||
|
|
||
|
input_ids_are_first_codebook = cache_position[0] == 0
|
||
|
if backbone_last_hidden_state is not None:
|
||
|
inputs_embeds[:, 0] = backbone_last_hidden_state
|
||
|
else:
|
||
|
if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
|
||
|
logger.warning(
|
||
|
"When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
|
||
|
)
|
||
|
|
||
|
inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
|
||
|
|
||
|
causal_mask = create_causal_mask(
|
||
|
config=self.config,
|
||
|
input_embeds=inputs_embeds,
|
||
|
attention_mask=attention_mask,
|
||
|
cache_position=cache_position,
|
||
|
past_key_values=past_key_values,
|
||
|
position_ids=position_ids,
|
||
|
)
|
||
|
|
||
|
hidden_states = inputs_embeds
|
||
|
|
||
|
# create position embeddings to be shared across the decoder layers
|
||
|
position_ids = cache_position.unsqueeze(0)
|
||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||
|
|
||
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||
|
hidden_states = decoder_layer(
|
||
|
hidden_states,
|
||
|
attention_mask=causal_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_value=past_key_values,
|
||
|
use_cache=use_cache,
|
||
|
cache_position=cache_position,
|
||
|
position_embeddings=position_embeddings,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
hidden_states = self.norm(hidden_states)
|
||
|
return BaseModelOutputWithPast(
|
||
|
last_hidden_state=hidden_states,
|
||
|
past_key_values=past_key_values if use_cache else None,
|
||
|
)
|
||
|
|
||
|
|
||
|
class CsmCodebooksHead(nn.Module):
|
||
|
def __init__(self, hidden_size, num_codebooks, vocab_size):
|
||
|
super().__init__()
|
||
|
self.num_codebooks = num_codebooks
|
||
|
self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size))
|
||
|
|
||
|
def forward(self, hidden_states, cache_position=None):
|
||
|
if cache_position is None:
|
||
|
seq_length = hidden_states.shape[1]
|
||
|
codebook_weight = self.weight[torch.arange(seq_length)]
|
||
|
else:
|
||
|
codebook_idxs = cache_position - 1
|
||
|
codebook_weight = self.weight[codebook_idxs]
|
||
|
|
||
|
hidden_states = [
|
||
|
nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T)
|
||
|
for codebook_idx in range(codebook_weight.shape[0])
|
||
|
]
|
||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
@auto_docstring(
|
||
|
custom_intro="""
|
||
|
The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
|
||
|
which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
|
||
|
(e.g. position 0 is the first codebook and uses the first codebook head, etc.)
|
||
|
"""
|
||
|
)
|
||
|
class CsmDepthDecoderForCausalLM(LlamaForCausalLM, GenerationMixin):
|
||
|
_tied_weights_keys = None
|
||
|
_tp_plan = None
|
||
|
_pp_plan = None
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
del self.lm_head
|
||
|
self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size)
|
||
|
self.model = CsmDepthDecoderModel(config)
|
||
|
|
||
|
def prepare_inputs_for_generation(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor,
|
||
|
past_key_values: Optional[Cache] = None,
|
||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
model_inputs = super().prepare_inputs_for_generation(
|
||
|
input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
|
||
|
)
|
||
|
|
||
|
is_first_generation_step = model_inputs["cache_position"][0] == 0
|
||
|
if not is_first_generation_step:
|
||
|
model_inputs.pop("backbone_last_hidden_state")
|
||
|
|
||
|
# csm depth decoder does not use position_ids
|
||
|
model_inputs.pop("position_ids")
|
||
|
|
||
|
return model_inputs
|
||
|
|
||
|
@can_return_tuple
|
||
|
@auto_docstring
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||
|
**kwargs: Unpack[TransformersKwargs],
|
||
|
) -> Union[tuple, CausalLMOutputWithPast]:
|
||
|
r"""
|
||
|
backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
|
||
|
The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
|
||
|
is provided in the `input_ids` argument.
|
||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||
|
"""
|
||
|
outputs = self.model(
|
||
|
input_ids=input_ids,
|
||
|
backbone_last_hidden_state=backbone_last_hidden_state,
|
||
|
attention_mask=attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_values=past_key_values,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
use_cache=use_cache,
|
||
|
cache_position=cache_position,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
hidden_states = outputs[0]
|
||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||
|
if isinstance(logits_to_keep, int):
|
||
|
if logits_to_keep == 0:
|
||
|
# skip idx 0 logits since it's for the concatenated backbone last hidden state
|
||
|
slice_indices = slice(1, None)
|
||
|
else:
|
||
|
slice_indices = slice(-logits_to_keep, None)
|
||
|
else:
|
||
|
slice_indices = logits_to_keep
|
||
|
|
||
|
logits = self.codebooks_head(
|
||
|
hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None
|
||
|
)
|
||
|
logits = logits.contiguous()
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
shift_labels = labels[..., 1:].contiguous()
|
||
|
loss = self.loss_function(
|
||
|
logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
|
||
|
)
|
||
|
|
||
|
return CausalLMOutputWithPast(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
past_key_values=outputs.past_key_values,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class CsmBackboneModelEmbeddings(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.hidden_size)
|
||
|
self.register_buffer(
|
||
|
"audio_tokens_offsets", torch.arange(config.num_codebooks) * config.vocab_size, persistent=False
|
||
|
)
|
||
|
|
||
|
def forward(self, input_ids):
|
||
|
input_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets)
|
||
|
input_embeds = input_embeds.sum(dim=2)
|
||
|
return input_embeds
|
||
|
|
||
|
|
||
|
@auto_docstring
|
||
|
class CsmBackboneModel(LlamaModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.embed_tokens = CsmBackboneModelEmbeddings(config)
|
||
|
|
||
|
@check_model_inputs
|
||
|
@auto_docstring
|
||
|
def forward(self, **super_kwargs):
|
||
|
r"""
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
|
||
|
1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
|
||
|
requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
|
||
|
|
||
|
2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
"""
|
||
|
return super().forward(**super_kwargs)
|
||
|
|
||
|
|
||
|
@auto_docstring(
|
||
|
custom_intro="""
|
||
|
The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
|
||
|
"""
|
||
|
)
|
||
|
class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||
|
_tied_weights_keys = [
|
||
|
"backbone_model.embed_tokens.embed_audio_tokens.weight",
|
||
|
"depth_decoder.model.embed_tokens.weight",
|
||
|
]
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.vocab_size = config.vocab_size
|
||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||
|
self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size)
|
||
|
self.backbone_model = CsmBackboneModel._from_config(config)
|
||
|
self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config)
|
||
|
self.codec_model = AutoModel.from_config(config.codec_config)
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.backbone_model.embed_tokens
|
||
|
|
||
|
def set_input_embeddings(self, value):
|
||
|
self.backbone_model.embed_tokens = value
|
||
|
|
||
|
def _tie_weights(self):
|
||
|
if self.config.tie_codebooks_embeddings:
|
||
|
self._tie_or_clone_weights(
|
||
|
self.backbone_model.embed_tokens.embed_audio_tokens,
|
||
|
self.depth_decoder.model.embed_tokens,
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def from_pretrained(cls, *args, **kwargs):
|
||
|
if kwargs.get("output_loading_info", False):
|
||
|
model, loading_info = super().from_pretrained(*args, **kwargs)
|
||
|
else:
|
||
|
model = super().from_pretrained(*args, **kwargs)
|
||
|
|
||
|
# copy depth decoder generation conf attr to the depth decoder generation config
|
||
|
prefix = "depth_decoder_"
|
||
|
prefix_len = len(prefix)
|
||
|
depth_decoder_attrs = {
|
||
|
attr[prefix_len:]: value
|
||
|
for attr, value in vars(model.generation_config).items()
|
||
|
if attr.startswith(prefix)
|
||
|
}
|
||
|
|
||
|
vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs})
|
||
|
|
||
|
# remove the depth decoder generation conf attr from the model generation config
|
||
|
for attr in depth_decoder_attrs:
|
||
|
delattr(model.generation_config, prefix + attr)
|
||
|
|
||
|
if "output_loading_info" in kwargs:
|
||
|
return model, loading_info
|
||
|
else:
|
||
|
return model
|
||
|
|
||
|
def save_pretrained(self, *args, **kwargs):
|
||
|
# copy the depth decoder generation config attributes to the model generation config
|
||
|
prefix = "depth_decoder_"
|
||
|
depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict()
|
||
|
depth_decoder_attrs.pop("transformers_version", None)
|
||
|
for attr, value in depth_decoder_attrs.items():
|
||
|
setattr(self.generation_config, prefix + attr, value)
|
||
|
|
||
|
super().save_pretrained(*args, **kwargs)
|
||
|
|
||
|
def _merge_input_ids_with_input_values(
|
||
|
self,
|
||
|
input_ids: Optional[torch.Tensor] = None,
|
||
|
input_values: Optional[torch.Tensor] = None,
|
||
|
input_values_cutoffs: Optional[torch.Tensor] = None,
|
||
|
labels: Optional[torch.Tensor] = None,
|
||
|
) -> Optional[torch.Tensor]:
|
||
|
"""
|
||
|
Merges the input_ids and input_values to produce a single inputs_embeds tensor:
|
||
|
1 - Infers the codec model on the input_values to retreive codebook token.
|
||
|
2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
|
||
|
3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.
|
||
|
|
||
|
Args:
|
||
|
input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
||
|
The input ids to embed.
|
||
|
input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
|
||
|
The audio input values to embed.
|
||
|
input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
|
||
|
The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
|
||
|
"""
|
||
|
inputs_embeds = self.embed_text_tokens(input_ids)
|
||
|
|
||
|
if input_values is not None:
|
||
|
# infer input_values_mask
|
||
|
input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0))
|
||
|
audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff()
|
||
|
audio_lengths = audio_lengths[audio_lengths > 0]
|
||
|
input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand(
|
||
|
len(audio_lengths), -1
|
||
|
)
|
||
|
input_values_mask = input_values_mask < audio_lengths.unsqueeze(1)
|
||
|
|
||
|
# =======================================
|
||
|
# TODO: @eustlb, this should be batched !!!
|
||
|
# but requires making sure batched inference of the codec model works as intended
|
||
|
with torch.no_grad():
|
||
|
audio_tokens_list = []
|
||
|
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
|
||
|
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
|
||
|
for i in range(batch_input_values_cutoffs.shape[0] - 1):
|
||
|
start_idx = batch_input_values_cutoffs[i]
|
||
|
end_idx = batch_input_values_cutoffs[i + 1]
|
||
|
audio_batch = batch_input_values[..., start_idx:end_idx]
|
||
|
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
|
||
|
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
|
||
|
audio_tokens_list.append(codebook_ids[0])
|
||
|
|
||
|
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
|
||
|
batched_audio_token_ids = torch.stack(
|
||
|
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
|
||
|
)
|
||
|
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
|
||
|
# =======================================
|
||
|
audio_token_id = self.config.audio_token_id
|
||
|
audio_token_mask = input_ids == audio_token_id
|
||
|
|
||
|
audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids)
|
||
|
inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask]
|
||
|
|
||
|
# same for the audio eos token
|
||
|
audio_eos_frame_ids = (
|
||
|
torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long)
|
||
|
* self.config.codebook_eos_token_id
|
||
|
)
|
||
|
audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1)
|
||
|
|
||
|
audio_eos_token_mask = input_ids == self.config.audio_eos_token_id
|
||
|
inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1)
|
||
|
|
||
|
# if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks)
|
||
|
if labels is not None:
|
||
|
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
|
||
|
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
|
||
|
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
|
||
|
# mask depth decoder
|
||
|
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
|
||
|
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
|
||
|
labels = labels_expanded
|
||
|
|
||
|
return {"inputs_embeds": inputs_embeds, "labels": labels}
|
||
|
|
||
|
def prepare_inputs_for_generation(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor,
|
||
|
past_key_values: Optional[Cache] = None,
|
||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
model_inputs = super().prepare_inputs_for_generation(
|
||
|
input_ids=input_ids,
|
||
|
past_key_values=past_key_values,
|
||
|
attention_mask=attention_mask,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
cache_position=cache_position,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None:
|
||
|
merged_inputs = self._merge_input_ids_with_input_values(
|
||
|
input_ids=input_ids,
|
||
|
input_values=kwargs.get("input_values"),
|
||
|
input_values_cutoffs=kwargs.get("input_values_cutoffs"),
|
||
|
labels=kwargs.get("labels"),
|
||
|
)
|
||
|
model_inputs.update(
|
||
|
{"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None}
|
||
|
)
|
||
|
|
||
|
return model_inputs
|
||
|
|
||
|
@can_return_tuple
|
||
|
@auto_docstring
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
input_values: Optional[torch.Tensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
input_values_cutoffs: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||
|
**kwargs: Unpack[TransformersKwargs],
|
||
|
) -> Union[tuple, CsmOutputWithPast]:
|
||
|
r"""
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
|
||
|
1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
|
||
|
requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
|
||
|
|
||
|
2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
|
||
|
Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
|
||
|
If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
|
||
|
where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
|
||
|
the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
|
||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
|
||
|
Requires targeted `input_values` to be provided as audio tokens will be infered from it using the `codec_model`.
|
||
|
- `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
|
||
|
- `-100` will be ignored in the loss computation
|
||
|
- `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
|
||
|
|
||
|
Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
|
||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||
|
Kept for compatibility. Does not support another value than:
|
||
|
1. `0`, which is equivalent to keeping all logits, used in the training regime
|
||
|
2. `1`, which is equivalent to keeping only the last logit, used in the generation regime
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
>>> import torch
|
||
|
>>> from transformers import CsmForConditionalGeneration, AutoProcessor
|
||
|
>>> from datasets import load_dataset, Audio
|
||
|
|
||
|
>>> model_id = "sesame/csm-1b"
|
||
|
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
||
|
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||
|
|
||
|
>>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||
|
>>> # ensure the audio is 24kHz
|
||
|
>>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||
|
|
||
|
>>> conversation = []
|
||
|
>>> # prepare a conversation with text and corresponding audio
|
||
|
>>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
|
||
|
... conversation.append(
|
||
|
... {
|
||
|
... "role": f"{speaker_id}",
|
||
|
... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
|
||
|
... }
|
||
|
... )
|
||
|
|
||
|
>>> inputs = processor.apply_chat_template(
|
||
|
... conversation,
|
||
|
... tokenize=True,
|
||
|
... return_dict=True,
|
||
|
... output_labels=True,
|
||
|
... ).to(torch_device)
|
||
|
|
||
|
>>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||
|
>>> output = model(**inputs)
|
||
|
>>> output.loss.backward()
|
||
|
```"""
|
||
|
if input_ids is not None and input_ids.ndim == 2:
|
||
|
merged_inputs = self._merge_input_ids_with_input_values(
|
||
|
input_ids, input_values, input_values_cutoffs, labels
|
||
|
)
|
||
|
inputs_embeds = merged_inputs["inputs_embeds"]
|
||
|
labels = merged_inputs["labels"]
|
||
|
input_ids = None
|
||
|
|
||
|
backbone_outputs = self.backbone_model(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_values=past_key_values,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
use_cache=use_cache,
|
||
|
cache_position=cache_position,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
backbone_hidden_states = backbone_outputs[0]
|
||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||
|
backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])
|
||
|
|
||
|
loss = None
|
||
|
backbone_loss = None
|
||
|
depth_decoder_loss = None
|
||
|
depth_decoder_outputs = None
|
||
|
if labels is not None:
|
||
|
# select first codebook as labels for the backbone model
|
||
|
backbone_labels = labels[:, :, 0]
|
||
|
backbone_loss = self.loss_function(
|
||
|
logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
|
||
|
)
|
||
|
|
||
|
# for the depth decoder, we need to select the frames to train on
|
||
|
# those are frames where the label is not uniformly `ignore_index` along the codebook dimension
|
||
|
train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
|
||
|
depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
|
||
|
# add place holder in position 0 that will be replaced by the backbone_last_hidden_state
|
||
|
depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)
|
||
|
|
||
|
train_idxs = train_mask.nonzero(as_tuple=True)
|
||
|
backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
|
||
|
depth_decoder_labels = labels[train_mask]
|
||
|
|
||
|
depth_decoder_outputs = self.depth_decoder(
|
||
|
input_ids=depth_decoder_input_ids,
|
||
|
backbone_last_hidden_state=backbone_last_hidden_states,
|
||
|
use_cache=use_cache,
|
||
|
return_dict=True,
|
||
|
labels=depth_decoder_labels,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
depth_decoder_loss = depth_decoder_outputs.loss
|
||
|
loss = backbone_loss + depth_decoder_loss
|
||
|
|
||
|
return CsmOutputWithPast(
|
||
|
loss=loss,
|
||
|
backbone_loss=backbone_loss,
|
||
|
depth_decoder_loss=depth_decoder_loss,
|
||
|
logits=backbone_logits,
|
||
|
past_key_values=backbone_outputs.past_key_values,
|
||
|
hidden_states=backbone_outputs.hidden_states,
|
||
|
attentions=backbone_outputs.attentions,
|
||
|
depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
|
||
|
depth_decoder_past_key_values=depth_decoder_outputs.past_key_values
|
||
|
if depth_decoder_outputs is not None
|
||
|
else None,
|
||
|
depth_decoder_hidden_states=depth_decoder_outputs.hidden_states
|
||
|
if depth_decoder_outputs is not None
|
||
|
else None,
|
||
|
depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
|
||
|
)
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"CsmPreTrainedModel",
|
||
|
"CsmBackboneModel",
|
||
|
"CsmDepthDecoderModel",
|
||
|
"CsmDepthDecoderForCausalLM",
|
||
|
"CsmForConditionalGeneration",
|
||
|
]
|