# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union import torch import torch.nn as nn from ...generation import ( GenerateDecoderOnlyOutput, GenerationConfig, GenerationMixin, GenerationMode, ) from ...generation.logits_process import LogitsProcessorList from ...generation.stopping_criteria import MaxLengthCriteria, StoppingCriteriaList from ...generation.utils import GenerateNonBeamOutput from ...utils import logging if TYPE_CHECKING: from ...generation.streamers import BaseStreamer logger = logging.get_logger(__name__) @dataclass class CsmGenerateOutput(GenerateDecoderOnlyOutput): """ Outputs of CsmForConditionalGeneration.generate. Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): Returns the model cache, used to speed up decoding. Different models have a different cache format, check audio (`list(torch.FloatTensor)` of length `batch_size`): The generated audio. """ audio: Optional[list[torch.Tensor]] = None class CsmGenerationMixin(GenerationMixin): def _get_stopping_criteria( self, *args, **kwargs, ) -> StoppingCriteriaList: criteria = super()._get_stopping_criteria(*args, **kwargs) kept_criteria = StoppingCriteriaList() for criterion in criteria: if not isinstance(criterion, MaxLengthCriteria): logger.warning( f"Csm does not support {criterion.__class__.__name__} stopping criteria, it will be ignored." ) else: kept_criteria.append(criterion) return kept_criteria def _prepare_generation_config( self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict ) -> tuple[GenerationConfig, dict]: """ This method overrides [~generation.utils.GenerationMixin._prepare_generation_config]. It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled. """ # extract depth decoder kwargs and remove them from the main kwargs depth_decoder_kwargs = { k[len("depth_decoder_") :]: v for k, v in kwargs.items() if k.startswith("depth_decoder_") } # remove the depth decoder keys from the original kwargs kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")} # initialize the generation config generation_config, model_kwargs = super()._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) self.depth_decoder.generation_config.update(**depth_decoder_kwargs) # ensure the depth decoder generation config is valid depth_decoder_min_new_tokens = getattr(self.depth_decoder.generation_config, "min_new_tokens") or ( self.config.num_codebooks - 1 ) depth_decoder_max_new_tokens = getattr(self.depth_decoder.generation_config, "max_new_tokens") or ( self.config.num_codebooks - 1 ) if {depth_decoder_min_new_tokens, depth_decoder_max_new_tokens} != {self.config.num_codebooks - 1}: raise ValueError( f"depth_decoder_generation_config's min_new_tokens ({depth_decoder_min_new_tokens}) and max_new_tokens ({depth_decoder_max_new_tokens}) must be equal to self.config.num_codebooks - 1 ({self.config.num_codebooks - 1})" ) elif self.depth_decoder.generation_config.return_dict_in_generate: logger.warning( "depth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generate" ) self.depth_decoder.generation_config.return_dict_in_generate = False self.depth_decoder.generation_config.min_new_tokens = depth_decoder_min_new_tokens self.depth_decoder.generation_config.max_new_tokens = depth_decoder_max_new_tokens # Monkey patch the get_generation_mode method to support CSM model original_get_generation_mode = generation_config.get_generation_mode def patched_get_generation_mode(assistant_model=None): generation_mode = original_get_generation_mode(assistant_model) if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]: raise ValueError( f"Generation mode {generation_mode} is not supported for CSM model. Please set generation parameters to use greedy or sampling generation." ) return generation_mode generation_config.get_generation_mode = patched_get_generation_mode return generation_config, model_kwargs def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: """ This method overrides [~generation.utils.GenerationMixin._sample]. To ease maintenance, modifications are marked with the comment "Csm specific". Indeed, Csm model requires a custom generation sampling step: 1. Infer the backbone model to sample the first codebook token 2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens 3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model 4. Repeat until stopping criteria is met Csm supports two stopping criterias: - stop when the generated sequence is at max_length - stop when all the generated codebook tokens are the codebook_eos_token_id """ # init values # *************** Csm specific *************** pad_token_id = self.config.codebook_pad_token_id has_eos_stopping_criteria = generation_config._eos_token_tensor is not None # ============================================ output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate do_sample = generation_config.do_sample # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # keep track of which sequences are already finished batch_size, cur_len = input_ids.shape[:2] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) # *************** Csm specific *************** if input_ids.ndim == 2 and model_kwargs.get("inputs_embeds") is None: # in the case where the passed input_ids correspond to text tokens, i.e. don't have a third dimension for codebook ids, # we need to remove the input length to the MaxLengthCriteria stopping criteria has such input are not returned for criterion in stopping_criteria: if isinstance(criterion, MaxLengthCriteria): criterion.max_length -= cur_len # ============================================ model_forward = self.__call__ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) if compile_forward: os.environ["TOKENIZERS_PARALLELISM"] = "0" model_forward = self.get_compiled_call(generation_config.compile_config) is_prefill = True while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device, ): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare variable output controls (note: some models won't accept all output controls) model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) # *************** Csm specific *************** model_inputs.update({"output_hidden_states": True}) # ============================================ if is_prefill: outputs = self(**model_inputs, return_dict=True) is_prefill = False else: outputs = model_forward(**model_inputs, return_dict=True) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, ) if synced_gpus and this_peer_finished: continue # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) next_token_logits = outputs.logits[:, -1, :].clone().float() next_token_logits = next_token_logits.to(input_ids.device) # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += (outputs.attentions,) if output_hidden_states: decoder_hidden_states += (outputs.hidden_states,) # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) # *************** Csm specific *************** # infer the depth decoder first_codebook_ids = next_tokens[:, None] # adds place holder in position 0 that will be replaced by the backbone_last_hidden_state depth_decoder_input_ids = nn.functional.pad(first_codebook_ids, (1, 0), value=0) backbone_last_hidden_state = outputs.hidden_states[-1][:, -1, :] depth_decoder_outputs = self.depth_decoder.generate( input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_state.clone() ) codebook_ids = ( depth_decoder_outputs if isinstance(depth_decoder_outputs, torch.Tensor) else depth_decoder_outputs.sequences ) # remove the place holder in position 0 codebook_ids = codebook_ids[:, 1:] next_tokens = codebook_ids # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * ( 1 - unfinished_sequences.unsqueeze(-1) ) # update generated ids, model inputs, and length for next step if input_ids.ndim == 2: input_ids = next_tokens[:, None, :] else: input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1) # ============================================ if streamer is not None: streamer.put(next_tokens.cpu()) # *************** Csm specific *************** # for the eos stopping criteria, is it expected that the eos token is the same for each codebook !!!! unfinished_sequences = unfinished_sequences & ~( input_ids[:, -1, :-1] == self.config.codebook_eos_token_id ).all(-1) # ============================================ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 cur_len += 1 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration del outputs # *************** Csm specific *************** del depth_decoder_outputs # ============================================ if streamer is not None: streamer.end() if return_dict_in_generate: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids def generate( self, input_ids: Optional[torch.Tensor] = None, input_values: Optional[torch.Tensor] = None, input_values_cutoffs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, synced_gpus: Optional[bool] = None, streamer: Optional["BaseStreamer"] = None, output_audio: Optional[bool] = False, **kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model. Indeed, Csm model requires a custom generation sampling step: 1. Infer the backbone model to sample the first codebook token 2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens 3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model 4. Repeat until stopping criteria is met Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the model's default generation configuration. You can override any `generation_config` by passing the corresponding parameters to generate(), e.g. `.generate(inputs, do_sample=True)`. Parameters: inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*): The sequence used as a prompt for the backbone model. input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*): The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry. These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`. input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*): Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input. If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2, the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]]. generation_config ([`~generation.GenerationConfig`], *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which has the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complements the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is intended for advanced users. synced_gpus (`bool`, *optional*): Whether to continue running the while loop until max_length. Unless overridden, this flag will be set to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`. streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. output_audio (`bool`, *optional*): Whether to return the generated audio. kwargs (`dict[str, Any]`, *optional*): Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*. Return: [`CsmGenerateOutput`] or `torch.LongTensor` or `list[torch.FloatTensor]`: A [`CsmGenerateOutput`] (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False` or a `list[torch.FloatTensor]` otherwise. Example: ```python >>> from transformers import CsmProcessor, CsmForConditionalGeneration >>> from datasets import load_dataset, Audio >>> model_id = "sesame/csm-1b" >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" >>> processor = AutoProcessor.from_pretrained(model_id) >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") >>> # ensure the audio is 24kHz >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000)) >>> conversation = [] >>> # prepare a conversation with text and corresponding audio >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]): ... conversation.append( ... { ... "role": f"{speaker_id}", ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}], ... } ... ) >>> # text prompt >>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]}) >>> inputs = processor.apply_chat_template( ... conversation, ... tokenize=True, ... return_dict=True, ... ).to(torch_device) >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) >>> audio = model.generate(**inputs, output_audio=True) >>> processor.save_audio(audio, "output.wav") ``` """ generate_output = super().generate( input_ids=input_ids, input_values=input_values, input_values_cutoffs=input_values_cutoffs, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, synced_gpus=synced_gpus, streamer=streamer, **kwargs, ) generate_returned_dict = not isinstance(generate_output, torch.Tensor) audio = None if output_audio: generated_audio_codes = generate_output.sequences if generate_returned_dict else generate_output # infer the codec model audio = [] with torch.no_grad(): # ======================================= # TODO: @eustlb, this should be batched !!! # but requires making sure batched inference of the codec model works as intended for audio_codes_batch in generated_audio_codes: eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero() if eos_idxs.numel() != 0: cutoff_idx = eos_idxs.min() else: cutoff_idx = audio_codes_batch.shape[0] audio_codes_batch = audio_codes_batch[:cutoff_idx] codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0)) audio.append(codec_decode_output.audio_values[0, 0]) # ======================================= if generate_returned_dict: return CsmGenerateOutput(audio=audio, **generate_output) elif output_audio: return audio else: return generate_output