# coding=utf-8 # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist from ...generation.logits_process import ( DiaClassifierFreeGuidanceLogitsProcessor, DiaEOSChannelFilterLogitsProcessor, DiaEOSDelayPatternLogitsProcessor, LogitsProcessorList, TemperatureLogitsWarper, ) from ...generation.stopping_criteria import StoppingCriteriaList from ...generation.streamers import BaseStreamer from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_utils import PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) class DiaGenerationMixin(GenerationMixin): # Indicates CFG which needs preparation to be properly handled by repeats _uses_cfg = None def _get_logits_processor( self, generation_config: GenerationConfig, input_ids_seq_length: Optional[int] = None, encoder_input_ids: torch.LongTensor = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, logits_processor: Optional[LogitsProcessorList] = None, device: Optional[str] = None, model_kwargs: Optional[dict[str, Any]] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, ) -> LogitsProcessorList: # Need either custom order or custom processor instead # (Temporarily disabling those for the super function) original_guidance_scale = generation_config.guidance_scale original_temperature = generation_config.temperature generation_config.guidance_scale = None generation_config.temperature = None # Get base processors and those we can integrate easily custom_processors = LogitsProcessorList() if original_temperature is not None and original_temperature != 1.0: custom_processors.append(TemperatureLogitsWarper(original_temperature)) custom_processors.append( DiaEOSChannelFilterLogitsProcessor( num_channels=len(self.config.delay_pattern), eos_token_id=self.config.eos_token_id, ) ) merged_processors = super()._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=encoder_input_ids, prefix_allowed_tokens_fn=None, logits_processor=custom_processors, device=device, model_kwargs=model_kwargs, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, ) # Custom processors we need at specific positions if original_guidance_scale is not None and original_guidance_scale != 1: cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor( guidance_scale=original_guidance_scale, guidance_top_k=generation_config.top_k, ) merged_processors.insert(0, cfg_processor) merged_processors.append( DiaEOSDelayPatternLogitsProcessor( delay_pattern=self.config.delay_pattern, eos_token_id=self.config.eos_token_id, max_generation_len=generation_config.max_length, device=device, ) ) # Enable temporarily disabled values back generation_config.guidance_scale = original_guidance_scale generation_config.temperature = original_temperature return merged_processors def _prepare_generation_config( self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict ) -> tuple[GenerationConfig, dict]: generation_config, model_kwargs = super()._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) # We allow generation up to max length + max delay pattern # (will revert back to max length after generation) generation_config.max_length += max(self.config.delay_pattern) # Internal flag to indicate CFG that needs to prepare unconditioned input self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1 return generation_config, model_kwargs def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, bos_token_id: Optional[torch.Tensor] = None, model_kwargs: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: inputs, input_name, model_kwargs = super()._prepare_model_inputs( inputs=inputs, bos_token_id=bos_token_id, model_kwargs=model_kwargs, ) # If CFG is requested we fill in the unconditioned parts if self._uses_cfg: unconditioned_inputs = torch.zeros_like(inputs) inputs = torch.cat([inputs, unconditioned_inputs], dim=0) if model_kwargs.get("attention_mask", None) is not None: model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1) return inputs, input_name, model_kwargs def _prepare_decoder_input_ids_for_generation( self, batch_size: int, model_input_name: str, model_kwargs: dict[str, torch.Tensor], decoder_start_token_id: torch.Tensor, device: Optional[torch.device] = None, ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" # 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out decoder_input_ids = decoder_attention_mask = None if model_kwargs is not None and "decoder_input_ids" in model_kwargs: decoder_input_ids = model_kwargs.pop("decoder_input_ids") if model_kwargs is not None and "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs.pop("decoder_attention_mask") # We allow generating without preparation (no proper delay) but discourage it if decoder_input_ids is None or decoder_attention_mask is None: logger.warning_once( "In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:" f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}." f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation." ) num_channels = self.config.decoder_config.num_channels real_batch_size = batch_size // 2 if self._uses_cfg else batch_size if decoder_input_ids is None: decoder_input_ids = torch.full( (real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device ) decoder_attention_mask = torch.ones( size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device ) # 2. Determine the valid input and what works as mask within the input delay_mask = decoder_input_ids.long() valid_input_size = ( decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max() ) decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long() decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long() # 3. Overwrite into model kwargs model_kwargs["decoder_attention_mask"] = decoder_attention_mask model_kwargs["decoder_delay_mask"] = delay_mask return decoder_input_ids, model_kwargs def prepare_inputs_for_generation( self, input_ids, encoder_outputs=None, # Using this to easily get the batch size decoder_delay_mask=None, **kwargs, ): # Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0] input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2) # Base method handles most things except CFG and the delay pattern mask model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs) # Post processing for CFG and overwriting via delay pattern mask # 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask) model_inputs["decoder_input_ids"] = self.apply_delay_mask( input_ids, self.config.pad_token_id, decoder_delay_mask ) # Depending on cache usage we need to pass all or just one if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0: model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :] # Be compile friendly model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous() # 2. Apply CFG duplication if needed if self._uses_cfg: for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]: if model_inputs.get(key, None) is not None: # double first dimension and keep everything else the same repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1)) model_inputs[key] = model_inputs[key].repeat(*repeat_pattern) return model_inputs @staticmethod def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor: if delay_mask is None: return input_ids mask_len = min(input_ids.shape[1], delay_mask.shape[1]) valid_mask = delay_mask[:, :mask_len, :] valid_input = input_ids[:, :mask_len, :] # Overwrite the respective parts of the input input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask) return input_ids def _main_generate_loop( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, use_model_defaults: Optional[bool] = None, custom_generate: Optional[str] = None, **kwargs, ): # ********** mostly taken from main generate function up to calling the different methods (see NOTE) ********** # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation generation_config, model_kwargs = self._prepare_generation_config( generation_config, use_model_defaults, **kwargs ) self._validate_model_kwargs(model_kwargs.copy()) self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) # 2. Set generation parameters if not already defined if synced_gpus is None: synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() # 3. Define model inputs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] device = inputs_tensor.device self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) # 4. Define other model kwargs if "encoder_outputs" not in model_kwargs: # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name, generation_config ) # 5. Prepare `input_ids` which will be used for auto-regressive generation input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, decoder_start_token_id=generation_config._decoder_start_token_tensor, device=inputs_tensor.device, ) if generation_config.token_healing: input_ids = self.heal_tokens(input_ids, tokenizer) if streamer is not None: streamer.put(input_ids.cpu()) # 6. Prepare `max_length` depending on other stopping criteria. # NOTE: incorrect `input_ids.shape[1]` previously input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding # dynamically overrides this value as it can need more than the last token logits if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs: model_kwargs["logits_to_keep"] = 1 self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. Prepare the cache. # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. # - different models have a different cache name expected by the model (default = "past_key_values") # - `max_length`, prepared above, is used to determine the maximum cache length max_cache_length = generation_config.max_length - 1 if ( inputs_tensor.shape[1] != input_ids_length and model_input_name == "inputs_embeds" and not self.config.is_encoder_decoder ): max_cache_length += inputs_tensor.shape[1] self._prepare_cache_for_generation( generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device ) # 8. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) if streamer is not None and (generation_config.num_beams > 1): raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) # 9. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, device=inputs_tensor.device, model_kwargs=model_kwargs, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, ) prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ) # Set model_kwargs `use_cache` so we can use it later in forward runs model_kwargs["use_cache"] = generation_config.use_cache # ******************* taken from main generate function up to calling the different methods ******************* # Prepare inner 2D logic in generation loop input_ids = input_ids.reshape(-1, input_ids.shape[-1]) # 10. go into different generation modes if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): # 11. expand input_ids with `num_return_sequences` additional sequences per batch if generation_config.num_return_sequences > 1: raise ValueError("`num_return_sequences>1` is incompatible with Dia.") # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) return self._sample( input_ids, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) else: raise ValueError( "Got incompatible mode for generation, should be one of greedy or sampling. " "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, use_model_defaults: Optional[bool] = None, custom_generate: Optional[str] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: # We expect the initial input ids to be the complete mask (delayed input) delay_mask = kwargs.get("decoder_input_ids", None) if delay_mask is not None: delay_mask = delay_mask.clone() output = self._main_generate_loop( inputs=inputs, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, assistant_model=assistant_model, streamer=streamer, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, use_model_defaults=use_model_defaults, custom_generate=custom_generate, **kwargs, ) return_dict_in_generate = not isinstance(output, torch.Tensor) if return_dict_in_generate: output_sequences = output.sequences else: output_sequences = output # Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels) num_channels = self.config.decoder_config.num_channels bsz = output_sequences.shape[0] // num_channels output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2) # Apply delay mask output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask) if return_dict_in_generate: output.sequences = output_sequences else: output = output_sequences return output