464 lines
21 KiB
Python
464 lines
21 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from ...generation.logits_process import (
|
|
DiaClassifierFreeGuidanceLogitsProcessor,
|
|
DiaEOSChannelFilterLogitsProcessor,
|
|
DiaEOSDelayPatternLogitsProcessor,
|
|
LogitsProcessorList,
|
|
TemperatureLogitsWarper,
|
|
)
|
|
from ...generation.stopping_criteria import StoppingCriteriaList
|
|
from ...generation.streamers import BaseStreamer
|
|
from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
|
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
from ...integrations.fsdp import is_fsdp_managed_module
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class DiaGenerationMixin(GenerationMixin):
|
|
# Indicates CFG which needs preparation to be properly handled by repeats
|
|
_uses_cfg = None
|
|
|
|
def _get_logits_processor(
|
|
self,
|
|
generation_config: GenerationConfig,
|
|
input_ids_seq_length: Optional[int] = None,
|
|
encoder_input_ids: torch.LongTensor = None,
|
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
device: Optional[str] = None,
|
|
model_kwargs: Optional[dict[str, Any]] = None,
|
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
) -> LogitsProcessorList:
|
|
# Need either custom order or custom processor instead
|
|
# (Temporarily disabling those for the super function)
|
|
original_guidance_scale = generation_config.guidance_scale
|
|
original_temperature = generation_config.temperature
|
|
generation_config.guidance_scale = None
|
|
generation_config.temperature = None
|
|
|
|
# Get base processors and those we can integrate easily
|
|
custom_processors = LogitsProcessorList()
|
|
|
|
if original_temperature is not None and original_temperature != 1.0:
|
|
custom_processors.append(TemperatureLogitsWarper(original_temperature))
|
|
|
|
custom_processors.append(
|
|
DiaEOSChannelFilterLogitsProcessor(
|
|
num_channels=len(self.config.delay_pattern),
|
|
eos_token_id=self.config.eos_token_id,
|
|
)
|
|
)
|
|
|
|
merged_processors = super()._get_logits_processor(
|
|
generation_config=generation_config,
|
|
input_ids_seq_length=input_ids_seq_length,
|
|
encoder_input_ids=encoder_input_ids,
|
|
prefix_allowed_tokens_fn=None,
|
|
logits_processor=custom_processors,
|
|
device=device,
|
|
model_kwargs=model_kwargs,
|
|
negative_prompt_ids=negative_prompt_ids,
|
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
)
|
|
|
|
# Custom processors we need at specific positions
|
|
if original_guidance_scale is not None and original_guidance_scale != 1:
|
|
cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
|
|
guidance_scale=original_guidance_scale,
|
|
guidance_top_k=generation_config.top_k,
|
|
)
|
|
merged_processors.insert(0, cfg_processor)
|
|
|
|
merged_processors.append(
|
|
DiaEOSDelayPatternLogitsProcessor(
|
|
delay_pattern=self.config.delay_pattern,
|
|
eos_token_id=self.config.eos_token_id,
|
|
max_generation_len=generation_config.max_length,
|
|
device=device,
|
|
)
|
|
)
|
|
|
|
# Enable temporarily disabled values back
|
|
generation_config.guidance_scale = original_guidance_scale
|
|
generation_config.temperature = original_temperature
|
|
|
|
return merged_processors
|
|
|
|
def _prepare_generation_config(
|
|
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict
|
|
) -> tuple[GenerationConfig, dict]:
|
|
generation_config, model_kwargs = super()._prepare_generation_config(
|
|
generation_config, use_model_defaults, **kwargs
|
|
)
|
|
|
|
# We allow generation up to max length + max delay pattern
|
|
# (will revert back to max length after generation)
|
|
generation_config.max_length += max(self.config.delay_pattern)
|
|
|
|
# Internal flag to indicate CFG that needs to prepare unconditioned input
|
|
self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
|
|
|
|
return generation_config, model_kwargs
|
|
|
|
def _prepare_model_inputs(
|
|
self,
|
|
inputs: Optional[torch.Tensor] = None,
|
|
bos_token_id: Optional[torch.Tensor] = None,
|
|
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
|
) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
|
|
inputs, input_name, model_kwargs = super()._prepare_model_inputs(
|
|
inputs=inputs,
|
|
bos_token_id=bos_token_id,
|
|
model_kwargs=model_kwargs,
|
|
)
|
|
|
|
# If CFG is requested we fill in the unconditioned parts
|
|
if self._uses_cfg:
|
|
unconditioned_inputs = torch.zeros_like(inputs)
|
|
inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
|
|
|
|
if model_kwargs.get("attention_mask", None) is not None:
|
|
model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
|
|
|
|
return inputs, input_name, model_kwargs
|
|
|
|
def _prepare_decoder_input_ids_for_generation(
|
|
self,
|
|
batch_size: int,
|
|
model_input_name: str,
|
|
model_kwargs: dict[str, torch.Tensor],
|
|
decoder_start_token_id: torch.Tensor,
|
|
device: Optional[torch.device] = None,
|
|
) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
|
|
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
|
# 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
|
|
decoder_input_ids = decoder_attention_mask = None
|
|
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
|
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
|
if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
|
|
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
|
|
|
|
# We allow generating without preparation (no proper delay) but discourage it
|
|
if decoder_input_ids is None or decoder_attention_mask is None:
|
|
logger.warning_once(
|
|
"In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
|
|
f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
|
|
f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
|
|
)
|
|
|
|
num_channels = self.config.decoder_config.num_channels
|
|
real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
|
|
|
|
if decoder_input_ids is None:
|
|
decoder_input_ids = torch.full(
|
|
(real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
|
|
)
|
|
|
|
decoder_attention_mask = torch.ones(
|
|
size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
|
|
)
|
|
|
|
# 2. Determine the valid input and what works as mask within the input
|
|
delay_mask = decoder_input_ids.long()
|
|
valid_input_size = (
|
|
decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max()
|
|
)
|
|
decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
|
|
decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
|
|
|
|
# 3. Overwrite into model kwargs
|
|
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
|
model_kwargs["decoder_delay_mask"] = delay_mask
|
|
|
|
return decoder_input_ids, model_kwargs
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
encoder_outputs=None, # Using this to easily get the batch size
|
|
decoder_delay_mask=None,
|
|
**kwargs,
|
|
):
|
|
# Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
|
|
batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
|
|
input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
|
|
|
|
# Base method handles most things except CFG and the delay pattern mask
|
|
model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
|
|
|
|
# Post processing for CFG and overwriting via delay pattern mask
|
|
# 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
|
|
model_inputs["decoder_input_ids"] = self.apply_delay_mask(
|
|
input_ids, self.config.pad_token_id, decoder_delay_mask
|
|
)
|
|
|
|
# Depending on cache usage we need to pass all or just one
|
|
if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0:
|
|
model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
|
|
|
|
# Be compile friendly
|
|
model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
|
|
|
|
# 2. Apply CFG duplication if needed
|
|
if self._uses_cfg:
|
|
for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
|
|
if model_inputs.get(key, None) is not None:
|
|
# double first dimension and keep everything else the same
|
|
repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
|
|
model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
|
|
|
|
return model_inputs
|
|
|
|
@staticmethod
|
|
def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
|
if delay_mask is None:
|
|
return input_ids
|
|
|
|
mask_len = min(input_ids.shape[1], delay_mask.shape[1])
|
|
valid_mask = delay_mask[:, :mask_len, :]
|
|
valid_input = input_ids[:, :mask_len, :]
|
|
|
|
# Overwrite the respective parts of the input
|
|
input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
|
|
|
|
return input_ids
|
|
|
|
def _main_generate_loop(
|
|
self,
|
|
inputs: Optional[torch.Tensor] = None,
|
|
generation_config: Optional[GenerationConfig] = None,
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
|
synced_gpus: Optional[bool] = None,
|
|
assistant_model: Optional["PreTrainedModel"] = None,
|
|
streamer: Optional["BaseStreamer"] = None,
|
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
use_model_defaults: Optional[bool] = None,
|
|
custom_generate: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
# ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
|
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
|
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
|
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
|
|
|
generation_config, model_kwargs = self._prepare_generation_config(
|
|
generation_config, use_model_defaults, **kwargs
|
|
)
|
|
self._validate_model_kwargs(model_kwargs.copy())
|
|
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
|
|
|
# 2. Set generation parameters if not already defined
|
|
if synced_gpus is None:
|
|
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
|
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
|
|
# 3. Define model inputs
|
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
|
inputs, generation_config.bos_token_id, model_kwargs
|
|
)
|
|
batch_size = inputs_tensor.shape[0]
|
|
|
|
device = inputs_tensor.device
|
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
|
|
|
# 4. Define other model kwargs
|
|
if "encoder_outputs" not in model_kwargs:
|
|
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
|
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
|
inputs_tensor, model_kwargs, model_input_name, generation_config
|
|
)
|
|
|
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
|
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
|
batch_size=batch_size,
|
|
model_input_name=model_input_name,
|
|
model_kwargs=model_kwargs,
|
|
decoder_start_token_id=generation_config._decoder_start_token_tensor,
|
|
device=inputs_tensor.device,
|
|
)
|
|
|
|
if generation_config.token_healing:
|
|
input_ids = self.heal_tokens(input_ids, tokenizer)
|
|
|
|
if streamer is not None:
|
|
streamer.put(input_ids.cpu())
|
|
|
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
|
# NOTE: incorrect `input_ids.shape[1]` previously
|
|
input_ids_length = input_ids.shape[-1]
|
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
|
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
|
generation_config = self._prepare_generated_length(
|
|
generation_config=generation_config,
|
|
has_default_max_length=has_default_max_length,
|
|
has_default_min_length=has_default_min_length,
|
|
model_input_name=model_input_name,
|
|
inputs_tensor=inputs_tensor,
|
|
input_ids_length=input_ids_length,
|
|
)
|
|
|
|
# If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
|
|
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
|
|
# dynamically overrides this value as it can need more than the last token logits
|
|
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
|
|
model_kwargs["logits_to_keep"] = 1
|
|
|
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
|
|
|
# 7. Prepare the cache.
|
|
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
|
# - different models have a different cache name expected by the model (default = "past_key_values")
|
|
# - `max_length`, prepared above, is used to determine the maximum cache length
|
|
max_cache_length = generation_config.max_length - 1
|
|
if (
|
|
inputs_tensor.shape[1] != input_ids_length
|
|
and model_input_name == "inputs_embeds"
|
|
and not self.config.is_encoder_decoder
|
|
):
|
|
max_cache_length += inputs_tensor.shape[1]
|
|
self._prepare_cache_for_generation(
|
|
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
|
|
)
|
|
|
|
# 8. determine generation mode
|
|
generation_mode = generation_config.get_generation_mode(assistant_model)
|
|
|
|
if streamer is not None and (generation_config.num_beams > 1):
|
|
raise ValueError(
|
|
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
|
)
|
|
|
|
# 9. prepare logits processors and stopping criteria
|
|
prepared_logits_processor = self._get_logits_processor(
|
|
generation_config=generation_config,
|
|
input_ids_seq_length=input_ids_length,
|
|
encoder_input_ids=inputs_tensor,
|
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
|
logits_processor=logits_processor,
|
|
device=inputs_tensor.device,
|
|
model_kwargs=model_kwargs,
|
|
negative_prompt_ids=negative_prompt_ids,
|
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
)
|
|
prepared_stopping_criteria = self._get_stopping_criteria(
|
|
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
|
)
|
|
|
|
# Set model_kwargs `use_cache` so we can use it later in forward runs
|
|
model_kwargs["use_cache"] = generation_config.use_cache
|
|
# ******************* taken from main generate function up to calling the different methods *******************
|
|
|
|
# Prepare inner 2D logic in generation loop
|
|
input_ids = input_ids.reshape(-1, input_ids.shape[-1])
|
|
|
|
# 10. go into different generation modes
|
|
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
|
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
|
if generation_config.num_return_sequences > 1:
|
|
raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
|
|
|
|
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
|
return self._sample(
|
|
input_ids,
|
|
logits_processor=prepared_logits_processor,
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
generation_config=generation_config,
|
|
synced_gpus=synced_gpus,
|
|
streamer=streamer,
|
|
**model_kwargs,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Got incompatible mode for generation, should be one of greedy or sampling. "
|
|
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
inputs: Optional[torch.Tensor] = None,
|
|
generation_config: Optional[GenerationConfig] = None,
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
|
synced_gpus: Optional[bool] = None,
|
|
assistant_model: Optional["PreTrainedModel"] = None,
|
|
streamer: Optional["BaseStreamer"] = None,
|
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
use_model_defaults: Optional[bool] = None,
|
|
custom_generate: Optional[str] = None,
|
|
**kwargs,
|
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
# We expect the initial input ids to be the complete mask (delayed input)
|
|
delay_mask = kwargs.get("decoder_input_ids", None)
|
|
if delay_mask is not None:
|
|
delay_mask = delay_mask.clone()
|
|
|
|
output = self._main_generate_loop(
|
|
inputs=inputs,
|
|
generation_config=generation_config,
|
|
logits_processor=logits_processor,
|
|
stopping_criteria=stopping_criteria,
|
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
|
synced_gpus=synced_gpus,
|
|
assistant_model=assistant_model,
|
|
streamer=streamer,
|
|
negative_prompt_ids=negative_prompt_ids,
|
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
use_model_defaults=use_model_defaults,
|
|
custom_generate=custom_generate,
|
|
**kwargs,
|
|
)
|
|
|
|
return_dict_in_generate = not isinstance(output, torch.Tensor)
|
|
|
|
if return_dict_in_generate:
|
|
output_sequences = output.sequences
|
|
else:
|
|
output_sequences = output
|
|
|
|
# Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
|
|
num_channels = self.config.decoder_config.num_channels
|
|
bsz = output_sequences.shape[0] // num_channels
|
|
output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
|
|
|
|
# Apply delay mask
|
|
output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask)
|
|
|
|
if return_dict_in_generate:
|
|
output.sequences = output_sequences
|
|
else:
|
|
output = output_sequences
|
|
|
|
return output
|