5354 lines
280 KiB
Python
5354 lines
280 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import copy
|
|
import inspect
|
|
import os
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
from huggingface_hub import file_exists
|
|
from packaging import version
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from ..cache_utils import (
|
|
Cache,
|
|
DynamicCache,
|
|
EncoderDecoderCache,
|
|
HybridChunkedCache,
|
|
OffloadedCache,
|
|
OffloadedHybridCache,
|
|
)
|
|
from ..configuration_utils import PretrainedConfig
|
|
from ..dynamic_module_utils import (
|
|
check_python_requirements,
|
|
get_cached_module_file,
|
|
get_class_in_module,
|
|
resolve_trust_remote_code,
|
|
)
|
|
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
from ..integrations.fsdp import is_fsdp_managed_module
|
|
from ..masking_utils import create_masks_for_generate
|
|
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
|
from ..pytorch_utils import isin_mps_friendly
|
|
from ..tokenization_utils import ExtensionsTrie
|
|
from ..utils import (
|
|
ModelOutput,
|
|
is_accelerate_available,
|
|
is_hqq_available,
|
|
is_optimum_quanto_available,
|
|
is_torchdynamo_exporting,
|
|
logging,
|
|
)
|
|
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
|
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
|
from .candidate_generator import (
|
|
AssistantVocabTranslatorCache,
|
|
AssistedCandidateGenerator,
|
|
AssistedCandidateGeneratorDifferentTokenizers,
|
|
CandidateGenerator,
|
|
EarlyExitCandidateGenerator,
|
|
PromptLookupCandidateGenerator,
|
|
UniversalSpeculativeDecodingGenerator,
|
|
_prepare_attention_mask,
|
|
_prepare_token_type_ids,
|
|
)
|
|
from .configuration_utils import (
|
|
NEED_SETUP_CACHE_CLASSES_MAPPING,
|
|
QUANT_BACKEND_CLASSES_MAPPING,
|
|
CompileConfig,
|
|
GenerationConfig,
|
|
GenerationMode,
|
|
)
|
|
from .continuous_batching import ContinuousMixin
|
|
from .logits_process import (
|
|
EncoderNoRepeatNGramLogitsProcessor,
|
|
EncoderRepetitionPenaltyLogitsProcessor,
|
|
EpsilonLogitsWarper,
|
|
EtaLogitsWarper,
|
|
ExponentialDecayLengthPenalty,
|
|
ForcedBOSTokenLogitsProcessor,
|
|
ForcedEOSTokenLogitsProcessor,
|
|
HammingDiversityLogitsProcessor,
|
|
InfNanRemoveLogitsProcessor,
|
|
LogitNormalization,
|
|
LogitsProcessorList,
|
|
MinLengthLogitsProcessor,
|
|
MinNewTokensLengthLogitsProcessor,
|
|
MinPLogitsWarper,
|
|
NoBadWordsLogitsProcessor,
|
|
NoRepeatNGramLogitsProcessor,
|
|
PrefixConstrainedLogitsProcessor,
|
|
RepetitionPenaltyLogitsProcessor,
|
|
SequenceBiasLogitsProcessor,
|
|
SuppressTokensAtBeginLogitsProcessor,
|
|
SuppressTokensLogitsProcessor,
|
|
TemperatureLogitsWarper,
|
|
TopKLogitsWarper,
|
|
TopPLogitsWarper,
|
|
TypicalLogitsWarper,
|
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
|
)
|
|
from .stopping_criteria import (
|
|
ConfidenceCriteria,
|
|
EosTokenCriteria,
|
|
MaxLengthCriteria,
|
|
MaxTimeCriteria,
|
|
StoppingCriteria,
|
|
StoppingCriteriaList,
|
|
StopStringCriteria,
|
|
)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from ..modeling_utils import PreTrainedModel
|
|
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
|
from .streamers import BaseStreamer
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
if is_accelerate_available():
|
|
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
|
|
|
|
|
|
# Variable names used to hold the cache at generation time
|
|
ALL_CACHE_NAMES = [
|
|
"past_key_values", # default
|
|
"cache_params", # mamba-based models
|
|
"state", # rwkv
|
|
"mems", # xlnet
|
|
"past_buckets_states", # reformer
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class GenerateDecoderOnlyOutput(ModelOutput):
|
|
"""
|
|
Outputs of decoder-only generation models, when using non-beam methods.
|
|
|
|
Args:
|
|
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
|
if all batches finished early due to the `eos_token_id`.
|
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
|
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
|
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
|
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
|
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
|
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
|
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
|
"""
|
|
|
|
sequences: torch.LongTensor
|
|
scores: Optional[tuple[torch.FloatTensor]] = None
|
|
logits: Optional[tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None
|
|
|
|
|
|
@dataclass
|
|
class GenerateEncoderDecoderOutput(ModelOutput):
|
|
"""
|
|
Outputs of encoder-decoder generation models, when using non-beam methods.
|
|
|
|
Args:
|
|
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
|
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
|
if all batches finished early due to the `eos_token_id`.
|
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
|
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
|
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
|
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
|
sequence_length, sequence_length)`.
|
|
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
|
shape `(batch_size, sequence_length, hidden_size)`.
|
|
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
|
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
|
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
|
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
|
"""
|
|
|
|
sequences: torch.LongTensor
|
|
scores: Optional[tuple[torch.FloatTensor]] = None
|
|
logits: Optional[tuple[torch.FloatTensor]] = None
|
|
encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None
|
|
|
|
|
|
@dataclass
|
|
class GenerateBeamDecoderOnlyOutput(ModelOutput):
|
|
"""
|
|
Outputs of decoder-only generation models, when using beam methods.
|
|
|
|
Args:
|
|
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
|
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
|
if all batches finished early due to the `eos_token_id`.
|
|
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
|
|
Final beam scores of the generated `sequences`.
|
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
|
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
|
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
|
|
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
|
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
|
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
|
each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
|
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
|
|
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
|
`(batch_size*num_return_sequences, sequence_length)`.
|
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
|
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
|
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
|
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
|
"""
|
|
|
|
sequences: torch.LongTensor
|
|
sequences_scores: Optional[torch.FloatTensor] = None
|
|
scores: Optional[tuple[torch.FloatTensor]] = None
|
|
logits: Optional[tuple[torch.FloatTensor]] = None
|
|
beam_indices: Optional[torch.LongTensor] = None
|
|
attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None
|
|
|
|
|
|
@dataclass
|
|
class GenerateBeamEncoderDecoderOutput(ModelOutput):
|
|
"""
|
|
Outputs of encoder-decoder generation models, when using beam methods.
|
|
|
|
Args:
|
|
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
|
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
|
if all batches finished early due to the `eos_token_id`.
|
|
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
|
|
Final beam scores of the generated `sequences`.
|
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
|
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
|
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
|
|
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
|
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
|
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
|
each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
|
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
|
|
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
|
`(batch_size*num_return_sequences, sequence_length)`.
|
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
|
sequence_length, sequence_length)`.
|
|
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
|
shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
|
|
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
|
|
sequence_length)`.
|
|
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
|
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
|
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
|
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
|
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
|
"""
|
|
|
|
sequences: torch.LongTensor
|
|
sequences_scores: Optional[torch.FloatTensor] = None
|
|
scores: Optional[tuple[torch.FloatTensor]] = None
|
|
logits: Optional[tuple[torch.FloatTensor]] = None
|
|
beam_indices: Optional[torch.LongTensor] = None
|
|
encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
|
past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]] = None
|
|
|
|
|
|
# TODO (joao): remove the equivalent classes and typing shortcuts below in v5
|
|
# Equivalent classes (kept for retrocompatibility purposes)
|
|
GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
|
|
ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
|
|
SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput
|
|
|
|
ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
|
|
GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
|
|
SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput
|
|
|
|
BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
|
|
BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
|
|
|
|
BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
|
|
BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
|
|
|
|
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
|
|
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
|
|
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
|
|
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
|
|
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
|
|
|
|
# Typing shortcuts
|
|
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
|
|
GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput]
|
|
GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
|
|
|
|
|
|
class GenerationMixin(ContinuousMixin):
|
|
"""
|
|
A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes.
|
|
Inheriting from this class causes the model to have special generation-related behavior, such as loading a
|
|
`GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI.
|
|
|
|
A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it
|
|
has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which
|
|
approximately shares the same interface to public methods like `generate`. Three examples:
|
|
- `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public
|
|
methods in the mixin;
|
|
- `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as
|
|
`GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls
|
|
`GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should
|
|
inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase;
|
|
- `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`.
|
|
However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case,
|
|
`BarkModel` should NOT inherit from `GenerationMixin`, as it breaks the `generate` interface.
|
|
|
|
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
|
|
- *greedy decoding* if `num_beams=1` and `do_sample=False`
|
|
- *contrastive search* if `penalty_alpha>0` and `top_k>1`
|
|
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
|
|
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
|
|
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
|
|
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
|
|
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
|
|
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
|
|
|
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
|
"""
|
|
|
|
def load_custom_generate(
|
|
self,
|
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
|
trust_remote_code: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> Callable:
|
|
"""
|
|
Loads and returns a custom generate function, given a model repo.
|
|
|
|
Args:
|
|
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
|
Can be either:
|
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
|
- A path to a *directory* containing model weights saved using
|
|
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
|
trust_remote_code (`bool`, *optional*):
|
|
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
|
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
|
execute code present on the Hub on your local machine.
|
|
**kwargs:
|
|
Additional keyword arguments for remote code loading.
|
|
|
|
Raises:
|
|
OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory.
|
|
|
|
Returns:
|
|
A callable that can be used to generate text.
|
|
"""
|
|
# Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError
|
|
is_local_code = os.path.exists(pretrained_model_name_or_path)
|
|
has_custom_generate_folder = True
|
|
if is_local_code:
|
|
if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")):
|
|
has_custom_generate_folder = False
|
|
else:
|
|
if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"):
|
|
has_custom_generate_folder = False
|
|
|
|
if not has_custom_generate_folder:
|
|
raise OSError(
|
|
f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a "
|
|
"`generate.py` file, can't load the custom generate function."
|
|
)
|
|
|
|
# Handle opt-in `trust_remote_code` and related exceptions
|
|
error_message = (
|
|
f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override "
|
|
"the default `generate` method."
|
|
)
|
|
resolve_trust_remote_code(
|
|
trust_remote_code,
|
|
pretrained_model_name_or_path,
|
|
has_local_code=is_local_code,
|
|
has_remote_code=not is_local_code,
|
|
error_message=error_message,
|
|
)
|
|
|
|
# Load the custom generate function
|
|
check_python_requirements(
|
|
pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs
|
|
)
|
|
module = get_cached_module_file(
|
|
pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
|
|
)
|
|
custom_generate_function = get_class_in_module("generate", module)
|
|
return custom_generate_function
|
|
|
|
def _cache_dependant_input_preparation(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
inputs_embeds: Optional[torch.FloatTensor],
|
|
cache_position: Optional[torch.LongTensor],
|
|
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
|
"""
|
|
Generic cache-dependent input preparation
|
|
The code is put in a separate function to allow granular unit testing
|
|
as it needs a different implementation to be exportable.
|
|
|
|
If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
|
- Exception 1: when passing input_embeds, input_ids may be missing entries
|
|
- Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
|
- Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
|
- Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
|
generate the first token for each sequence. Later use the generated Input ids for continuation.
|
|
|
|
The current implementation does not rely on ``self`` and could be
|
|
a class method. It is left as a standard method to be easily rewritten.
|
|
"""
|
|
if is_torchdynamo_exporting():
|
|
return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position)
|
|
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
|
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
|
elif (
|
|
inputs_embeds is not None # Exception 1
|
|
or (cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
|
):
|
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
|
input_ids = input_ids[:, cache_position]
|
|
return inputs_embeds, input_ids
|
|
|
|
def _cache_dependant_input_preparation_exporting(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
inputs_embeds: Optional[torch.FloatTensor],
|
|
cache_position: Optional[torch.LongTensor],
|
|
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
|
"""
|
|
This method implements method ``_cache_dependant_input_preparation``
|
|
with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
|
|
The code is put in a separate function to allow granular unit testing.
|
|
"""
|
|
if inputs_embeds is None:
|
|
input_ids = input_ids[:, cache_position]
|
|
else:
|
|
# This is the code we need to implemented with torch.cond.
|
|
# if input_ids.shape[1] == 0:
|
|
# inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
|
# else:
|
|
# if cache_position[-1] >= input_ids.shape[1]:
|
|
# input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
# else:
|
|
# if input_ids.shape[1] != cache_position.shape[0]:
|
|
# input_ids = input_ids[:, cache_position]
|
|
def branch_1(inputs_embeds, cache_position):
|
|
return inputs_embeds[:, -cache_position.shape[0] :]
|
|
|
|
def branch_2(input_ids, cache_position):
|
|
return input_ids[:, -cache_position.shape[0] :]
|
|
|
|
def branch_3(input_ids, cache_position):
|
|
return input_ids[:, cache_position]
|
|
|
|
inputs_embeds, input_ids = torch.cond(
|
|
input_ids.shape[1] == 0,
|
|
(
|
|
lambda input_ids, inputs_embeds, cache_position: (
|
|
branch_1(inputs_embeds, cache_position),
|
|
input_ids,
|
|
)
|
|
),
|
|
(
|
|
lambda input_ids, inputs_embeds, cache_position: (
|
|
inputs_embeds,
|
|
torch.cond(
|
|
cache_position[-1] >= input_ids.shape[1],
|
|
branch_2,
|
|
lambda input_ids, cache_position: (
|
|
torch.cond(
|
|
input_ids.shape[1] != cache_position.shape[0],
|
|
branch_3,
|
|
(lambda input_ids, cache_position: input_ids),
|
|
[input_ids, cache_position],
|
|
)
|
|
),
|
|
[input_ids, cache_position],
|
|
),
|
|
)
|
|
),
|
|
[input_ids, inputs_embeds, cache_position],
|
|
)
|
|
return inputs_embeds, input_ids
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
past_key_values: Optional[Cache] = None,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Prepare the model inputs for generation. It includes operations like computing the 4D attention mask or
|
|
slicing inputs given the existing cache.
|
|
|
|
See the forward pass in the model documentation for expected arguments (different models might have different
|
|
requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
|
|
"""
|
|
|
|
# 1. Handle BC:
|
|
model_inputs = {}
|
|
model_inputs["cache_position"] = cache_position
|
|
|
|
# 2. Generic cache-dependent input preparation
|
|
if past_key_values is not None:
|
|
model_inputs["past_key_values"] = past_key_values
|
|
inputs_embeds, input_ids = self._cache_dependant_input_preparation(
|
|
input_ids, inputs_embeds, cache_position
|
|
)
|
|
|
|
# 3. Prepare base model inputs
|
|
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
|
|
if not self.config.is_encoder_decoder:
|
|
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
|
model_inputs[input_ids_key] = None
|
|
model_inputs["inputs_embeds"] = inputs_embeds
|
|
else:
|
|
# `clone` calls in this function ensure a consistent stride. See #32227
|
|
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
|
model_inputs["inputs_embeds"] = None
|
|
else:
|
|
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
|
|
|
# 4. Create missing `position_ids` on the fly
|
|
encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
|
|
attention_mask = (
|
|
kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
|
|
)
|
|
attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
|
|
position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
|
|
if (
|
|
attention_mask is not None
|
|
and kwargs.get(position_ids_key) is None
|
|
and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
|
|
):
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below)
|
|
|
|
# 5. Slice model inputs if it's an input that should have the same length as `input_ids`
|
|
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
|
|
model_input = kwargs.get(model_input_name)
|
|
if model_input is not None:
|
|
if past_key_values is not None:
|
|
current_input_length = (
|
|
model_inputs["inputs_embeds"].shape[1]
|
|
if model_inputs.get("inputs_embeds") is not None
|
|
else model_inputs[input_ids_key].shape[1]
|
|
)
|
|
model_input = model_input[:, -current_input_length:]
|
|
model_input = model_input.clone(memory_format=torch.contiguous_format)
|
|
model_inputs[model_input_name] = model_input
|
|
|
|
# 6. Create 4D attention mask is we are using a compilable cache (important for performant compiled forward
|
|
# pass)
|
|
if (
|
|
isinstance(past_key_values, Cache)
|
|
and past_key_values.is_compileable
|
|
and attention_mask is not None
|
|
and attention_mask.ndim == 2
|
|
):
|
|
if not self.config.is_encoder_decoder and model_inputs["inputs_embeds"] is not None:
|
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
|
else:
|
|
batch_size, sequence_length = model_inputs[input_ids_key].shape[:2]
|
|
|
|
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
|
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
|
|
base_model = getattr(self, self.base_model_prefix, self)
|
|
decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None
|
|
causal_mask_creation_function = getattr(
|
|
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
|
)
|
|
if causal_mask_creation_function is None and decoder is not None: # it may be in the decoder
|
|
causal_mask_creation_function = getattr(
|
|
decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
|
)
|
|
|
|
# If it's not defined, it means the model uses the new general mask API
|
|
if causal_mask_creation_function is None: # can't be found
|
|
token_type_ids = model_inputs.get("token_type_ids", None)
|
|
position_ids = model_inputs.get(position_ids_key, None)
|
|
# Some models may overwrite the general one
|
|
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
|
|
attention_mask = causal_mask_creation_function(
|
|
config=self.config,
|
|
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
|
|
input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
token_type_ids=token_type_ids,
|
|
)
|
|
else:
|
|
attention_mask = causal_mask_creation_function(
|
|
attention_mask,
|
|
sequence_length=sequence_length,
|
|
target_length=past_key_values.get_max_cache_shape(),
|
|
dtype=self.dtype,
|
|
cache_position=cache_position,
|
|
batch_size=batch_size,
|
|
config=self.config,
|
|
past_key_values=past_key_values,
|
|
)
|
|
if attention_mask is not None:
|
|
model_inputs[attention_mask_key] = attention_mask
|
|
|
|
if encoder_attention_mask is not None:
|
|
model_inputs["attention_mask"] = encoder_attention_mask
|
|
|
|
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
|
|
tensor_kws = {"dtype": torch.int32, "device": self.device}
|
|
pos = model_inputs["position_ids"][:, -1]
|
|
|
|
cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
|
|
max_length_k = int(pos.max()) + 1
|
|
|
|
bs, seq_len = input_ids.size()
|
|
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
|
|
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
|
|
max_length_q = int(q_len.max())
|
|
|
|
model_inputs.update(
|
|
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
|
|
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
|
|
max_length_q=max_length_q,
|
|
max_length_k=max_length_k,
|
|
)
|
|
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
|
for key, value in kwargs.items():
|
|
if key not in model_inputs:
|
|
model_inputs[key] = value
|
|
|
|
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
|
model_inputs.pop("labels", None)
|
|
return model_inputs
|
|
|
|
def _prepare_model_inputs(
|
|
self,
|
|
inputs: Optional[torch.Tensor] = None,
|
|
bos_token_id: Optional[torch.Tensor] = None,
|
|
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
|
) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
|
|
"""
|
|
This function extracts the model-specific `inputs` for generation.
|
|
"""
|
|
# 1. retrieve all kwargs that are non-None or non-model input related.
|
|
# some encoder-decoder models have different names for model and encoder
|
|
if (
|
|
self.config.is_encoder_decoder
|
|
and hasattr(self, "encoder")
|
|
and self.encoder.main_input_name != self.main_input_name
|
|
):
|
|
input_name = self.encoder.main_input_name
|
|
else:
|
|
input_name = self.main_input_name
|
|
|
|
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
|
|
|
|
# 2. check whether model_input_name is passed as kwarg
|
|
# if yes and `inputs` is None use kwarg inputs
|
|
inputs_kwarg = model_kwargs.pop(input_name, None)
|
|
if inputs_kwarg is not None and inputs is not None:
|
|
raise ValueError(
|
|
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
|
|
f"Make sure to either pass {inputs} or {input_name}=..."
|
|
)
|
|
elif inputs_kwarg is not None:
|
|
inputs = inputs_kwarg
|
|
|
|
# 3. In the presence of `inputs_embeds` for text models:
|
|
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
|
|
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
|
|
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
|
|
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
|
|
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
|
|
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
|
|
if model_kwargs["inputs_embeds"] is None:
|
|
model_kwargs.pop("inputs_embeds")
|
|
elif not self.config.is_encoder_decoder:
|
|
has_inputs_embeds_forwarding = "inputs_embeds" in set(
|
|
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
|
|
)
|
|
if not has_inputs_embeds_forwarding:
|
|
raise ValueError(
|
|
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
|
|
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
|
|
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
|
|
)
|
|
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
|
|
# the attention mask) can rely on the actual model input.
|
|
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
|
inputs, bos_token_id, model_kwargs=model_kwargs
|
|
)
|
|
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
|
else:
|
|
if inputs is not None:
|
|
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
|
|
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
|
|
|
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
|
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
|
|
return inputs, input_name, model_kwargs
|
|
|
|
def _maybe_initialize_input_ids_for_generation(
|
|
self,
|
|
inputs: Optional[torch.Tensor] = None,
|
|
bos_token_id: Optional[torch.Tensor] = None,
|
|
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
|
) -> torch.LongTensor:
|
|
"""Initializes input ids for generation, if necessary."""
|
|
if inputs is not None:
|
|
return inputs
|
|
|
|
encoder_outputs = model_kwargs.get("encoder_outputs")
|
|
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
|
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
|
shape = encoder_outputs.last_hidden_state.size()[:-1]
|
|
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
|
|
|
|
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
|
|
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
|
|
batch_size = 1
|
|
for value in model_kwargs.values():
|
|
if isinstance(value, torch.Tensor):
|
|
batch_size = value.shape[0]
|
|
break
|
|
|
|
if "inputs_embeds" in model_kwargs:
|
|
return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
|
|
|
|
if bos_token_id is None:
|
|
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
|
|
|
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
|
|
|
|
def _prepare_attention_mask_for_generation(
|
|
self,
|
|
inputs_tensor: torch.Tensor,
|
|
generation_config: GenerationConfig,
|
|
model_kwargs: dict[str, Any],
|
|
) -> torch.LongTensor:
|
|
pad_token_id = generation_config._pad_token_tensor
|
|
eos_token_id = generation_config._eos_token_tensor
|
|
|
|
# `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
|
|
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
|
|
inputs_tensor = model_kwargs["input_ids"]
|
|
|
|
# No information for attention mask inference -> return default attention mask
|
|
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
|
|
if pad_token_id is None:
|
|
return default_attention_mask
|
|
|
|
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
|
|
if not is_input_ids:
|
|
return default_attention_mask
|
|
|
|
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
|
isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
|
|
)
|
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
|
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
|
|
)
|
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
|
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
|
|
|
|
attention_mask = (
|
|
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
|
)
|
|
return attention_mask
|
|
|
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
|
self,
|
|
inputs_tensor: torch.Tensor,
|
|
model_kwargs,
|
|
model_input_name: Optional[str],
|
|
generation_config: GenerationConfig,
|
|
) -> dict[str, Any]:
|
|
# 1. get encoder
|
|
encoder = self.get_encoder()
|
|
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
|
|
# as the inputs.
|
|
if hasattr(self, "hf_device_map"):
|
|
if hasattr(encoder, "_hf_hook"):
|
|
encoder._hf_hook.io_same_device = True
|
|
else:
|
|
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
|
|
|
|
# 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
|
|
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
|
encoder_kwargs = {
|
|
argument: value
|
|
for argument, value in model_kwargs.items()
|
|
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
|
}
|
|
encoder_signature = set(inspect.signature(encoder.forward).parameters)
|
|
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
|
|
if not encoder_accepts_wildcard:
|
|
encoder_kwargs = {
|
|
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
|
|
}
|
|
encoder_kwargs["output_attentions"] = generation_config.output_attentions
|
|
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
|
|
|
# 3. make sure that encoder returns `ModelOutput`
|
|
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
|
encoder_kwargs["return_dict"] = True
|
|
encoder_kwargs[model_input_name] = inputs_tensor
|
|
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore
|
|
|
|
return model_kwargs
|
|
|
|
def _prepare_decoder_input_ids_for_generation(
|
|
self,
|
|
batch_size: int,
|
|
model_input_name: str,
|
|
model_kwargs: dict[str, torch.Tensor],
|
|
decoder_start_token_id: torch.Tensor,
|
|
device: Optional[torch.device] = None,
|
|
) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
|
|
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
|
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
|
|
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
|
|
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
|
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
|
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
|
|
decoder_input_ids = model_kwargs.pop("input_ids")
|
|
else:
|
|
decoder_input_ids = None
|
|
|
|
# 2. `decoder_start_token_id` must have shape (batch_size, 1)
|
|
if device is None:
|
|
device = self.device
|
|
if decoder_start_token_id.ndim == 1:
|
|
if decoder_start_token_id.shape[0] != batch_size:
|
|
raise ValueError(
|
|
f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
|
|
)
|
|
decoder_start_token_id = decoder_start_token_id.view(-1, 1)
|
|
else:
|
|
decoder_start_token_id = (
|
|
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
|
|
)
|
|
|
|
# 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
|
|
# no user input -> use decoder_start_token_id as decoder_input_ids
|
|
if decoder_input_ids is None:
|
|
decoder_input_ids = decoder_start_token_id
|
|
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the
|
|
# original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic.
|
|
# See: https://github.com/huggingface/transformers/pull/31470
|
|
elif "donut" in self.__class__.__name__.lower() or (
|
|
self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
|
|
):
|
|
pass
|
|
elif self.config.model_type in ["whisper"]:
|
|
pass
|
|
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
|
# decoder_attention_mask if provided)
|
|
elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():
|
|
decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
|
|
if "decoder_attention_mask" in model_kwargs:
|
|
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
|
decoder_attention_mask = torch.cat(
|
|
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
|
|
dim=-1,
|
|
)
|
|
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
|
|
|
return decoder_input_ids, model_kwargs
|
|
|
|
@staticmethod
|
|
def _expand_inputs_for_generation(
|
|
expand_size: int = 1,
|
|
is_encoder_decoder: bool = False,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
**model_kwargs,
|
|
) -> tuple[torch.LongTensor, dict[str, Any]]:
|
|
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
|
|
# Do not call torch.repeat_interleave if expand_size is 1 because it clones
|
|
# the input tensor and thus requires more memory although no change is applied
|
|
if expand_size == 1:
|
|
return input_ids, model_kwargs
|
|
|
|
def _expand_dict_for_generation(dict_to_expand):
|
|
for key in dict_to_expand:
|
|
if (
|
|
key != "cache_position"
|
|
and dict_to_expand[key] is not None
|
|
and isinstance(dict_to_expand[key], torch.Tensor)
|
|
):
|
|
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
|
return dict_to_expand
|
|
|
|
if input_ids is not None:
|
|
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
|
|
|
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
|
|
|
if is_encoder_decoder:
|
|
if model_kwargs.get("encoder_outputs") is None:
|
|
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
|
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
|
|
|
return input_ids, model_kwargs
|
|
|
|
def _update_model_kwargs_for_generation(
|
|
self,
|
|
outputs: ModelOutput,
|
|
model_kwargs: dict[str, Any],
|
|
is_encoder_decoder: bool = False,
|
|
num_new_tokens: int = 1,
|
|
) -> dict[str, Any]:
|
|
# update past_key_values keeping its naming used in model code
|
|
for possible_cache_name in ALL_CACHE_NAMES:
|
|
if possible_cache_name in outputs:
|
|
# TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated
|
|
if possible_cache_name in ("past_buckets_states", "mems"):
|
|
cache_name = "past_key_values"
|
|
else:
|
|
cache_name = possible_cache_name
|
|
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
|
|
break
|
|
|
|
# update token_type_ids with last value
|
|
if "token_type_ids" in model_kwargs:
|
|
token_type_ids = model_kwargs["token_type_ids"]
|
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
|
|
|
if not is_encoder_decoder:
|
|
# update attention mask
|
|
if "attention_mask" in model_kwargs:
|
|
attention_mask = model_kwargs["attention_mask"]
|
|
model_kwargs["attention_mask"] = torch.cat(
|
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
|
)
|
|
else:
|
|
# update decoder attention mask
|
|
if "decoder_attention_mask" in model_kwargs:
|
|
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
|
model_kwargs["decoder_attention_mask"] = torch.cat(
|
|
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
|
|
dim=-1,
|
|
)
|
|
|
|
if model_kwargs.get("use_cache", True):
|
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
|
else:
|
|
past_positions = model_kwargs.pop("cache_position")
|
|
new_positions = torch.arange(
|
|
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
|
|
).to(past_positions.device)
|
|
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
|
|
return model_kwargs
|
|
|
|
def _get_candidate_generator(
|
|
self,
|
|
generation_config: GenerationConfig,
|
|
input_ids: torch.LongTensor,
|
|
inputs_tensor: torch.Tensor,
|
|
assistant_model: "PreTrainedModel",
|
|
logits_processor: LogitsProcessorList,
|
|
target_tokenizer: "PreTrainedTokenizerBase",
|
|
assistant_tokenizer: "PreTrainedTokenizerBase",
|
|
model_kwargs: dict,
|
|
) -> CandidateGenerator:
|
|
"""
|
|
Returns the candidate generator to be used in `assisted_generation`
|
|
"""
|
|
different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))
|
|
|
|
if generation_config.assistant_early_exit is not None:
|
|
candidate_generator = EarlyExitCandidateGenerator(
|
|
input_ids=input_ids,
|
|
assistant_model=self,
|
|
generation_config=generation_config,
|
|
model_kwargs=model_kwargs,
|
|
inputs_tensor=inputs_tensor,
|
|
logits_processor=logits_processor,
|
|
)
|
|
elif generation_config.prompt_lookup_num_tokens is not None:
|
|
candidate_generator = PromptLookupCandidateGenerator(
|
|
eos_token_id=generation_config._eos_token_tensor,
|
|
num_output_tokens=generation_config.prompt_lookup_num_tokens,
|
|
max_matching_ngram_size=generation_config.max_matching_ngram_size,
|
|
max_length=generation_config.max_length,
|
|
)
|
|
elif different_tokenizers:
|
|
if generation_config.do_sample is True:
|
|
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
|
target_tokenizer,
|
|
assistant_tokenizer,
|
|
self.config.get_text_config().vocab_size,
|
|
assistant_model=assistant_model,
|
|
assistant_prune_lm_head=True, # prune LM head of assistant model
|
|
)
|
|
# Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismatches between token ids and logits index
|
|
assistant_model.generation_config.repetition_penalty = None
|
|
candidate_generator = UniversalSpeculativeDecodingGenerator(
|
|
input_ids=input_ids,
|
|
assistant_model=assistant_model,
|
|
generation_config=generation_config,
|
|
model_kwargs=model_kwargs,
|
|
inputs_tensor=inputs_tensor,
|
|
logits_processor=logits_processor,
|
|
target_tokenizer=target_tokenizer,
|
|
assistant_tokenizer=assistant_tokenizer,
|
|
atm_translator=atm_translator,
|
|
)
|
|
elif generation_config.do_sample is False:
|
|
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
|
|
input_ids=input_ids,
|
|
assistant_model=assistant_model,
|
|
generation_config=generation_config,
|
|
model_kwargs=model_kwargs,
|
|
inputs_tensor=inputs_tensor,
|
|
logits_processor=logits_processor,
|
|
target_tokenizer=target_tokenizer,
|
|
assistant_tokenizer=assistant_tokenizer,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}"
|
|
)
|
|
else:
|
|
candidate_generator = AssistedCandidateGenerator(
|
|
input_ids=input_ids,
|
|
assistant_model=assistant_model,
|
|
generation_config=generation_config,
|
|
model_kwargs=model_kwargs,
|
|
inputs_tensor=inputs_tensor,
|
|
logits_processor=logits_processor,
|
|
)
|
|
return candidate_generator
|
|
|
|
def _get_logits_processor(
|
|
self,
|
|
generation_config: GenerationConfig,
|
|
input_ids_seq_length: Optional[int] = None,
|
|
encoder_input_ids: torch.LongTensor = None,
|
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
device: Optional[str] = None,
|
|
model_kwargs: Optional[dict[str, Any]] = None,
|
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
) -> LogitsProcessorList:
|
|
"""
|
|
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
|
|
instances used to modify the scores of the language model head.
|
|
"""
|
|
# instantiate processors list
|
|
processors = LogitsProcessorList()
|
|
if logits_processor is None:
|
|
logits_processor = []
|
|
|
|
if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
|
|
processors.append(
|
|
UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
|
generation_config.guidance_scale,
|
|
self,
|
|
unconditional_ids=negative_prompt_ids,
|
|
unconditional_attention_mask=negative_prompt_attention_mask,
|
|
use_cache=generation_config.use_cache,
|
|
)
|
|
)
|
|
if generation_config.sequence_bias is not None:
|
|
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
|
|
|
|
if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
|
|
processors.append(
|
|
HammingDiversityLogitsProcessor(
|
|
diversity_penalty=generation_config.diversity_penalty,
|
|
num_beams=generation_config.num_beams,
|
|
num_beam_groups=generation_config.num_beam_groups,
|
|
)
|
|
)
|
|
if (
|
|
generation_config.encoder_repetition_penalty is not None
|
|
and generation_config.encoder_repetition_penalty != 1.0
|
|
):
|
|
if len(encoder_input_ids.shape) == 2:
|
|
processors.append(
|
|
EncoderRepetitionPenaltyLogitsProcessor(
|
|
penalty=generation_config.encoder_repetition_penalty,
|
|
encoder_input_ids=encoder_input_ids,
|
|
)
|
|
)
|
|
else:
|
|
warnings.warn(
|
|
"Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to "
|
|
"`generate`, ignoring the argument.",
|
|
UserWarning,
|
|
)
|
|
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
|
|
processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
|
|
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
|
|
processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
|
|
if (
|
|
generation_config.encoder_no_repeat_ngram_size is not None
|
|
and generation_config.encoder_no_repeat_ngram_size > 0
|
|
):
|
|
if len(encoder_input_ids.shape) == 2:
|
|
processors.append(
|
|
EncoderNoRepeatNGramLogitsProcessor(
|
|
generation_config.encoder_no_repeat_ngram_size,
|
|
encoder_input_ids,
|
|
)
|
|
)
|
|
else:
|
|
warnings.warn(
|
|
"Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to "
|
|
"`generate`, ignoring the argument.",
|
|
UserWarning,
|
|
)
|
|
if generation_config.bad_words_ids is not None:
|
|
processors.append(
|
|
NoBadWordsLogitsProcessor(
|
|
generation_config.bad_words_ids,
|
|
generation_config._eos_token_tensor,
|
|
)
|
|
)
|
|
if (
|
|
generation_config.min_length is not None
|
|
and getattr(generation_config, "_eos_token_tensor", None) is not None
|
|
and generation_config.min_length > 0
|
|
):
|
|
processors.append(
|
|
MinLengthLogitsProcessor(
|
|
generation_config.min_length,
|
|
generation_config._eos_token_tensor,
|
|
device=device,
|
|
)
|
|
)
|
|
if (
|
|
generation_config.min_new_tokens is not None
|
|
and getattr(generation_config, "_eos_token_tensor", None) is not None
|
|
and generation_config.min_new_tokens > 0
|
|
):
|
|
processors.append(
|
|
MinNewTokensLengthLogitsProcessor(
|
|
input_ids_seq_length,
|
|
generation_config.min_new_tokens,
|
|
generation_config._eos_token_tensor,
|
|
device=device,
|
|
)
|
|
)
|
|
if prefix_allowed_tokens_fn is not None:
|
|
processors.append(
|
|
PrefixConstrainedLogitsProcessor(
|
|
prefix_allowed_tokens_fn,
|
|
generation_config.num_beams // generation_config.num_beam_groups,
|
|
)
|
|
)
|
|
if generation_config.forced_bos_token_id is not None:
|
|
processors.append(
|
|
ForcedBOSTokenLogitsProcessor(
|
|
generation_config.forced_bos_token_id,
|
|
)
|
|
)
|
|
if generation_config.forced_eos_token_id is not None:
|
|
processors.append(
|
|
ForcedEOSTokenLogitsProcessor(
|
|
generation_config.max_length,
|
|
generation_config.forced_eos_token_id,
|
|
device=device,
|
|
)
|
|
)
|
|
if generation_config.remove_invalid_values is True:
|
|
processors.append(InfNanRemoveLogitsProcessor())
|
|
if generation_config.exponential_decay_length_penalty is not None:
|
|
processors.append(
|
|
ExponentialDecayLengthPenalty(
|
|
generation_config.exponential_decay_length_penalty,
|
|
generation_config._eos_token_tensor,
|
|
input_ids_seq_length,
|
|
)
|
|
)
|
|
if generation_config.suppress_tokens is not None:
|
|
processors.append(
|
|
SuppressTokensLogitsProcessor(
|
|
generation_config.suppress_tokens,
|
|
device=device,
|
|
)
|
|
)
|
|
if generation_config.begin_suppress_tokens is not None:
|
|
begin_index = input_ids_seq_length
|
|
begin_index = (
|
|
begin_index
|
|
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
|
else begin_index + 1
|
|
)
|
|
processors.append(
|
|
SuppressTokensAtBeginLogitsProcessor(
|
|
generation_config.begin_suppress_tokens,
|
|
begin_index,
|
|
device=device,
|
|
)
|
|
)
|
|
|
|
# TODO (joao): find a strategy to specify the order of the processors
|
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
|
|
|
# Processors previously known as `LogitsWarpers`, only applied with sampling strategies
|
|
if generation_config.do_sample:
|
|
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
|
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
|
if generation_config.num_beams > 1:
|
|
if isinstance(generation_config._eos_token_tensor, list):
|
|
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
|
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
|
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
|
else:
|
|
min_tokens_to_keep = 2
|
|
else:
|
|
min_tokens_to_keep = 1
|
|
|
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
|
# all samplers can be found in `generation_utils_samplers.py`
|
|
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
|
processors.append(TemperatureLogitsWarper(generation_config.temperature))
|
|
if generation_config.top_k is not None and generation_config.top_k != 0:
|
|
processors.append(
|
|
TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
|
|
)
|
|
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
|
processors.append(
|
|
TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
|
|
)
|
|
if generation_config.min_p is not None:
|
|
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
|
|
processors.append(
|
|
MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)
|
|
)
|
|
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
|
processors.append(
|
|
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
|
|
)
|
|
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
|
|
processors.append(
|
|
EpsilonLogitsWarper(
|
|
epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep
|
|
)
|
|
)
|
|
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
|
processors.append(
|
|
EtaLogitsWarper(
|
|
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
|
|
)
|
|
)
|
|
|
|
# Watermarking should be after all logits processing is finished (see #34630)
|
|
if generation_config.watermarking_config is not None:
|
|
processors.append(
|
|
generation_config.watermarking_config.construct_processor(
|
|
self.config.get_text_config().vocab_size, device
|
|
)
|
|
)
|
|
|
|
# `LogitNormalization` should always be the last logit processor, when present
|
|
if generation_config.renormalize_logits is True:
|
|
processors.append(LogitNormalization())
|
|
return processors
|
|
|
|
def _get_stopping_criteria(
|
|
self,
|
|
generation_config: GenerationConfig,
|
|
stopping_criteria: Optional[StoppingCriteriaList],
|
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
|
**kwargs,
|
|
) -> StoppingCriteriaList:
|
|
criteria = StoppingCriteriaList()
|
|
if generation_config.max_length is not None:
|
|
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
|
criteria.append(
|
|
MaxLengthCriteria(
|
|
max_length=generation_config.max_length,
|
|
max_position_embeddings=max_position_embeddings,
|
|
)
|
|
)
|
|
if generation_config.max_time is not None:
|
|
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
|
|
if generation_config.stop_strings is not None:
|
|
if tokenizer is None:
|
|
raise ValueError(
|
|
"There are one or more stop strings, either in the arguments to `generate` or in the "
|
|
"model's generation config, but we could not locate a tokenizer. When generating with "
|
|
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
|
|
)
|
|
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
|
|
if generation_config._eos_token_tensor is not None:
|
|
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
|
|
if (
|
|
generation_config.is_assistant
|
|
and generation_config.assistant_confidence_threshold is not None
|
|
and generation_config.assistant_confidence_threshold > 0
|
|
):
|
|
criteria.append(
|
|
ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)
|
|
)
|
|
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
|
return criteria
|
|
|
|
def _merge_criteria_processor_list(
|
|
self,
|
|
default_list: Union[LogitsProcessorList, StoppingCriteriaList],
|
|
custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
|
|
) -> Union[LogitsProcessorList, StoppingCriteriaList]:
|
|
"""
|
|
Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same
|
|
processor/criteria is present on both lists, use the user-defined one.
|
|
|
|
(Note: up to v4.49.0, this function threw an exception is the same logit processor was found twice.)
|
|
"""
|
|
if len(custom_list) == 0:
|
|
return default_list
|
|
|
|
final_list = type(default_list)()
|
|
for default in default_list:
|
|
using_custom = False
|
|
for custom in custom_list:
|
|
if type(custom) is type(default):
|
|
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
|
|
logger.warning_once(
|
|
f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it "
|
|
f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} "
|
|
f"will take precedence. Please check the docstring of {type(custom)} to see related "
|
|
"`.generate()` flags."
|
|
)
|
|
final_list.append(custom)
|
|
using_custom = True
|
|
break
|
|
if not using_custom:
|
|
final_list.append(default)
|
|
|
|
for custom in custom_list:
|
|
if custom not in final_list:
|
|
final_list.append(custom)
|
|
return final_list
|
|
|
|
def compute_transition_scores(
|
|
self,
|
|
sequences: torch.Tensor,
|
|
scores: tuple[torch.Tensor],
|
|
beam_indices: Optional[torch.Tensor] = None,
|
|
normalize_logits: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was
|
|
used). This is a convenient method to quickly obtain the scores of the selected tokens at generation time.
|
|
|
|
Parameters:
|
|
sequences (`torch.LongTensor`):
|
|
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or
|
|
shorter if all batches finished early due to the `eos_token_id`.
|
|
scores (`tuple(torch.FloatTensor)`):
|
|
Transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
|
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
|
|
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
|
beam_indices (`torch.LongTensor`, *optional*):
|
|
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
|
`(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at
|
|
generate-time.
|
|
normalize_logits (`bool`, *optional*, defaults to `False`):
|
|
Whether to normalize the logits (which, for legacy reasons, may be unnormalized).
|
|
|
|
Return:
|
|
`torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing
|
|
the transition scores (logits)
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
|
|
>>> import numpy as np
|
|
|
|
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
|
>>> tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
>>> inputs = tokenizer(["Today is"], return_tensors="pt")
|
|
|
|
>>> # Example 1: Print the scores for each token generated with Greedy Search
|
|
>>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
|
|
>>> transition_scores = model.compute_transition_scores(
|
|
... outputs.sequences, outputs.scores, normalize_logits=True
|
|
... )
|
|
>>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for
|
|
>>> # encoder-decoder models, like BART or T5.
|
|
>>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
|
|
>>> generated_tokens = outputs.sequences[:, input_length:]
|
|
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
|
|
... # | token | token string | log probability | probability
|
|
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
|
|
| 262 | the | -1.414 | 24.33%
|
|
| 1110 | day | -2.609 | 7.36%
|
|
| 618 | when | -2.010 | 13.40%
|
|
| 356 | we | -1.859 | 15.58%
|
|
| 460 | can | -2.508 | 8.14%
|
|
|
|
>>> # Example 2: Reconstruct the sequence scores from Beam Search
|
|
>>> outputs = model.generate(
|
|
... **inputs,
|
|
... max_new_tokens=5,
|
|
... num_beams=4,
|
|
... num_return_sequences=4,
|
|
... return_dict_in_generate=True,
|
|
... output_scores=True,
|
|
... )
|
|
>>> transition_scores = model.compute_transition_scores(
|
|
... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
|
|
... )
|
|
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
|
|
>>> # Tip 1: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
|
|
>>> # use case, you might want to recompute it with `normalize_logits=True`.
|
|
>>> # Tip 2: the output length does NOT include the input length
|
|
>>> output_length = np.sum(transition_scores.numpy() < 0, axis=1)
|
|
>>> length_penalty = model.generation_config.length_penalty
|
|
>>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)
|
|
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
|
|
True
|
|
```"""
|
|
# 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent
|
|
# to a beam search approach were the first (and only) beam is always selected
|
|
if beam_indices is None:
|
|
beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device)
|
|
beam_indices = beam_indices.expand(-1, len(scores))
|
|
|
|
# 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being
|
|
# seq_len - input_length
|
|
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
|
|
|
|
# 3. Optionally normalize the logits (across the vocab dimension)
|
|
if normalize_logits:
|
|
scores = scores.reshape(-1, self.config.get_text_config().vocab_size, scores.shape[-1])
|
|
scores = torch.nn.functional.log_softmax(scores, dim=1)
|
|
scores = scores.reshape(-1, scores.shape[-1])
|
|
|
|
# 4. cut beam_indices to longest beam length
|
|
beam_indices_mask = beam_indices < 0
|
|
max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
|
|
beam_indices = beam_indices.clone()[:, :max_beam_length]
|
|
beam_indices_mask = beam_indices_mask[:, :max_beam_length]
|
|
|
|
# 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards
|
|
beam_indices[beam_indices_mask] = 0
|
|
|
|
# 6. multiply beam_indices with vocab size to gather correctly from scores
|
|
beam_sequence_indices = beam_indices * self.config.get_text_config().vocab_size
|
|
|
|
# 7. Define which indices contributed to scores
|
|
cut_idx = sequences.shape[-1] - max_beam_length
|
|
indices = sequences[:, cut_idx:] + beam_sequence_indices
|
|
|
|
# 8. Compute scores
|
|
transition_scores = scores.gather(0, indices)
|
|
|
|
# 9. Mask out transition_scores of beams that stopped early
|
|
transition_scores[beam_indices_mask] = 0
|
|
|
|
return transition_scores
|
|
|
|
def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
|
|
if assistant_model is None:
|
|
return
|
|
|
|
if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
|
|
attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
|
|
attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check]
|
|
are_equal = all(
|
|
getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check
|
|
)
|
|
if not are_equal:
|
|
raise ValueError(
|
|
"The main model and the assistant don't have compatible encoder-dependent input shapes. "
|
|
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
|
|
)
|
|
|
|
doc_reference = (
|
|
"(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)"
|
|
)
|
|
if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
|
|
if assistant_tokenizer is not None:
|
|
raise ValueError(
|
|
f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}."
|
|
)
|
|
else:
|
|
if tokenizer is None or assistant_tokenizer is None:
|
|
raise ValueError(
|
|
f"The main and assistant models have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}."
|
|
)
|
|
|
|
def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
|
|
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
|
# Excludes arguments that are handled before calling any model function
|
|
if self.config.is_encoder_decoder:
|
|
for key in ["decoder_input_ids"]:
|
|
model_kwargs.pop(key, None)
|
|
|
|
unused_model_args = []
|
|
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
|
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
|
|
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
|
|
if "kwargs" in model_args or "model_kwargs" in model_args:
|
|
model_args |= set(inspect.signature(self.forward).parameters)
|
|
|
|
# Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
|
|
if self.config.is_encoder_decoder:
|
|
base_model = getattr(self, self.base_model_prefix, None)
|
|
|
|
# allow encoder kwargs
|
|
encoder = getattr(self, "encoder", None)
|
|
# `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
|
|
# Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
|
|
# TODO: A better way to handle this.
|
|
if encoder is None and base_model is not None:
|
|
encoder = getattr(base_model, "encoder", None)
|
|
|
|
if encoder is not None:
|
|
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
|
|
model_args |= encoder_model_args
|
|
|
|
# allow decoder kwargs
|
|
decoder = getattr(self, "decoder", None)
|
|
if decoder is None and base_model is not None:
|
|
decoder = getattr(base_model, "decoder", None)
|
|
|
|
if decoder is not None:
|
|
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
|
|
model_args |= {f"decoder_{x}" for x in decoder_model_args}
|
|
|
|
for key, value in model_kwargs.items():
|
|
if value is not None and key not in model_args:
|
|
unused_model_args.append(key)
|
|
|
|
if unused_model_args:
|
|
raise ValueError(
|
|
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
|
|
" generate arguments will also show up in this list)"
|
|
)
|
|
|
|
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
|
"""Performs validation related to the resulting generated length"""
|
|
# 1. Max length warnings related to poor parameterization
|
|
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
|
|
# 20 is the default max_length of the generation config
|
|
warnings.warn(
|
|
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
|
|
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
|
|
"generation.",
|
|
UserWarning,
|
|
)
|
|
if input_ids_length >= generation_config.max_length:
|
|
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
|
raise ValueError(
|
|
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
|
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
|
" increasing `max_length` or, better yet, setting `max_new_tokens`."
|
|
)
|
|
|
|
# 2. Min length warnings due to unfeasible parameter combinations
|
|
min_length_error_suffix = (
|
|
" Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
|
|
"increase the maximum length."
|
|
)
|
|
if has_default_max_length:
|
|
min_length_error_suffix += (
|
|
f" Note that `max_length` is set to {generation_config.max_length}, its default value."
|
|
)
|
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
|
warnings.warn(
|
|
f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
|
|
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
|
|
UserWarning,
|
|
)
|
|
if generation_config.min_new_tokens is not None:
|
|
min_length = generation_config.min_new_tokens + input_ids_length
|
|
if min_length > generation_config.max_length:
|
|
warnings.warn(
|
|
f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
|
|
f"added to the prompt length ({input_ids_length}), is larger than"
|
|
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
|
|
UserWarning,
|
|
)
|
|
|
|
def _prepare_generated_length(
|
|
self,
|
|
generation_config,
|
|
has_default_max_length,
|
|
has_default_min_length,
|
|
model_input_name,
|
|
input_ids_length,
|
|
inputs_tensor,
|
|
):
|
|
"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""
|
|
|
|
if generation_config.max_new_tokens is not None:
|
|
if not has_default_max_length and generation_config.max_length is not None:
|
|
logger.warning(
|
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
|
"Please refer to the documentation for more information. "
|
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
|
)
|
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
|
|
|
# if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
|
|
# otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length``
|
|
elif (
|
|
model_input_name == "inputs_embeds"
|
|
and input_ids_length != inputs_tensor.shape[1]
|
|
and not self.config.is_encoder_decoder
|
|
):
|
|
generation_config.max_length -= inputs_tensor.shape[1]
|
|
elif has_default_max_length: # by default let's always generate 20 new tokens
|
|
if generation_config.max_length == GenerationConfig().max_length:
|
|
generation_config.max_length = generation_config.max_length + input_ids_length
|
|
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
|
if max_position_embeddings is not None:
|
|
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
|
|
|
# same for min length
|
|
if generation_config.min_new_tokens is not None:
|
|
if not has_default_min_length:
|
|
logger.warning(
|
|
f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="
|
|
f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. "
|
|
"Please refer to the documentation for more information. "
|
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
|
)
|
|
generation_config.min_length = generation_config.min_new_tokens + input_ids_length
|
|
|
|
elif (
|
|
model_input_name == "inputs_embeds"
|
|
and input_ids_length != inputs_tensor.shape[1]
|
|
and not self.config.is_encoder_decoder
|
|
):
|
|
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
|
|
|
|
return generation_config
|
|
|
|
def _prepare_generation_config(
|
|
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict
|
|
) -> tuple[GenerationConfig, dict]:
|
|
"""
|
|
Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
|
function handles retrocompatibility with respect to configuration files.
|
|
"""
|
|
# parameterization priority:
|
|
# kwargs > non-global default values in `generation_config` > `model.generation_config` > GenerationConfig()
|
|
# TODO (joao): per-model generation config classes.
|
|
|
|
using_model_generation_config = False
|
|
if generation_config is None:
|
|
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
|
# the following conditions must be met
|
|
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
|
# 2) the generation config must have seen no modification since its creation (the hash is the same);
|
|
# 3) there are non-default generation parameters in the model config.
|
|
# 4) the user must have set new generation parameters in the model config.
|
|
if (
|
|
self.generation_config._from_model_config # 1)
|
|
and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
|
|
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
|
|
):
|
|
new_generation_config = GenerationConfig.from_model_config(self.config)
|
|
if new_generation_config != self.generation_config: # 4)
|
|
warnings.warn(
|
|
"You have modified the pretrained model configuration to control generation. This is a"
|
|
" deprecated strategy to control generation and will be removed in v5."
|
|
" Please use and modify the model generation configuration (see"
|
|
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
|
|
UserWarning,
|
|
)
|
|
self.generation_config = new_generation_config
|
|
|
|
generation_config = self.generation_config
|
|
using_model_generation_config = True
|
|
|
|
# `torch.export.export` usually raises an exception if it is called
|
|
# with ``strict=True``. deepcopy can only be processed if ``strict=False``.
|
|
generation_config = copy.deepcopy(generation_config)
|
|
|
|
if not using_model_generation_config:
|
|
# If `generation_config` is provided:
|
|
# - `use_model_defaults`: let's fallback ALL default values to the model's generation config
|
|
# - otherwise: legacy behavior, let's just make sure we have the tokens defined
|
|
model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version)
|
|
if use_model_defaults is True or (
|
|
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
|
|
):
|
|
modified_values = {}
|
|
global_default_generation_config = GenerationConfig()
|
|
model_generation_config = self.generation_config
|
|
# we iterate over the model's generation config: it may hold custom keys, which we'll want to copy
|
|
for key, model_gen_config_value in model_generation_config.__dict__.items():
|
|
if key.startswith("_") or key == "transformers_version": # metadata
|
|
continue
|
|
global_default_value = getattr(global_default_generation_config, key, None)
|
|
custom_gen_config_value = getattr(generation_config, key, None)
|
|
if (
|
|
custom_gen_config_value == global_default_value
|
|
and model_gen_config_value != global_default_value
|
|
):
|
|
modified_values[key] = model_gen_config_value
|
|
setattr(generation_config, key, model_gen_config_value)
|
|
# edge case: we may set `temperature=0.0` and `do_sample=False`, but the model defaults to
|
|
# `do_sample=True`
|
|
if generation_config.temperature == 0.0:
|
|
generation_config.do_sample = False
|
|
if use_model_defaults is None and len(modified_values) > 0:
|
|
logger.warning_once(
|
|
f"`generation_config` default values have been modified to match model-specific defaults: "
|
|
f"{modified_values}. If this is not desired, please set these values explicitly."
|
|
)
|
|
else:
|
|
if generation_config.bos_token_id is None:
|
|
generation_config.bos_token_id = self.generation_config.bos_token_id
|
|
if generation_config.eos_token_id is None:
|
|
generation_config.eos_token_id = self.generation_config.eos_token_id
|
|
if generation_config.pad_token_id is None:
|
|
generation_config.pad_token_id = self.generation_config.pad_token_id
|
|
if generation_config.decoder_start_token_id is None:
|
|
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
|
|
|
|
# Finally, apply any passed kwargs
|
|
model_kwargs = generation_config.update(**kwargs)
|
|
|
|
return generation_config, model_kwargs
|
|
|
|
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
|
|
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
|
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
|
|
if "cache_position" in model_kwargs and model_kwargs["cache_position"]:
|
|
return model_kwargs
|
|
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
|
|
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
|
|
elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
|
|
cache_position = (
|
|
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
|
|
)
|
|
else:
|
|
cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1
|
|
|
|
past_length = 0
|
|
if model_kwargs.get("past_key_values") is not None:
|
|
cache = model_kwargs["past_key_values"]
|
|
past_length = 0
|
|
if not isinstance(cache, Cache):
|
|
past_length = cache[0][0].shape[2]
|
|
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
|
|
past_length = cache.get_seq_length()
|
|
|
|
cache_position = cache_position[past_length:]
|
|
|
|
model_kwargs["cache_position"] = cache_position
|
|
return model_kwargs
|
|
|
|
def _get_layer_device_map_for_cache_init(self) -> Optional[dict[int, Union[str, int]]]:
|
|
"""
|
|
Returns the device map for each decoder layer, to allocate the cache on the right device.
|
|
Inspired from `dispatch_model` in accelerate.
|
|
"""
|
|
execution_device_map = None
|
|
|
|
if hasattr(self, "hf_device_map"):
|
|
if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}:
|
|
main_device = "cpu"
|
|
else:
|
|
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
|
|
execution_device_map = {
|
|
name: main_device if device in ["cpu", "disk"] else device
|
|
for name, device in self.hf_device_map.items()
|
|
}
|
|
|
|
# No `execution_device_map` -> rely on `self.device` to allocate the cache
|
|
if execution_device_map is None:
|
|
return None
|
|
|
|
# Single device for all layers
|
|
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
|
if len(execution_device_map) == 1 and "" in execution_device_map:
|
|
return dict.fromkeys(range(num_hidden_layers), execution_device_map[""])
|
|
|
|
# Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device.
|
|
layer_device_map = {}
|
|
# Case 1: The model has a `get_decoder` method, we can use it to find the decoder name.
|
|
if hasattr(self, "get_decoder"):
|
|
decoder_name = None
|
|
for name, module in self.named_modules():
|
|
if module is self.get_decoder():
|
|
decoder_name = name
|
|
break
|
|
if decoder_name is None:
|
|
raise RuntimeError(
|
|
"`model.get_decoder()` is not returning a named module of the model. This is unexpected, please "
|
|
"open an issue on GitHub."
|
|
)
|
|
|
|
decoder_mapped_modules = [
|
|
module_name for module_name in execution_device_map.keys() if decoder_name in module_name
|
|
]
|
|
# The decoder name may be present in `execution_device_map` in two forms:
|
|
# a) each layer has a device mapping
|
|
if len(decoder_mapped_modules) >= num_hidden_layers:
|
|
for idx in range(num_hidden_layers):
|
|
for module_name in decoder_mapped_modules:
|
|
if f".{idx}." in f"{module_name}.":
|
|
layer_device_map[idx] = execution_device_map[module_name]
|
|
break
|
|
|
|
# b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map,
|
|
# then the mapping is done in a parent module
|
|
else:
|
|
while True:
|
|
if decoder_name in execution_device_map:
|
|
layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name])
|
|
break
|
|
elif "." in decoder_name:
|
|
decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module
|
|
else:
|
|
raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map")
|
|
|
|
# Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index)
|
|
else:
|
|
for layer in execution_device_map:
|
|
for idx in range(num_hidden_layers):
|
|
if f".{idx}." in f"{layer}.":
|
|
layer_device_map[idx] = execution_device_map[layer]
|
|
break
|
|
|
|
for idx in range(num_hidden_layers):
|
|
if idx not in layer_device_map:
|
|
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
|
return layer_device_map
|
|
|
|
def _get_cache(
|
|
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
|
) -> Cache:
|
|
"""
|
|
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
|
new `generate` call requires a larger cache or uses a different batch size.
|
|
|
|
Returns the resulting cache object.
|
|
"""
|
|
if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
|
|
cache_implementation = "hybrid_chunked"
|
|
|
|
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
|
requires_cross_attention_cache = (
|
|
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
|
)
|
|
|
|
if hasattr(self, "_cache"):
|
|
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
|
|
|
|
if cache_implementation == "sliding_window":
|
|
max_cache_len = min(self.config.sliding_window, max_cache_len)
|
|
|
|
need_new_cache = (
|
|
not hasattr(self, "_cache")
|
|
or (not isinstance(cache_to_check, cache_cls))
|
|
or cache_to_check.max_batch_size != batch_size
|
|
or isinstance(
|
|
cache_to_check, (HybridChunkedCache, OffloadedHybridCache)
|
|
) # due to internal slicing, we always re-init
|
|
or cache_to_check.max_cache_len < max_cache_len
|
|
)
|
|
|
|
if requires_cross_attention_cache and hasattr(self, "_cache"):
|
|
need_new_cache = (
|
|
need_new_cache
|
|
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
|
|
)
|
|
|
|
if need_new_cache:
|
|
if hasattr(self.config, "_pre_quantization_dtype"):
|
|
cache_dtype = self.config._pre_quantization_dtype
|
|
else:
|
|
cache_dtype = self.dtype
|
|
|
|
layer_device_map = self._get_layer_device_map_for_cache_init()
|
|
cache_kwargs = {
|
|
"config": self.config.get_text_config(),
|
|
"max_batch_size": batch_size,
|
|
"max_cache_len": max_cache_len,
|
|
"dtype": cache_dtype,
|
|
"device": device,
|
|
"layer_device_map": layer_device_map,
|
|
}
|
|
if cache_implementation in ["static", "hybrid", "offloaded_static"]:
|
|
cache_kwargs.update({"tp_size": self.tp_size})
|
|
|
|
self._cache = cache_cls(**cache_kwargs)
|
|
if requires_cross_attention_cache:
|
|
encoder_kwargs = cache_kwargs.copy()
|
|
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
|
|
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
|
|
else:
|
|
self._cache.reset()
|
|
return self._cache
|
|
|
|
@classmethod
|
|
def _supports_default_dynamic_cache(cls) -> bool:
|
|
"""
|
|
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
|
|
This adds exception for some models like `Mamba` models which use their own caches
|
|
and do not need to initialize the Cache in advance in order to save memory (because no back and forth
|
|
`to_legacy_cache` and `from_legacy_cache` will be performed for mamba-based models).
|
|
"""
|
|
# NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name
|
|
return not cls._is_stateful and all(
|
|
special_model_name not in cls.__name__.lower()
|
|
for special_model_name in [
|
|
"reformer",
|
|
"minimax",
|
|
"xlnet",
|
|
"lfm2",
|
|
]
|
|
)
|
|
|
|
def _prepare_cache_for_generation(
|
|
self,
|
|
generation_config: GenerationConfig,
|
|
model_kwargs: dict,
|
|
assistant_model: "PreTrainedModel",
|
|
batch_size: int,
|
|
max_cache_length: int,
|
|
device: torch.device,
|
|
) -> bool:
|
|
"""
|
|
Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
|
|
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
|
"""
|
|
|
|
is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"])
|
|
cache_name = "past_key_values" if not is_hybrid_cache else "cache_params"
|
|
|
|
requires_cross_attention_cache = (
|
|
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
|
)
|
|
|
|
# Quick escape route 1: if the user specifies a cache, we only need to:
|
|
# a) check for conflicting `generate` arguments
|
|
# b) convert to the new cache format (if the user passes a legacy cache and model supports it)
|
|
user_defined_cache = model_kwargs.get(cache_name)
|
|
if user_defined_cache is not None:
|
|
if generation_config.cache_implementation is not None:
|
|
raise ValueError(
|
|
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
|
|
"Cache object) is unsupported. Please use only one of the two."
|
|
)
|
|
if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():
|
|
model_kwargs[cache_name] = (
|
|
DynamicCache.from_legacy_cache(user_defined_cache)
|
|
if not requires_cross_attention_cache
|
|
else EncoderDecoderCache.from_legacy_cache(user_defined_cache)
|
|
)
|
|
return
|
|
|
|
# Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in
|
|
# `generation_config.validate()`)
|
|
if generation_config.use_cache is False:
|
|
return
|
|
|
|
# Quick escape route 3: model that only supports legacy caches or models that supply it in `prepare_inputs_for_generation` (mamba, zamba, ...)
|
|
if not self._supports_default_dynamic_cache():
|
|
if generation_config.cache_implementation is not None:
|
|
warnings.warn(
|
|
"This model does not support `Cache` instances, it only supports the legacy cache format (tuple "
|
|
f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be "
|
|
"ignored.",
|
|
UserWarning,
|
|
)
|
|
return
|
|
|
|
# Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
|
|
|
|
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
|
|
# which is only supported in dynamic caches atm
|
|
if assistant_model is not None and generation_config.cache_implementation is not None:
|
|
logger.warning_once(
|
|
"An assistant model is provided, using a dynamic cache instead of a cache of type="
|
|
f"'{generation_config.cache_implementation}'."
|
|
)
|
|
generation_config.cache_implementation = None
|
|
|
|
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
|
|
self.config.get_text_config(decoder=True), "cache_implementation", None
|
|
)
|
|
if generation_config.cache_implementation is not None:
|
|
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
|
if generation_config.cache_implementation == "static" and not self._can_compile_fullgraph:
|
|
raise ValueError(
|
|
"This model does not support `cache_implementation='static'`. Please check the following "
|
|
"issue: https://github.com/huggingface/transformers/issues/28981"
|
|
)
|
|
model_kwargs[cache_name] = self._get_cache(
|
|
cache_implementation=generation_config.cache_implementation,
|
|
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
|
|
max_cache_len=max_cache_length,
|
|
device=device,
|
|
model_kwargs=model_kwargs,
|
|
)
|
|
elif generation_config.cache_implementation == "quantized":
|
|
if self.config.is_encoder_decoder or not self._supports_default_dynamic_cache():
|
|
raise ValueError(
|
|
"This model does not support the quantized cache. If you want your model to support quantized "
|
|
"cache, please open an issue and tag @zucchini-nlp."
|
|
)
|
|
|
|
cache_config = (
|
|
generation_config.cache_config
|
|
if generation_config.cache_config is not None
|
|
else {"backend": "quanto"}
|
|
)
|
|
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config["backend"]]
|
|
|
|
if cache_config["backend"] == "quanto" and not is_optimum_quanto_available():
|
|
raise ImportError(
|
|
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. "
|
|
"Please install it via with `pip install optimum-quanto`"
|
|
)
|
|
elif cache_config["backend"] == "HQQ" and not is_hqq_available():
|
|
raise ImportError(
|
|
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
|
|
"Please install it via with `pip install hqq`"
|
|
)
|
|
|
|
model_kwargs[cache_name] = cache_class(**cache_config)
|
|
elif generation_config.cache_implementation == "offloaded":
|
|
model_kwargs[cache_name] = OffloadedCache()
|
|
elif generation_config.cache_implementation == "dynamic":
|
|
model_kwargs[cache_name] = DynamicCache()
|
|
|
|
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
|
# keeps copying the cache thus using much more memory
|
|
else:
|
|
model_kwargs[cache_name] = (
|
|
DynamicCache()
|
|
if not requires_cross_attention_cache
|
|
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
|
)
|
|
|
|
def _supports_logits_to_keep(self) -> bool:
|
|
"""
|
|
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
|
|
to save memory. Checking it in this way allows to avoid using a new model attribute.
|
|
"""
|
|
return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
|
|
|
|
def _prepare_special_tokens(
|
|
self,
|
|
generation_config: GenerationConfig,
|
|
kwargs_has_attention_mask: Optional[bool] = None,
|
|
device: Optional[Union[torch.device, str]] = None,
|
|
):
|
|
"""
|
|
Prepares the special tokens for generation, overwriting the generation config with their processed versions
|
|
converted to tensor.
|
|
|
|
Note that `generation_config` is changed in place and stops being serializable after this method is called.
|
|
That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
|
|
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
|
"""
|
|
|
|
# Convert special tokens to tensors
|
|
def _tensor_or_none(token, device=None):
|
|
if token is None:
|
|
return token
|
|
|
|
device = device if device is not None else self.device
|
|
if isinstance(token, torch.Tensor):
|
|
return token.to(device)
|
|
return torch.tensor(token, device=device, dtype=torch.long)
|
|
|
|
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
|
|
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
|
|
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
|
|
decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
|
|
|
|
# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
|
|
if self.config.is_encoder_decoder:
|
|
decoder_start_token_tensor = (
|
|
decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor
|
|
)
|
|
|
|
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
|
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
|
|
eos_token_tensor = eos_token_tensor.unsqueeze(0)
|
|
|
|
# Set pad token if unset (and there are conditions to do so)
|
|
if pad_token_tensor is None and eos_token_tensor is not None:
|
|
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
|
logger.warning(
|
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
|
)
|
|
pad_token_tensor = eos_token_tensor[0]
|
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
|
|
|
|
# Sanity checks/warnings
|
|
if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
|
|
raise ValueError(
|
|
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
|
)
|
|
if (
|
|
eos_token_tensor is not None
|
|
and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
|
|
):
|
|
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
|
logger.warning_once(
|
|
"The attention mask is not set and cannot be inferred from input because pad token is same as "
|
|
"eos token. As a consequence, you may observe unexpected behavior. Please pass your input's "
|
|
"`attention_mask` to obtain reliable results."
|
|
)
|
|
if eos_token_tensor is not None and (
|
|
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
|
|
):
|
|
logger.warning(
|
|
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation "
|
|
"will not stop until the maximum length is reached. Depending on other flags, it may even crash."
|
|
)
|
|
|
|
# Update generation config with the updated special tokens tensors
|
|
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
|
|
# (in their non-tensor form), in order to enable end-to-end compilation. See
|
|
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
|
|
generation_config._bos_token_tensor = bos_token_tensor
|
|
generation_config._eos_token_tensor = eos_token_tensor
|
|
generation_config._pad_token_tensor = pad_token_tensor
|
|
generation_config._decoder_start_token_tensor = decoder_start_token_tensor
|
|
|
|
def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: GenerationConfig) -> bool:
|
|
"""
|
|
Determines whether to trigger auto-compilation of the model's forward pass at generation time.
|
|
"""
|
|
# Override: honor `disable_compile` flag
|
|
if generation_config.disable_compile:
|
|
return False
|
|
|
|
# Base logic
|
|
valid_hardware = self.device.type == "cuda" or bool(
|
|
generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
|
|
)
|
|
using_compilable_cache = (
|
|
isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
|
|
)
|
|
# TODO @raushan `self._can_compile_fullgraph` can be removed and inferred from model arch (e.g. MoE doesn't support compile)
|
|
can_compile = valid_hardware and using_compilable_cache and self._can_compile_fullgraph
|
|
|
|
# Exception 1: Some quantization methods do not support compilation
|
|
if getattr(self, "hf_quantizer", None) is not None:
|
|
can_compile &= self.hf_quantizer.is_compileable
|
|
|
|
if hasattr(self, "hf_device_map"):
|
|
all_model_devices = set(self.hf_device_map.values())
|
|
# Exception 2: Don't compile if the model is using CPU offload (as of April 2025, this results in a crash)
|
|
has_cpu_offload = "cpu" in all_model_devices and len(all_model_devices) > 1
|
|
can_compile &= not has_cpu_offload
|
|
|
|
# Exception 3: Disk offload is not supported for compilation
|
|
has_disk_offload = "disk" in all_model_devices
|
|
can_compile &= not has_disk_offload
|
|
|
|
# Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn
|
|
# them
|
|
if generation_config.compile_config is not None and not can_compile:
|
|
logger.warning_once(
|
|
"You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation "
|
|
"will be skipped."
|
|
)
|
|
|
|
return can_compile
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
inputs: Optional[torch.Tensor] = None,
|
|
generation_config: Optional[GenerationConfig] = None,
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
|
synced_gpus: Optional[bool] = None,
|
|
assistant_model: Optional["PreTrainedModel"] = None,
|
|
streamer: Optional["BaseStreamer"] = None,
|
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
use_model_defaults: Optional[bool] = None,
|
|
custom_generate: Optional[str] = None,
|
|
**kwargs,
|
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
r"""
|
|
|
|
Generates sequences of token ids for models with a language modeling head.
|
|
|
|
<Tip warning={true}>
|
|
|
|
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
|
|
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
|
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
|
|
|
|
For an overview of generation strategies and code examples, check out the [following
|
|
guide](../generation_strategies).
|
|
|
|
</Tip>
|
|
|
|
Parameters:
|
|
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
|
|
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
|
|
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
|
|
should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
|
|
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
|
|
generation_config ([`~generation.GenerationConfig`], *optional*):
|
|
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
|
passed to generate matching the attributes of `generation_config` will override them. If
|
|
`generation_config` is not provided, the default will be used, which has the following loading
|
|
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
|
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
|
default values, whose documentation should be checked to parameterize generation.
|
|
logits_processor (`LogitsProcessorList`, *optional*):
|
|
Custom logits processors that complement the default logits processors built from arguments and
|
|
generation config. If a logit processor is passed that is already created with the arguments or a
|
|
generation config an error is thrown. This feature is intended for advanced users.
|
|
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
|
Custom stopping criteria that complements the default stopping criteria built from arguments and a
|
|
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
|
generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
|
|
sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
|
|
intended for advanced users.
|
|
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
|
|
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
|
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
|
|
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
|
|
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
|
|
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
|
Retrieval](https://huggingface.co/papers/2010.00904).
|
|
synced_gpus (`bool`, *optional*):
|
|
Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
|
|
to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
|
|
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
|
|
assistant_model (`PreTrainedModel`, *optional*):
|
|
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
|
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model
|
|
is much faster than running generation with the model you're calling generate from. As such, the
|
|
assistant model should be much smaller.
|
|
streamer (`BaseStreamer`, *optional*):
|
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
|
negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
The negative prompt needed for some processors such as CFG. The batch size must match the input batch
|
|
size. This is an experimental feature, subject to breaking API changes in future versions.
|
|
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Attention_mask for `negative_prompt_ids`.
|
|
use_model_defaults (`bool`, *optional*):
|
|
When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
|
|
generation configuration (`model.generation_config`), as opposed to the global defaults
|
|
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
|
|
`True`.
|
|
custom_generate (`str`, *optional*):
|
|
A string containing the name of a huggingface.co repository. If provided, the custom `generate`
|
|
function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the
|
|
standard `generate` method. Note that the logic is for generation is entirely defined in that
|
|
repository, and the return type may be different from the standard `generate` method.
|
|
kwargs (`dict[str, Any]`, *optional*):
|
|
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
|
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
|
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
|
|
|
|
Return:
|
|
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
|
or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.
|
|
|
|
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
|
|
[`~utils.ModelOutput`] types are:
|
|
|
|
- [`~generation.GenerateDecoderOnlyOutput`],
|
|
- [`~generation.GenerateBeamDecoderOnlyOutput`]
|
|
|
|
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
|
|
[`~utils.ModelOutput`] types are:
|
|
|
|
- [`~generation.GenerateEncoderDecoderOutput`],
|
|
- [`~generation.GenerateBeamEncoderDecoderOutput`]
|
|
"""
|
|
# 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
|
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
|
if custom_generate is not None:
|
|
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
|
|
# they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
|
|
# trigger the custom generation. They can access to methods from `GenerationMixin` through `model`.
|
|
global_keys_to_exclude = {
|
|
"self",
|
|
"kwargs",
|
|
"global_keys_to_exclude",
|
|
"trust_remote_code",
|
|
"custom_generate",
|
|
}
|
|
generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
|
|
generate_arguments.update(kwargs)
|
|
|
|
custom_generate_function = self.load_custom_generate(
|
|
custom_generate, trust_remote_code=trust_remote_code, **kwargs
|
|
)
|
|
return custom_generate_function(model=self, **generate_arguments)
|
|
|
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
|
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
|
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
|
|
|
generation_config, model_kwargs = self._prepare_generation_config(
|
|
generation_config, use_model_defaults, **kwargs
|
|
)
|
|
self._validate_model_kwargs(model_kwargs.copy())
|
|
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
|
|
|
# 2. Set generation parameters if not already defined
|
|
if synced_gpus is None:
|
|
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
|
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
|
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
|
|
|
# 3. Define model inputs
|
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
|
inputs, generation_config.bos_token_id, model_kwargs
|
|
)
|
|
batch_size = inputs_tensor.shape[0]
|
|
|
|
device = inputs_tensor.device
|
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
|
|
|
# decoder-only models must use left-padding for batched generation.
|
|
if not self.config.is_encoder_decoder:
|
|
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
|
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
|
if (
|
|
generation_config._pad_token_tensor is not None
|
|
and batch_size > 1
|
|
and len(inputs_tensor.shape) == 2
|
|
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
|
|
):
|
|
logger.warning(
|
|
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
|
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
|
)
|
|
|
|
# 4. Define other model kwargs
|
|
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
|
|
# generating the first new token or not, and we only want to use the embeddings for the first new token)
|
|
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
|
generation_config.use_cache = True
|
|
|
|
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
|
inputs_tensor, generation_config, model_kwargs
|
|
)
|
|
elif kwargs_has_attention_mask:
|
|
# TODO (joao): generalize this check with other types of inputs
|
|
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
|
|
raise ValueError("`attention_mask` passed to `generate` must be 2D.")
|
|
|
|
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
|
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
|
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
|
inputs_tensor, model_kwargs, model_input_name, generation_config
|
|
)
|
|
|
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
|
if self.config.is_encoder_decoder:
|
|
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
|
batch_size=batch_size,
|
|
model_input_name=model_input_name,
|
|
model_kwargs=model_kwargs,
|
|
decoder_start_token_id=generation_config._decoder_start_token_tensor,
|
|
device=inputs_tensor.device,
|
|
)
|
|
else:
|
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
|
|
|
if generation_config.token_healing:
|
|
input_ids = self.heal_tokens(input_ids, tokenizer)
|
|
|
|
if streamer is not None:
|
|
streamer.put(input_ids.cpu())
|
|
|
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
|
input_ids_length = input_ids.shape[1]
|
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
|
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
|
generation_config = self._prepare_generated_length(
|
|
generation_config=generation_config,
|
|
has_default_max_length=has_default_max_length,
|
|
has_default_min_length=has_default_min_length,
|
|
model_input_name=model_input_name,
|
|
inputs_tensor=inputs_tensor,
|
|
input_ids_length=input_ids_length,
|
|
)
|
|
|
|
# If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
|
|
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
|
|
# dynamically overrides this value as it can need more than the last token logits
|
|
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
|
|
model_kwargs["logits_to_keep"] = 1
|
|
|
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
|
|
|
# 7. Prepare the cache.
|
|
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
|
# - different models have a different cache name expected by the model (default = "past_key_values")
|
|
# - `max_length`, prepared above, is used to determine the maximum cache length
|
|
max_cache_length = generation_config.max_length - 1
|
|
if (
|
|
inputs_tensor.shape[1] != input_ids_length
|
|
and model_input_name == "inputs_embeds"
|
|
and not self.config.is_encoder_decoder
|
|
):
|
|
max_cache_length += inputs_tensor.shape[1]
|
|
self._prepare_cache_for_generation(
|
|
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
|
|
)
|
|
|
|
# 8. determine generation mode
|
|
generation_mode = generation_config.get_generation_mode(assistant_model)
|
|
|
|
if streamer is not None and (generation_config.num_beams > 1):
|
|
raise ValueError(
|
|
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
|
)
|
|
|
|
if self.device.type != input_ids.device.type:
|
|
warnings.warn(
|
|
"You are calling .generate() with the `input_ids` being on a device type different"
|
|
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
|
|
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
|
|
" Please make sure that you have put `input_ids` to the"
|
|
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
|
|
" running `.generate()`.",
|
|
UserWarning,
|
|
)
|
|
|
|
# 9. prepare logits processors and stopping criteria
|
|
prepared_logits_processor = self._get_logits_processor(
|
|
generation_config=generation_config,
|
|
input_ids_seq_length=input_ids_length,
|
|
encoder_input_ids=inputs_tensor,
|
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
|
logits_processor=logits_processor,
|
|
device=inputs_tensor.device,
|
|
model_kwargs=model_kwargs,
|
|
negative_prompt_ids=negative_prompt_ids,
|
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
)
|
|
prepared_stopping_criteria = self._get_stopping_criteria(
|
|
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
|
)
|
|
|
|
# Set model_kwargs `use_cache` so we can use it later in forward runs
|
|
model_kwargs["use_cache"] = generation_config.use_cache
|
|
|
|
# 10. go into different generation modes
|
|
if generation_mode == GenerationMode.ASSISTED_GENERATION:
|
|
if generation_config.num_return_sequences > 1:
|
|
raise ValueError(
|
|
"num_return_sequences has to be 1 when doing assisted generate, "
|
|
f"but is {generation_config.num_return_sequences}."
|
|
)
|
|
if batch_size > 1:
|
|
raise ValueError("assisted generate is only supported for batch_size = 1")
|
|
if not model_kwargs["use_cache"]:
|
|
raise ValueError("assisted generate requires `use_cache=True`")
|
|
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
|
|
raise ValueError("assisted generate is not supported with Static cache classes`")
|
|
if self._is_stateful:
|
|
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
|
|
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
|
|
raise ValueError(
|
|
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
|
|
)
|
|
|
|
# 11. Get the candidate generator, given the parameterization
|
|
candidate_generator = self._get_candidate_generator(
|
|
generation_config=generation_config,
|
|
input_ids=input_ids,
|
|
inputs_tensor=inputs_tensor,
|
|
assistant_model=assistant_model,
|
|
logits_processor=logits_processor,
|
|
target_tokenizer=tokenizer,
|
|
assistant_tokenizer=assistant_tokenizer,
|
|
model_kwargs=model_kwargs,
|
|
)
|
|
|
|
# 12. run assisted generate
|
|
result = self._assisted_decoding(
|
|
input_ids,
|
|
candidate_generator=candidate_generator,
|
|
logits_processor=prepared_logits_processor,
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
generation_config=generation_config,
|
|
synced_gpus=synced_gpus,
|
|
streamer=streamer,
|
|
**model_kwargs,
|
|
)
|
|
elif generation_mode == GenerationMode.DOLA_GENERATION:
|
|
if not trust_remote_code:
|
|
logger.warning_once(
|
|
"DoLa Decoding is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
|
|
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
|
|
)
|
|
if self._is_stateful:
|
|
# DoLa decoding was not designed for stateful models, and would require some changes
|
|
raise ValueError(
|
|
f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
|
|
)
|
|
result = self._dola_decoding(
|
|
input_ids,
|
|
dola_layers=generation_config.dola_layers,
|
|
logits_processor=prepared_logits_processor,
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
generation_config=generation_config,
|
|
synced_gpus=synced_gpus,
|
|
streamer=streamer,
|
|
**model_kwargs,
|
|
)
|
|
|
|
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
|
if not trust_remote_code:
|
|
logger.warning_once(
|
|
"Contrastive Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
|
|
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
|
|
)
|
|
if not model_kwargs["use_cache"]:
|
|
raise ValueError("Contrastive search requires `use_cache=True`")
|
|
if self._is_stateful:
|
|
# Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
|
|
raise ValueError(
|
|
f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
|
|
)
|
|
|
|
result = self._contrastive_search(
|
|
input_ids,
|
|
logits_processor=prepared_logits_processor,
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
generation_config=generation_config,
|
|
synced_gpus=synced_gpus,
|
|
streamer=streamer,
|
|
**model_kwargs,
|
|
)
|
|
|
|
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
|
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
|
input_ids=input_ids,
|
|
expand_size=generation_config.num_return_sequences,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
**model_kwargs,
|
|
)
|
|
|
|
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
|
result = self._sample(
|
|
input_ids,
|
|
logits_processor=prepared_logits_processor,
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
generation_config=generation_config,
|
|
synced_gpus=synced_gpus,
|
|
streamer=streamer,
|
|
**model_kwargs,
|
|
)
|
|
|
|
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
|
|
# 11. interleave input_ids with `num_beams` additional sequences per batch
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
|
input_ids=input_ids,
|
|
expand_size=generation_config.num_beams,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
**model_kwargs,
|
|
)
|
|
# 12. run beam sample
|
|
result = self._beam_search(
|
|
input_ids,
|
|
logits_processor=prepared_logits_processor,
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
generation_config=generation_config,
|
|
synced_gpus=synced_gpus,
|
|
**model_kwargs,
|
|
)
|
|
|
|
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
|
|
logger.warning_once(
|
|
"Group Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
|
|
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
|
|
)
|
|
# 11. prepare beam search scorer
|
|
beam_scorer = BeamSearchScorer(
|
|
batch_size=batch_size,
|
|
num_beams=generation_config.num_beams,
|
|
device=inputs_tensor.device,
|
|
length_penalty=generation_config.length_penalty,
|
|
do_early_stopping=generation_config.early_stopping,
|
|
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
|
num_beam_groups=generation_config.num_beam_groups,
|
|
max_length=generation_config.max_length,
|
|
)
|
|
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
|
input_ids=input_ids,
|
|
expand_size=generation_config.num_beams,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
**model_kwargs,
|
|
)
|
|
# 13. run beam search
|
|
result = self._group_beam_search(
|
|
input_ids,
|
|
beam_scorer,
|
|
logits_processor=prepared_logits_processor,
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
generation_config=generation_config,
|
|
synced_gpus=synced_gpus,
|
|
**model_kwargs,
|
|
)
|
|
|
|
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
|
|
logger.warning_once(
|
|
"Constrained Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
|
|
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
|
|
)
|
|
final_constraints = []
|
|
if generation_config.constraints is not None:
|
|
final_constraints = generation_config.constraints
|
|
|
|
if generation_config.force_words_ids is not None:
|
|
|
|
def typeerror():
|
|
raise ValueError(
|
|
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` "
|
|
f"of positive integers, but is {generation_config.force_words_ids}."
|
|
)
|
|
|
|
if (
|
|
not isinstance(generation_config.force_words_ids, list)
|
|
or len(generation_config.force_words_ids) == 0
|
|
):
|
|
typeerror()
|
|
|
|
for word_ids in generation_config.force_words_ids:
|
|
if isinstance(word_ids[0], list):
|
|
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
|
typeerror()
|
|
if any(not isinstance(token_ids, list) for token_ids in word_ids):
|
|
typeerror()
|
|
if any(
|
|
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
|
for token_ids in word_ids
|
|
):
|
|
typeerror()
|
|
|
|
constraint = DisjunctiveConstraint(word_ids)
|
|
else:
|
|
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
|
typeerror()
|
|
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
|
|
typeerror()
|
|
|
|
constraint = PhrasalConstraint(word_ids)
|
|
final_constraints.append(constraint)
|
|
|
|
# 11. prepare beam search scorer
|
|
constrained_beam_scorer = ConstrainedBeamSearchScorer(
|
|
constraints=final_constraints,
|
|
batch_size=batch_size,
|
|
num_beams=generation_config.num_beams,
|
|
device=inputs_tensor.device,
|
|
length_penalty=generation_config.length_penalty,
|
|
do_early_stopping=generation_config.early_stopping,
|
|
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
|
max_length=generation_config.max_length,
|
|
)
|
|
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
|
input_ids=input_ids,
|
|
expand_size=generation_config.num_beams,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
**model_kwargs,
|
|
)
|
|
# 13. run beam search
|
|
result = self._constrained_beam_search(
|
|
input_ids,
|
|
constrained_beam_scorer=constrained_beam_scorer,
|
|
logits_processor=prepared_logits_processor,
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
generation_config=generation_config,
|
|
synced_gpus=synced_gpus,
|
|
**model_kwargs,
|
|
)
|
|
|
|
# Convert to legacy cache format if requested
|
|
if (
|
|
generation_config.return_legacy_cache is True
|
|
and hasattr(result, "past_key_values")
|
|
and getattr(result.past_key_values, "to_legacy_cache") is not None
|
|
):
|
|
result.past_key_values = result.past_key_values.to_legacy_cache()
|
|
return result
|
|
|
|
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
|
|
"""
|
|
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
|
|
fed through `this_peer_finished`. ZeRO stage 3-friendly.
|
|
"""
|
|
if synced_gpus:
|
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
|
# The following logic allows an early break if all peers finished generating their sequence
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device)
|
|
# send 0.0 if we finished, 1.0 otherwise
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
|
# did all peers finish? the reduced sum will be 0.0 then
|
|
if this_peer_finished_flag.item() == 0.0:
|
|
return False
|
|
elif this_peer_finished:
|
|
return False
|
|
return True
|
|
|
|
def heal_tokens(
|
|
self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
|
|
) -> torch.LongTensor:
|
|
r"""
|
|
Generates sequences of token ids for models with a language modeling head.
|
|
Parameters:
|
|
input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.
|
|
tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.
|
|
Return:
|
|
`torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension.
|
|
"""
|
|
if tokenizer is None:
|
|
raise ValueError(
|
|
" When generating with token healing, you must pass the model's tokenizer to the `tokenizer` "
|
|
"argument of `generate`."
|
|
)
|
|
|
|
bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
|
|
vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
|
|
generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
|
|
|
|
# assumption: leading/trailing whitespace is not meaningful, so the prompts are
|
|
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
|
|
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
|
|
input_ids = tokenizer(
|
|
prompts,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
).input_ids.to(input_ids.device)
|
|
|
|
# replace bos with pad to not condition healing on it
|
|
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
|
|
|
|
"""
|
|
the latter code assumes the input_ids is not empty,
|
|
input_id has to be checked if contains elements
|
|
"""
|
|
if input_ids.numel() == 0:
|
|
return input_ids
|
|
|
|
tail_ids = input_ids[:, -1].tolist()
|
|
|
|
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
|
|
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
|
|
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
|
|
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
|
|
|
|
for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
|
|
batch_ids = input_ids[batch_idx]
|
|
if torch.all(batch_ids == pad_token_id).item():
|
|
continue # skip empty sequences (all pad ids)
|
|
|
|
# apply bias for alternatives (extensions) to the tail token
|
|
"""
|
|
seq_bias key has to be tuple with int so have to use
|
|
tokenizer function to convert str to int
|
|
"""
|
|
seq_bias = {
|
|
(tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
|
|
}
|
|
|
|
if len(seq_bias) == 1:
|
|
continue # skip if there are no token alternatives to heal with
|
|
|
|
# slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
|
|
seq_bias[(tail_id,)] += 1.0
|
|
generation_config.update(sequence_bias=seq_bias)
|
|
|
|
trimmed_ids = batch_ids[:-1]
|
|
|
|
"""
|
|
the latter code assumes trimmed_ids is not empty
|
|
so have to check the its element count
|
|
"""
|
|
if trimmed_ids.numel() == 0:
|
|
continue
|
|
|
|
# if the prompt is a single (non-pad) token, regenerate from bos
|
|
if len(batch_ids[batch_ids != pad_token_id]) == 1:
|
|
trimmed_ids[-1] = bos_token_id
|
|
|
|
input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
|
|
|
|
return input_ids
|
|
|
|
def _dola_decoding(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
dola_layers: Union[str, list[int]],
|
|
logits_processor: LogitsProcessorList,
|
|
stopping_criteria: StoppingCriteriaList,
|
|
generation_config: GenerationConfig,
|
|
synced_gpus: bool,
|
|
streamer: "BaseStreamer",
|
|
**model_kwargs,
|
|
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
|
r"""
|
|
Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be
|
|
used for decoder-only text models.
|
|
The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language
|
|
Models" (https://huggingface.co/papers/2309.03883) in ICLR 2024.
|
|
|
|
Parameters:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
The sequence used as a prompt for the generation.
|
|
dola_layers (`Union[str, list[int]]`):
|
|
The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which
|
|
means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices
|
|
to be used for candidate layers. The 0-th layer is the word embedding layer of the model.
|
|
logits_processor (`LogitsProcessorList`):
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
|
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
|
used to tell if the generation loop should stop.
|
|
generation_config ([`~generation.GenerationConfig`]):
|
|
The generation configuration to be used as parametrization of the decoding method.
|
|
synced_gpus (`bool`):
|
|
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
|
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
|
streamer (`BaseStreamer`, *optional*):
|
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
|
model_kwargs:
|
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
|
|
|
Return:
|
|
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`]
|
|
or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
|
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
|
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
|
`model.config.is_encoder_decoder=True`.
|
|
"""
|
|
|
|
if self.config.is_encoder_decoder:
|
|
raise ValueError("DoLa decoding is only available for decoder-only models.")
|
|
# init values
|
|
|
|
pad_token_id = generation_config._pad_token_tensor
|
|
output_attentions = generation_config.output_attentions
|
|
output_hidden_states = generation_config.output_hidden_states
|
|
output_scores = generation_config.output_scores
|
|
output_logits = generation_config.output_logits
|
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
|
do_sample = generation_config.do_sample
|
|
|
|
# init attention / hidden states / scores tuples
|
|
scores = () if (return_dict_in_generate and output_scores) else None
|
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
|
|
# keep track of which sequences are already finished
|
|
batch_size, cur_length = input_ids.shape[:2]
|
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
|
model_kwargs = self._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
|
|
|
|
this_peer_finished = False
|
|
|
|
# prepare layers for DoLa decoding
|
|
final_layer = self.config.get_text_config().num_hidden_layers
|
|
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer,
|
|
# as the early exit from word embeddings will become identity function
|
|
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th
|
|
# layer otherwise. Notice that DoLa does not help shallow models much.
|
|
if not self.config.tie_word_embeddings:
|
|
start_layer = 0
|
|
elif final_layer > 2:
|
|
start_layer = 2
|
|
elif final_layer == 2:
|
|
start_layer = 1
|
|
else:
|
|
start_layer = 0
|
|
|
|
# For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)`
|
|
# are used for `'low'` and `'high'` layers, respectively.
|
|
# For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for
|
|
# `'low'` and `'high'` layers, respectively.
|
|
if isinstance(dola_layers, str) and dola_layers == "low":
|
|
if start_layer == final_layer // 2:
|
|
candidate_premature_layers = [start_layer]
|
|
else:
|
|
candidate_premature_layers = (
|
|
list(range(start_layer, final_layer // 2, 2))
|
|
if final_layer <= 40
|
|
else list(range(start_layer, 20, 2))
|
|
)
|
|
elif isinstance(dola_layers, str) and dola_layers == "high":
|
|
candidate_premature_layers = (
|
|
list(range(final_layer // 2, final_layer, 2))
|
|
if final_layer <= 40
|
|
else list(range(final_layer - 20, final_layer, 2))
|
|
)
|
|
# Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers.
|
|
elif isinstance(dola_layers, list):
|
|
candidate_premature_layers = [i for i in dola_layers if i < final_layer]
|
|
else:
|
|
raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.")
|
|
|
|
lm_head = self.get_output_embeddings()
|
|
if lm_head is None:
|
|
raise ValueError("DoLa is not supported for models that don't have output embeddings.")
|
|
|
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
# prepare model inputs
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
# forward pass to get next token
|
|
outputs = self(
|
|
**model_inputs,
|
|
return_dict=True,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=True,
|
|
)
|
|
|
|
# .float() is needed to retain precision for later logits manipulations
|
|
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().to(copy=True, dtype=torch.float32)
|
|
final_logits = outputs.logits[:, -1, :].float()
|
|
candidate_premature_logits = {}
|
|
for candidate_premature_layer in candidate_premature_layers:
|
|
candidate_premature_logits[candidate_premature_layer] = lm_head(
|
|
outputs.hidden_states[candidate_premature_layer][:, -1, :]
|
|
).to(final_logits.device)
|
|
|
|
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs,
|
|
model_kwargs,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
)
|
|
if synced_gpus and this_peer_finished:
|
|
continue
|
|
|
|
next_token_logits = _dola_select_contrast(
|
|
candidate_premature_layers, candidate_premature_logits, final_logits
|
|
)
|
|
next_token_logits = next_token_logits.to(input_ids.device)
|
|
# pre-process distribution
|
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
|
|
|
# Store scores, attentions and hidden_states when required
|
|
if return_dict_in_generate:
|
|
if output_scores:
|
|
scores += (next_token_scores,)
|
|
if output_logits:
|
|
raw_logits += (final_layer_next_token_logits,)
|
|
if output_attentions:
|
|
decoder_attentions += (
|
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
|
)
|
|
if self.config.is_encoder_decoder:
|
|
cross_attentions += (outputs.cross_attentions,)
|
|
|
|
if output_hidden_states:
|
|
decoder_hidden_states += (
|
|
(outputs.decoder_hidden_states,)
|
|
if self.config.is_encoder_decoder
|
|
else (outputs.hidden_states,)
|
|
)
|
|
|
|
if do_sample: # sample
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
else: # argmax
|
|
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
|
|
# finished sentences should have their next token be a padding token
|
|
if has_eos_stopping_criteria:
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
|
|
# update generated ids, model inputs, and length for next step
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
if streamer is not None:
|
|
streamer.put(next_tokens.cpu())
|
|
|
|
# stop when each sentence is finished
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
this_peer_finished = unfinished_sequences.max() == 0
|
|
|
|
if streamer is not None:
|
|
streamer.end()
|
|
|
|
if return_dict_in_generate:
|
|
return GenerateDecoderOnlyOutput(
|
|
sequences=input_ids,
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
attentions=decoder_attentions,
|
|
hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return input_ids
|
|
|
|
@torch.no_grad()
|
|
def _contrastive_search(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
logits_processor: LogitsProcessorList,
|
|
stopping_criteria: StoppingCriteriaList,
|
|
generation_config: GenerationConfig,
|
|
synced_gpus: bool,
|
|
streamer: Optional["BaseStreamer"],
|
|
**model_kwargs,
|
|
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
|
r"""
|
|
Generates sequences of token ids for models with a language modeling head using **contrastive search** and can
|
|
be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
|
|
|
Parameters:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
The sequence used as a prompt for the generation.
|
|
logits_processor (`LogitsProcessorList`):
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
|
stopping_criteria (`StoppingCriteriaList`):
|
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
|
used to tell if the generation loop should stop.
|
|
generation_config ([`~generation.GenerationConfig`]):
|
|
The generation configuration to be used as parametrization of the decoding method.
|
|
synced_gpus (`bool`):
|
|
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
|
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
|
streamer (`BaseStreamer`, *optional*):
|
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
|
model_kwargs:
|
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
|
|
|
Return:
|
|
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`]
|
|
or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
|
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
|
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
|
`model.config.is_encoder_decoder=True`.
|
|
"""
|
|
# init values
|
|
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
|
top_k = generation_config.top_k
|
|
penalty_alpha = generation_config.penalty_alpha
|
|
pad_token_id = generation_config._pad_token_tensor
|
|
output_attentions = generation_config.output_attentions
|
|
output_hidden_states = generation_config.output_hidden_states
|
|
output_scores = generation_config.output_scores
|
|
output_logits = generation_config.output_logits
|
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
sequential = generation_config.low_memory
|
|
|
|
# init attention / hidden states / scores tuples
|
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
scores = () if (return_dict_in_generate and output_scores) else None
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
|
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
|
encoder_hidden_states = (
|
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
)
|
|
|
|
# keep track of which sequences are already finished
|
|
batch_size, cur_len = input_ids.shape[:2]
|
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
|
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
|
|
|
# Create cosine_matrix_mask based on the attention_mask
|
|
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
|
|
if self.config.is_encoder_decoder:
|
|
if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
|
|
cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
|
|
else:
|
|
cosine_matrix_mask = model_kwargs["attention_mask"]
|
|
cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0)
|
|
|
|
this_peer_finished = False
|
|
|
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
|
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
|
if model_kwargs.get("past_key_values") is None or (
|
|
isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
|
|
and model_kwargs["past_key_values"].get_seq_length() == 0
|
|
):
|
|
# prepare inputs
|
|
model_kwargs["use_cache"] = True
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
|
|
# the `encoder_outputs`
|
|
outputs = self(
|
|
**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
|
|
)
|
|
|
|
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
|
|
# previous tokens)
|
|
if self.config.is_encoder_decoder:
|
|
last_hidden_states = outputs.decoder_hidden_states[-1]
|
|
else:
|
|
last_hidden_states = outputs.hidden_states[-1]
|
|
|
|
# next logit for contrastive search to select top-k candidate tokens
|
|
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
|
|
# (the clone itself is always small)
|
|
# torch.float32 is needed to retain precision for later logits manipulations
|
|
logit_for_next_step = outputs.logits[:, -1, :].to(
|
|
copy=True, dtype=torch.float32, device=input_ids.device
|
|
)
|
|
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs,
|
|
model_kwargs,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
)
|
|
|
|
if not sequential:
|
|
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
|
|
# input_ids is required for expanding visual inputs in qwen2vl
|
|
_, model_kwargs = self._expand_inputs_for_generation(
|
|
input_ids=input_ids,
|
|
expand_size=top_k,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
**model_kwargs,
|
|
)
|
|
|
|
past_key_values = model_kwargs.get("past_key_values")
|
|
if past_key_values is None:
|
|
raise ValueError(
|
|
f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
|
|
"for contrastive search."
|
|
)
|
|
elif (
|
|
not isinstance(past_key_values[0], (tuple, torch.Tensor))
|
|
or past_key_values[0][0].shape[0] != batch_size
|
|
):
|
|
raise ValueError(
|
|
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
|
|
"used for contrastive search without further modifications."
|
|
)
|
|
|
|
# contrastive_search main logic start:
|
|
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|
|
# degeneration penalty
|
|
processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
|
|
next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1)
|
|
|
|
top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k)
|
|
|
|
# Store scores, attentions and hidden_states when required
|
|
if return_dict_in_generate:
|
|
if output_logits:
|
|
raw_logits += (logit_for_next_step,)
|
|
if output_scores:
|
|
scores += (processed_logit_for_next_step,)
|
|
if output_attentions:
|
|
decoder_attentions += (
|
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
|
)
|
|
if self.config.is_encoder_decoder:
|
|
cross_attentions += (outputs.cross_attentions,)
|
|
|
|
if output_hidden_states:
|
|
decoder_hidden_states += (
|
|
(outputs.decoder_hidden_states,)
|
|
if self.config.is_encoder_decoder
|
|
else (outputs.hidden_states,)
|
|
)
|
|
|
|
# This is needed to properly delete outputs.logits which may be very large for this first iteration
|
|
# Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward()
|
|
del outputs
|
|
|
|
if not sequential:
|
|
# Replicates the new past_key_values to match the `top_k` candidates
|
|
past = model_kwargs["past_key_values"]
|
|
# If it is a static cache, modify it in-place layer after layer to save memory
|
|
if isinstance(past, DynamicCache) or (
|
|
isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
|
|
):
|
|
past.batch_repeat_interleave(top_k)
|
|
else:
|
|
new_key_values = []
|
|
for layer in past:
|
|
items = []
|
|
# item is either the key or the value matrix
|
|
for item in layer:
|
|
items.append(item.repeat_interleave(top_k, dim=0))
|
|
new_key_values.append(tuple(items))
|
|
|
|
past = tuple(new_key_values)
|
|
|
|
model_kwargs["past_key_values"] = past
|
|
|
|
if sequential:
|
|
all_outputs = []
|
|
for i in range(top_k):
|
|
# compute the candidate tokens by the language model and collect their hidden_states
|
|
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
|
|
|
|
outputs = self(
|
|
**next_model_inputs,
|
|
return_dict=True,
|
|
output_hidden_states=True,
|
|
output_attentions=output_attentions,
|
|
)
|
|
if isinstance(outputs["past_key_values"], DynamicCache) or (
|
|
isinstance(outputs["past_key_values"], EncoderDecoderCache)
|
|
and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
|
|
):
|
|
# Remove past K-V from output since we don't need to stack later
|
|
outputs["past_key_values"] = None
|
|
# Remove last token from past K-V since we don't want to append it at this point
|
|
model_kwargs["past_key_values"].crop(-1)
|
|
|
|
all_outputs.append(outputs)
|
|
outputs = stack_model_outputs(all_outputs, self.config.get_text_config())
|
|
|
|
else:
|
|
# compute the candidate tokens by the language model and collect their hidden_states
|
|
# assembles top_k_ids into batch of size k
|
|
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
|
|
|
|
outputs = self(
|
|
**next_model_inputs,
|
|
return_dict=True,
|
|
output_hidden_states=True,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
# This is essential to avoid having a last reference to the big past K-V and double the necessary memory
|
|
# in the next loop
|
|
del next_model_inputs
|
|
|
|
# name is different for encoder-decoder and decoder-only models
|
|
if self.config.is_encoder_decoder:
|
|
next_hidden = outputs.decoder_hidden_states[-1]
|
|
full_hidden_states = outputs.decoder_hidden_states
|
|
else:
|
|
next_hidden = outputs.hidden_states[-1]
|
|
full_hidden_states = outputs.hidden_states
|
|
|
|
# .float() is needed to retain precision for later logits manipulations
|
|
logits = outputs.logits[:, -1, :].float()
|
|
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
|
|
|
|
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
|
|
# model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't
|
|
# introduce (noticeable) slowdowns on single-device runs.
|
|
selected_idx = _ranking_fast(
|
|
context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k
|
|
)
|
|
cosine_matrix_mask = torch.cat(
|
|
[cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1
|
|
)
|
|
selected_idx = selected_idx.to("cpu")
|
|
|
|
# This will be used instead of the previous inneficient torch.stack(torch.split())
|
|
augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)])
|
|
|
|
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
|
|
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
|
|
# (model confidence minus degeneration penalty); (6) decoder hidden_states
|
|
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
|
|
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
|
|
next_hidden = next_hidden[range(batch_size), selected_idx, :]
|
|
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
|
|
|
|
next_decoder_hidden_states = ()
|
|
for layer in full_hidden_states:
|
|
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
|
|
next_decoder_hidden_states += (layer,)
|
|
|
|
# generate past_key_values cache of only the selected token
|
|
if sequential:
|
|
next_model_input = self.prepare_inputs_for_generation(
|
|
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
|
|
)
|
|
|
|
selected_outputs = self(
|
|
**next_model_input,
|
|
return_dict=True,
|
|
output_hidden_states=False,
|
|
output_attentions=False,
|
|
)
|
|
next_past_key_values = selected_outputs["past_key_values"]
|
|
|
|
else:
|
|
next_past_key_values = None
|
|
for possible_cache_name in ALL_CACHE_NAMES:
|
|
next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
|
|
# Do it in-place layer per layer to save memory
|
|
if isinstance(next_past_key_values, DynamicCache) or (
|
|
isinstance(next_past_key_values, EncoderDecoderCache)
|
|
and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
|
|
):
|
|
next_past_key_values.batch_select_indices(augmented_idx)
|
|
else:
|
|
new_key_values = []
|
|
for layer in next_past_key_values:
|
|
items = []
|
|
# item is either the key or the value matrix
|
|
for item in layer:
|
|
items.append(item[augmented_idx, ...])
|
|
new_key_values.append(tuple(items))
|
|
|
|
next_past_key_values = tuple(new_key_values)
|
|
|
|
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
|
|
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
|
|
|
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
|
|
if self.config.is_encoder_decoder:
|
|
next_step_cross_attentions = ()
|
|
next_step_decoder_attentions = ()
|
|
if output_attentions:
|
|
for layer in outputs.cross_attentions:
|
|
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
|
next_step_cross_attentions += (layer,)
|
|
for layer in outputs.decoder_attentions:
|
|
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
|
next_step_decoder_attentions += (layer,)
|
|
outputs = Seq2SeqLMOutput(
|
|
past_key_values=next_past_key_values,
|
|
decoder_hidden_states=next_decoder_hidden_states,
|
|
decoder_attentions=next_step_decoder_attentions or None,
|
|
cross_attentions=next_step_cross_attentions or None,
|
|
)
|
|
else:
|
|
next_step_attentions = ()
|
|
if output_attentions:
|
|
for layer in outputs.attentions:
|
|
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
|
next_step_attentions += (layer,)
|
|
outputs = CausalLMOutputWithPast(
|
|
past_key_values=next_past_key_values,
|
|
hidden_states=next_decoder_hidden_states,
|
|
attentions=next_step_attentions or None,
|
|
)
|
|
# contrastive_search main logic end
|
|
|
|
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs,
|
|
model_kwargs,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
)
|
|
if synced_gpus and this_peer_finished:
|
|
continue
|
|
|
|
# finished sentences should have their next token be a padding token
|
|
if has_eos_stopping_criteria:
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
|
|
# update generated ids, model inputs, and length for next step
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
if streamer is not None:
|
|
streamer.put(next_tokens.cpu())
|
|
|
|
# stop when each sentence is finished
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
this_peer_finished = unfinished_sequences.max() == 0
|
|
|
|
if streamer is not None:
|
|
streamer.end()
|
|
|
|
if return_dict_in_generate:
|
|
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
|
# `past_key_values` to be consistent with the other decoding methods
|
|
if model_kwargs.get("past_key_values") is not None:
|
|
if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
|
|
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
|
and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
|
|
):
|
|
model_kwargs["past_key_values"].crop(-1)
|
|
else:
|
|
past_key_values = []
|
|
for layer in model_kwargs["past_key_values"]:
|
|
layer_past_key_values = []
|
|
for item in layer:
|
|
layer_past_key_values.append(item[..., :-1, :])
|
|
past_key_values.append(tuple(layer_past_key_values))
|
|
model_kwargs["past_key_values"] = tuple(past_key_values)
|
|
|
|
if self.config.is_encoder_decoder:
|
|
return GenerateEncoderDecoderOutput(
|
|
sequences=input_ids,
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
encoder_attentions=encoder_attentions,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
decoder_attentions=decoder_attentions,
|
|
cross_attentions=cross_attentions,
|
|
decoder_hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return GenerateDecoderOnlyOutput(
|
|
sequences=input_ids,
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
attentions=decoder_attentions,
|
|
hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return input_ids
|
|
|
|
def _sample(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
logits_processor: LogitsProcessorList,
|
|
stopping_criteria: StoppingCriteriaList,
|
|
generation_config: GenerationConfig,
|
|
synced_gpus: bool,
|
|
streamer: Optional["BaseStreamer"],
|
|
**model_kwargs,
|
|
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
|
r"""
|
|
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
|
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
|
|
|
Parameters:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
The sequence used as a prompt for the generation.
|
|
logits_processor (`LogitsProcessorList`):
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
|
stopping_criteria (`StoppingCriteriaList`):
|
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
|
used to tell if the generation loop should stop.
|
|
generation_config ([`~generation.GenerationConfig`]):
|
|
The generation configuration to be used as parametrization of the decoding method.
|
|
synced_gpus (`bool`):
|
|
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
|
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
|
streamer (`BaseStreamer`, *optional*):
|
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
|
model_kwargs:
|
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
|
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
|
|
|
Return:
|
|
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
|
|
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
|
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
|
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
|
`model.config.is_encoder_decoder=True`.
|
|
"""
|
|
# init values
|
|
pad_token_id = generation_config._pad_token_tensor
|
|
output_attentions = generation_config.output_attentions
|
|
output_hidden_states = generation_config.output_hidden_states
|
|
output_scores = generation_config.output_scores
|
|
output_logits = generation_config.output_logits
|
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
|
do_sample = generation_config.do_sample
|
|
|
|
# init attention / hidden states / scores tuples
|
|
scores = () if (return_dict_in_generate and output_scores) else None
|
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
|
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
|
encoder_hidden_states = (
|
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
)
|
|
|
|
# keep track of which sequences are already finished
|
|
batch_size, cur_len = input_ids.shape[:2]
|
|
this_peer_finished = False
|
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
|
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
|
|
|
model_forward = self.__call__
|
|
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
|
if compile_forward:
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
|
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
|
if self.config._attn_implementation == "flash_attention_2" and getattr(
|
|
model_kwargs.get("past_key_values"), "is_compileable", False
|
|
):
|
|
if generation_config.compile_config is None:
|
|
generation_config.compile_config = CompileConfig(fullgraph=False)
|
|
# only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
|
|
elif generation_config.compile_config.fullgraph:
|
|
logger.warning_once(
|
|
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
|
|
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
|
|
)
|
|
generation_config.compile_config.fullgraph = False
|
|
model_forward = self.get_compiled_call(generation_config.compile_config)
|
|
|
|
if generation_config.prefill_chunk_size is not None:
|
|
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
|
is_prefill = False
|
|
else:
|
|
is_prefill = True
|
|
|
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
# prepare model inputs
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
# prepare variable output controls (note: some models won't accept all output controls)
|
|
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
|
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
|
|
|
if is_prefill:
|
|
outputs = self(**model_inputs, return_dict=True)
|
|
is_prefill = False
|
|
else:
|
|
outputs = model_forward(**model_inputs, return_dict=True)
|
|
|
|
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs,
|
|
model_kwargs,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
)
|
|
if synced_gpus and this_peer_finished:
|
|
continue
|
|
|
|
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
|
# (the clone itself is always small)
|
|
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
|
|
|
# pre-process distribution
|
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
|
|
|
# Store scores, attentions and hidden_states when required
|
|
if return_dict_in_generate:
|
|
if output_scores:
|
|
scores += (next_token_scores,)
|
|
if output_logits:
|
|
raw_logits += (next_token_logits,)
|
|
if output_attentions:
|
|
decoder_attentions += (
|
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
|
)
|
|
if self.config.is_encoder_decoder:
|
|
cross_attentions += (outputs.cross_attentions,)
|
|
|
|
if output_hidden_states:
|
|
decoder_hidden_states += (
|
|
(outputs.decoder_hidden_states,)
|
|
if self.config.is_encoder_decoder
|
|
else (outputs.hidden_states,)
|
|
)
|
|
|
|
# token selection
|
|
if do_sample:
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
|
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
else:
|
|
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
|
|
# finished sentences should have their next token be a padding token
|
|
if has_eos_stopping_criteria:
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
|
|
# update generated ids, model inputs, and length for next step
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
if streamer is not None:
|
|
streamer.put(next_tokens.cpu())
|
|
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
this_peer_finished = unfinished_sequences.max() == 0
|
|
cur_len += 1
|
|
|
|
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
|
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
|
del outputs
|
|
|
|
if streamer is not None:
|
|
streamer.end()
|
|
|
|
if return_dict_in_generate:
|
|
if self.config.is_encoder_decoder:
|
|
return GenerateEncoderDecoderOutput(
|
|
sequences=input_ids,
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
encoder_attentions=encoder_attentions,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
decoder_attentions=decoder_attentions,
|
|
cross_attentions=cross_attentions,
|
|
decoder_hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return GenerateDecoderOnlyOutput(
|
|
sequences=input_ids,
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
attentions=decoder_attentions,
|
|
hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return input_ids
|
|
|
|
@staticmethod
|
|
def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor:
|
|
"""[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]"""
|
|
shape = list(tensor.shape)
|
|
return torch.reshape(tensor, [shape[0] * shape[1]] + shape[2:])
|
|
|
|
@staticmethod
|
|
def _unflatten_beam_dim(tensor: torch.Tensor, batch_size: int, num_beams: int) -> torch.Tensor:
|
|
"""[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]"""
|
|
shape = list(tensor.shape)
|
|
return torch.reshape(tensor, [batch_size, num_beams] + shape[1:])
|
|
|
|
@staticmethod
|
|
def _gather_beams(tensor: torch.Tensor, beam_indices: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Gathers the beam slices indexed by beam_indices into new beam array.
|
|
|
|
Args:
|
|
tensor (`torch.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor
|
|
with the two first dimensions depicting the batch and the beam dimensions.
|
|
beam_indices (`torch.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to
|
|
select .
|
|
|
|
Returns:
|
|
A tensor with the selected beams
|
|
"""
|
|
# `take_along_dim` requires its indices arg to have the same number of dims as `input`
|
|
while len(beam_indices.shape) < len(tensor.shape):
|
|
beam_indices = beam_indices.unsqueeze(-1)
|
|
gathered_tensor = torch.take_along_dim(input=tensor, indices=beam_indices, dim=1)
|
|
return gathered_tensor
|
|
|
|
@staticmethod
|
|
def _check_early_stop_heuristic(
|
|
is_early_stop_heuristic_unsatisfied: torch.Tensor,
|
|
running_beam_scores: torch.Tensor,
|
|
beam_scores: torch.Tensor,
|
|
is_sent_finished: torch.Tensor,
|
|
cur_len: int,
|
|
max_length: int,
|
|
decoder_prompt_len: int,
|
|
early_stopping: Union[bool, str],
|
|
length_penalty: float,
|
|
):
|
|
"""
|
|
Determine whether early stopping is possible by checking if the best possible score of running beams
|
|
could still improve upon the finished ones.
|
|
|
|
Mechanism:
|
|
- Without a length penalty, beam scores typically decrease as more tokens are generated.
|
|
So, if the *best possible* score from any running beam is already worse than the *worst* finished beam,
|
|
we can safely stop early.
|
|
- With a length penalty, scores may increase with longer sequences. In this case, we use heuristics
|
|
to estimate the best possible score — though this estimate may not always be correct — and stop
|
|
if no further improvement seems likely.
|
|
|
|
We apply different heuristics depending on the value of `early_stopping`:
|
|
1. `early_stopping == False`:
|
|
-> Use a heuristic that assumes the best score comes from the current length minus the decoder prompt length.
|
|
-> See detailed discussion: https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
|
|
|
2. `early_stopping == "never"`:
|
|
-> Estimate the best score using either `max_length` or `cur_len`, depending on the sign of `length_penalty`.
|
|
-> A positive length penalty favors longer sequences, so we use `max_length` in that case.
|
|
|
|
NOTE: the canonical beam search implementation can be replicated with `early_stopping="never"` and
|
|
`length_penalty=0.0`, which are NOT the default flags. The default behavior was empirically found to produce
|
|
better sequences (prior to 2022), and changing it is BC breaking.
|
|
"""
|
|
if early_stopping == "never" and length_penalty > 0.0:
|
|
best_hypothetical_length = max_length - decoder_prompt_len
|
|
else:
|
|
best_hypothetical_length = cur_len - decoder_prompt_len
|
|
best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
|
|
worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
|
|
return is_early_stop_heuristic_unsatisfied & torch.any(
|
|
best_possible_running_score > worst_finished_score, dim=-1, keepdim=True
|
|
)
|
|
|
|
@staticmethod
|
|
def _beam_search_has_unfinished_sequences(
|
|
is_early_stop_heuristic_unsatisfied: torch.Tensor,
|
|
is_sent_finished: torch.Tensor,
|
|
next_token_hits_stopping_criteria: torch.Tensor,
|
|
early_stopping: Union[bool, str],
|
|
):
|
|
"""
|
|
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
|
|
"""
|
|
# a. Can the open beams improve the top completed scores?
|
|
improvement_possible = torch.any(is_early_stop_heuristic_unsatisfied)
|
|
|
|
# b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
|
|
# enabled, where we want to finish as soon as all beams have a completed sequence.
|
|
exists_open_beam = ~(torch.all(is_sent_finished) & (early_stopping is True))
|
|
|
|
# c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have
|
|
# reached `max_length``
|
|
valid_continuations = ~torch.all(next_token_hits_stopping_criteria)
|
|
|
|
return improvement_possible & exists_open_beam & valid_continuations
|
|
|
|
def _get_top_k_continuations(
|
|
self,
|
|
accumulated_log_probs: torch.Tensor,
|
|
running_sequences: torch.Tensor,
|
|
running_beam_indices: torch.Tensor,
|
|
cur_len: int,
|
|
decoder_prompt_len: int,
|
|
do_sample: bool,
|
|
beams_to_keep: int,
|
|
num_beams: int,
|
|
vocab_size: int,
|
|
batch_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Get top-K continuations given the accumulated log probs on the next token.
|
|
|
|
A few notes to understand what's going on:
|
|
1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the
|
|
top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated
|
|
log-probabilities, or sample them without replacement using the accumulated scores
|
|
2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at
|
|
least `num_beams` sequences remaining to continue the live beam search.
|
|
3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations
|
|
selected in this step hit the stopping criteria.
|
|
"""
|
|
# TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after
|
|
# token selection. The function should be an argument exposed, so that custom scoring functions can be
|
|
# defined.
|
|
|
|
# Gather the top K scores from _all_ beams.
|
|
if do_sample:
|
|
topk_indices = torch.multinomial(
|
|
nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep
|
|
)
|
|
topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices)
|
|
else:
|
|
topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep)
|
|
|
|
# Gather K top beams, recover the beam index by floor division and token id by modulo division
|
|
topk_current_beam_indices = topk_indices // vocab_size
|
|
topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices)
|
|
topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices)
|
|
topk_ids = topk_indices % vocab_size
|
|
|
|
# Update sequences for the K top-k new sequences.
|
|
topk_running_sequences[:, :, cur_len] = topk_ids
|
|
|
|
# we want to store the beam indices with batch information -> real beam index = beam index % num beams
|
|
batch_offset = torch.arange(batch_size, device=topk_ids.device).view(-1, 1) * num_beams
|
|
batch_modified_indices = topk_current_beam_indices + batch_offset
|
|
topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices
|
|
|
|
return topk_log_probs, topk_running_sequences, topk_running_beam_indices
|
|
|
|
def _get_running_beams_for_next_iteration(
|
|
self,
|
|
topk_log_probs: torch.Tensor,
|
|
topk_running_sequences: torch.Tensor,
|
|
topk_running_beam_indices: torch.Tensor,
|
|
next_token_hits_stopping_criteria: torch.Tensor,
|
|
num_beams: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the
|
|
best non-finished beams to continue beam search in the next iteration.
|
|
"""
|
|
# To prevent these just finished sequences from being used in subsequent iterations, set their log probs
|
|
# to a very large negative value
|
|
topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(torch.float32) * -1.0e9
|
|
|
|
next_topk_indices = torch.topk(topk_running_log_probs, k=num_beams)[1]
|
|
running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices)
|
|
running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices)
|
|
running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices)
|
|
return running_sequences, running_beam_scores, running_beam_indices
|
|
|
|
def _update_finished_beams(
|
|
self,
|
|
sequences: torch.Tensor,
|
|
topk_running_sequences: torch.Tensor,
|
|
beam_scores: torch.Tensor,
|
|
topk_log_probs: torch.Tensor,
|
|
beam_indices: torch.Tensor,
|
|
topk_running_beam_indices: torch.Tensor,
|
|
is_early_stop_heuristic_unsatisfied: torch.Tensor,
|
|
is_sent_finished: torch.Tensor,
|
|
next_token_hits_stopping_criteria: torch.Tensor,
|
|
top_num_beam_mask: torch.Tensor,
|
|
num_beams: int,
|
|
cur_len: int,
|
|
decoder_prompt_len: int,
|
|
length_penalty: float,
|
|
early_stopping: Union[bool, str],
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Updates the finished beams if (and only if) there are new completed sequences that have a higher score than
|
|
the current finished sequences.
|
|
"""
|
|
# Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the
|
|
# remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to
|
|
# continue.
|
|
did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :]
|
|
|
|
# Further process topk logits for the finished beams
|
|
# - add length penalty
|
|
topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty)
|
|
# - make sure no scores can be added anymore if beam is full and early stopping is on
|
|
beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True)
|
|
topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9
|
|
# - make sure no scores can be added anymore if improvement is not possible
|
|
topk_log_probs += (~is_early_stop_heuristic_unsatisfied).to(torch.float32) * -1.0e9
|
|
|
|
# - make sure still running sequences cannot be chosen as finalized beam
|
|
topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9
|
|
|
|
# Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized
|
|
# data with the new finalized sequences (if any, non-finalized sequences have a very large negative score
|
|
# in this step), and keep the best `num_beams` sequences.
|
|
merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1)
|
|
merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1)
|
|
merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1)
|
|
merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1)
|
|
topk_merged_indices = torch.topk(merged_scores, k=num_beams)[1]
|
|
sequences = self._gather_beams(merged_sequences, topk_merged_indices)
|
|
beam_scores = self._gather_beams(merged_scores, topk_merged_indices)
|
|
beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices)
|
|
is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices)
|
|
return sequences, beam_scores, beam_indices, is_sent_finished
|
|
|
|
# end of auxiliary functions for beam search
|
|
|
|
def _beam_search(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
logits_processor: LogitsProcessorList,
|
|
stopping_criteria: StoppingCriteriaList,
|
|
generation_config: GenerationConfig,
|
|
synced_gpus: bool,
|
|
**model_kwargs,
|
|
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
|
r"""
|
|
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
|
|
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
|
|
|
If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
|
|
https://huggingface.co/blog/how-to-generate (especially the beam search section).
|
|
|
|
You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
|
|
(https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
|
|
|
|
Parameters:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
|
|
The sequence used as a prompt for the generation.
|
|
logits_processor (`LogitsProcessorList`):
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
|
stopping_criteria (`StoppingCriteriaList`:
|
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
|
used to tell if the generation loop should stop.
|
|
generation_config ([`~generation.GenerationConfig`]):
|
|
The generation configuration to be used as parametrization of the decoding method.
|
|
synced_gpus (`bool`):
|
|
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
|
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
|
model_kwargs:
|
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
|
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
|
|
|
Return:
|
|
[`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
|
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
|
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
|
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
|
|
`model.config.is_encoder_decoder=True`.
|
|
"""
|
|
|
|
# 1. init beam_search values
|
|
pad_token_id = generation_config._pad_token_tensor
|
|
eos_token_id = generation_config._eos_token_tensor
|
|
output_attentions = generation_config.output_attentions
|
|
output_hidden_states = generation_config.output_hidden_states
|
|
output_scores = generation_config.output_scores
|
|
output_logits = generation_config.output_logits
|
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
do_sample = generation_config.do_sample
|
|
early_stopping = generation_config.early_stopping
|
|
length_penalty = generation_config.length_penalty
|
|
max_length = generation_config.max_length
|
|
num_beams = generation_config.num_beams
|
|
num_return_sequences = generation_config.num_return_sequences
|
|
|
|
batch_size_unflattened, cur_len = input_ids.shape[:2]
|
|
batch_size = batch_size_unflattened // num_beams
|
|
# TODO (joao): standardize special cases
|
|
if self.__class__.__name__ == "MoshiDepthDecoder":
|
|
vocab_size = self.config.audio_vocab_size
|
|
elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
|
|
vocab_size = self.get_output_embeddings().out_features
|
|
else:
|
|
vocab_size = self.config.get_text_config().vocab_size
|
|
decoder_prompt_len = cur_len
|
|
this_peer_finished = False
|
|
|
|
# At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
|
|
# with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
|
|
# (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
|
|
# non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
|
|
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
|
beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
|
|
top_num_beam_mask = torch.cat(
|
|
(torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)),
|
|
dim=0,
|
|
).to(input_ids.device)
|
|
|
|
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
|
|
|
# (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
|
|
# are newer low-memory alternatives like the offloaded cache)
|
|
sequential = generation_config.low_memory
|
|
if sequential:
|
|
raise ValueError(
|
|
"`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
|
|
"#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
|
|
)
|
|
|
|
# 2. init output tuples
|
|
all_scores = () if (return_dict_in_generate and output_scores) else None
|
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
beam_indices = () if (return_dict_in_generate and output_logits) else None
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
|
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
|
encoder_hidden_states = (
|
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
)
|
|
|
|
# 3. init running tensors and static-shaped placeholders
|
|
|
|
# per batch, beam-item holding current token in loop and completed sequences
|
|
output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
|
|
running_sequences = torch.full(
|
|
(batch_size, num_beams, max_length),
|
|
fill_value=output_fill_value,
|
|
dtype=torch.int64,
|
|
device=input_ids.device,
|
|
)
|
|
running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
|
|
sequences = running_sequences.detach().clone()
|
|
|
|
# per batch, beam-item score, logprobs
|
|
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
|
|
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
|
|
running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
|
running_beam_scores[:, 1:] = -1e9
|
|
beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device)
|
|
|
|
# per batch, beam-item state bit indicating if sentence has finished.
|
|
is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
|
|
|
|
# per batch state bit indicating if there is a possibility to improve the best finished sentence.
|
|
is_early_stop_heuristic_unsatisfied = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device)
|
|
|
|
# per batch, beam-item state bit indicating if there are valid continuations.
|
|
next_token_hits_stopping_criteria = torch.zeros(
|
|
(batch_size, num_beams), dtype=torch.bool, device=input_ids.device
|
|
)
|
|
|
|
# per batch selected beam indices
|
|
running_beam_indices = torch.full(
|
|
(batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device
|
|
)
|
|
beam_indices = running_beam_indices.detach().clone()
|
|
|
|
# 4. run the generation loop
|
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
# a. Forward current tokens, obtain the logits
|
|
flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
|
|
model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
|
|
|
|
# prepare variable output controls (note: some models won't accept all output controls)
|
|
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
|
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
|
|
|
model_outputs = self(**model_inputs, return_dict=True)
|
|
|
|
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
model_outputs,
|
|
model_kwargs,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
)
|
|
if synced_gpus and this_peer_finished:
|
|
continue
|
|
|
|
# Copy is needed to avoid keeping a hanging ref
|
|
logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
|
|
|
# b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
|
|
# `temperature`, ...), and add new logprobs to existing running logprobs scores.
|
|
log_probs = nn.functional.log_softmax(logits, dim=-1)
|
|
log_probs = logits_processor(flat_running_sequences, log_probs)
|
|
|
|
# Store logits, attentions and hidden_states when required
|
|
if return_dict_in_generate:
|
|
if output_logits:
|
|
raw_logits += (logits.clone(),)
|
|
if return_dict_in_generate and output_scores:
|
|
all_scores += (log_probs.clone(),)
|
|
|
|
if output_attentions:
|
|
decoder_attentions += (
|
|
(model_outputs.decoder_attentions,)
|
|
if self.config.is_encoder_decoder
|
|
else (model_outputs.attentions,)
|
|
)
|
|
if self.config.is_encoder_decoder:
|
|
cross_attentions += (model_outputs.cross_attentions,)
|
|
|
|
if output_hidden_states:
|
|
decoder_hidden_states += (
|
|
(model_outputs.decoder_hidden_states,)
|
|
if self.config.is_encoder_decoder
|
|
else (model_outputs.hidden_states,)
|
|
)
|
|
|
|
# This is needed to properly delete logits which may be very large for first iteration
|
|
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
|
del model_outputs
|
|
|
|
log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams)
|
|
log_probs = log_probs + running_beam_scores[:, :, None]
|
|
log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size))
|
|
|
|
# c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best
|
|
# continuations among all beams based on the accumulated scores.
|
|
topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations(
|
|
accumulated_log_probs=log_probs,
|
|
running_sequences=running_sequences,
|
|
running_beam_indices=running_beam_indices,
|
|
cur_len=cur_len,
|
|
decoder_prompt_len=decoder_prompt_len,
|
|
do_sample=do_sample,
|
|
beams_to_keep=beams_to_keep,
|
|
num_beams=num_beams,
|
|
vocab_size=vocab_size,
|
|
batch_size=batch_size,
|
|
)
|
|
|
|
# d. Check which running sequences have finished
|
|
next_token_hits_stopping_criteria = stopping_criteria(
|
|
self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes
|
|
all_scores,
|
|
)
|
|
next_token_hits_stopping_criteria = self._unflatten_beam_dim(
|
|
next_token_hits_stopping_criteria, batch_size, beams_to_keep
|
|
)
|
|
|
|
# e. Get the non-finished running `num_beams` sequences for the next generation step
|
|
running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration(
|
|
topk_log_probs=topk_log_probs,
|
|
topk_running_sequences=topk_running_sequences,
|
|
topk_running_beam_indices=topk_running_beam_indices,
|
|
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
|
|
num_beams=num_beams,
|
|
)
|
|
|
|
# f. Update the completed beams if a new high score in a finished sequence is found
|
|
sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams(
|
|
sequences=sequences,
|
|
topk_running_sequences=topk_running_sequences,
|
|
beam_scores=beam_scores,
|
|
topk_log_probs=topk_log_probs,
|
|
beam_indices=beam_indices,
|
|
topk_running_beam_indices=topk_running_beam_indices,
|
|
is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
|
|
is_sent_finished=is_sent_finished,
|
|
next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
|
|
top_num_beam_mask=top_num_beam_mask,
|
|
num_beams=num_beams,
|
|
cur_len=cur_len,
|
|
decoder_prompt_len=decoder_prompt_len,
|
|
length_penalty=length_penalty,
|
|
early_stopping=early_stopping,
|
|
)
|
|
|
|
# g. Prepare remaining data for the next iteration, including computing the stopping condition for
|
|
# beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)
|
|
|
|
# pluck the cache from the beam indices that will be used in the next iteration
|
|
# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
|
|
if model_kwargs.get("past_key_values", None) is not None:
|
|
beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len])
|
|
if hasattr(self, "_reorder_cache"):
|
|
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
|
else:
|
|
model_kwargs["past_key_values"].reorder_cache(beam_idx)
|
|
|
|
cur_len = cur_len + 1
|
|
is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic(
|
|
is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
|
|
running_beam_scores=running_beam_scores,
|
|
beam_scores=beam_scores,
|
|
is_sent_finished=is_sent_finished,
|
|
cur_len=cur_len,
|
|
max_length=max_length,
|
|
decoder_prompt_len=decoder_prompt_len,
|
|
early_stopping=early_stopping,
|
|
length_penalty=length_penalty,
|
|
)
|
|
this_peer_finished = not self._beam_search_has_unfinished_sequences(
|
|
is_early_stop_heuristic_unsatisfied,
|
|
is_sent_finished,
|
|
next_token_hits_stopping_criteria,
|
|
early_stopping,
|
|
)
|
|
|
|
# 5. prepare outputs
|
|
# Take best beams for each batch (the score is sorted in descending order)
|
|
sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :])
|
|
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
|
|
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
|
|
|
|
# Crop the static-shaped tensors to the actual size.
|
|
# `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
|
|
# step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
|
|
# previous decoding iteration)
|
|
max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
|
|
output_length = decoder_prompt_len + max_generated_length
|
|
sequences = sequences[:, :output_length]
|
|
beam_indices = beam_indices[:, :max_generated_length]
|
|
|
|
if return_dict_in_generate:
|
|
if not output_scores:
|
|
beam_scores = None
|
|
|
|
if self.config.is_encoder_decoder:
|
|
return GenerateBeamEncoderDecoderOutput(
|
|
sequences=sequences,
|
|
sequences_scores=beam_scores,
|
|
scores=all_scores,
|
|
logits=raw_logits,
|
|
beam_indices=beam_indices,
|
|
encoder_attentions=encoder_attentions,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
decoder_attentions=decoder_attentions,
|
|
cross_attentions=cross_attentions,
|
|
decoder_hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return GenerateBeamDecoderOnlyOutput(
|
|
sequences=sequences,
|
|
sequences_scores=beam_scores,
|
|
scores=all_scores,
|
|
logits=raw_logits,
|
|
beam_indices=beam_indices,
|
|
attentions=decoder_attentions,
|
|
hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return sequences
|
|
|
|
def _group_beam_search(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
beam_scorer: BeamScorer,
|
|
logits_processor: LogitsProcessorList,
|
|
stopping_criteria: StoppingCriteriaList,
|
|
generation_config: GenerationConfig,
|
|
synced_gpus: bool,
|
|
**model_kwargs,
|
|
):
|
|
r"""
|
|
Generates sequences of token ids for models with a language modeling head using **diverse beam search
|
|
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
|
|
|
Parameters:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
|
|
The sequence used as a prompt for the generation.
|
|
beam_scorer (`BeamScorer`):
|
|
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
|
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
|
|
logits_processor (`LogitsProcessorList`):
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
|
stopping_criteria (`StoppingCriteriaList`):
|
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
|
used to tell if the generation loop should stop.
|
|
generation_config ([`~generation.GenerationConfig`]):
|
|
The generation configuration to be used as parametrization of the decoding method.
|
|
synced_gpus (`bool`):
|
|
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
|
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
|
model_kwargs:
|
|
Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
|
|
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
|
|
|
Return:
|
|
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
|
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
|
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
|
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
|
|
`model.config.is_encoder_decoder=True`.
|
|
"""
|
|
# init values
|
|
pad_token_id = generation_config._pad_token_tensor
|
|
eos_token_id = generation_config._eos_token_tensor
|
|
output_attentions = generation_config.output_attentions
|
|
output_hidden_states = generation_config.output_hidden_states
|
|
output_scores = generation_config.output_scores
|
|
output_logits = generation_config.output_logits
|
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
|
|
num_beams = beam_scorer.num_beams
|
|
num_beam_groups = beam_scorer.num_beam_groups
|
|
num_sub_beams = num_beams // num_beam_groups
|
|
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
|
|
device = input_ids.device
|
|
|
|
batch_beam_size, cur_len = input_ids.shape
|
|
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
|
|
|
if return_dict_in_generate and output_scores:
|
|
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
|
|
else:
|
|
beam_indices = None
|
|
|
|
if num_beams * batch_size != batch_beam_size:
|
|
raise ValueError(
|
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
|
)
|
|
|
|
# init attention / hidden states / scores tuples
|
|
scores = () if (return_dict_in_generate and output_scores) else None
|
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
|
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
|
encoder_hidden_states = (
|
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
)
|
|
|
|
# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
|
|
# the same group don't produce same tokens every time.
|
|
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
|
beam_scores[:, ::num_sub_beams] = 0
|
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
|
|
|
this_peer_finished = False
|
|
|
|
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
|
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
# predicted tokens in cur_len step
|
|
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
|
|
|
# indices which will form the beams in the next time step
|
|
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
|
|
|
# do one decoder step on all beams of all sentences in batch
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
# prepare variable output controls (note: some models won't accept all output controls)
|
|
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
|
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
|
|
|
outputs = self(**model_inputs, return_dict=True)
|
|
|
|
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs,
|
|
model_kwargs,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
)
|
|
if synced_gpus and this_peer_finished:
|
|
cur_len = cur_len + 1
|
|
continue
|
|
|
|
if output_scores:
|
|
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
|
|
if output_logits:
|
|
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
|
# (the clone itself is always small)
|
|
raw_logit_score = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device)
|
|
|
|
for beam_group_idx in range(num_beam_groups):
|
|
group_start_idx = beam_group_idx * num_sub_beams
|
|
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
|
group_size = group_end_idx - group_start_idx
|
|
|
|
# indices of beams of current group among all sentences in batch
|
|
batch_group_indices = []
|
|
|
|
for batch_idx in range(batch_size):
|
|
batch_group_indices.extend(
|
|
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
|
)
|
|
group_input_ids = input_ids[batch_group_indices]
|
|
|
|
# select outputs of beams of current group only
|
|
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
|
|
# .float() is needed to retain precision for later logits manipulations
|
|
next_token_logits = outputs.logits[batch_group_indices, -1, :].to(
|
|
dtype=torch.float32, device=input_ids.device
|
|
)
|
|
|
|
next_token_scores = nn.functional.log_softmax(
|
|
next_token_logits, dim=-1
|
|
) # (batch_size * group_size, vocab_size)
|
|
vocab_size = next_token_scores.shape[-1]
|
|
|
|
next_token_scores_processed = logits_processor(
|
|
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
|
)
|
|
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
|
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
|
|
|
if output_scores:
|
|
processed_score[batch_group_indices] = next_token_scores_processed
|
|
|
|
# reshape for beam search
|
|
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
|
|
|
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
|
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
|
next_token_scores, next_tokens = torch.topk(
|
|
next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True
|
|
)
|
|
|
|
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
|
next_tokens = next_tokens % vocab_size
|
|
|
|
# stateless
|
|
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
|
beam_outputs = beam_scorer.process(
|
|
group_input_ids,
|
|
next_token_scores,
|
|
next_tokens,
|
|
next_indices,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
beam_indices=process_beam_indices,
|
|
group_index=beam_group_idx,
|
|
decoder_prompt_len=decoder_prompt_len,
|
|
)
|
|
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
|
beam_idx = beam_outputs["next_beam_indices"]
|
|
|
|
if return_dict_in_generate and output_scores:
|
|
beam_indices[beam_group_idx] = tuple(
|
|
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
|
|
)
|
|
|
|
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
|
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
|
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
|
|
|
# (beam_idx // group_size) -> batch_idx
|
|
# (beam_idx % group_size) -> offset of idx inside the group
|
|
reordering_indices[batch_group_indices] = (
|
|
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
|
|
+ group_start_idx
|
|
+ (beam_idx % group_size)
|
|
)
|
|
|
|
# Store scores, attentions and hidden_states when required
|
|
if return_dict_in_generate:
|
|
if output_scores:
|
|
scores += (processed_score,)
|
|
if output_logits:
|
|
raw_logits += (raw_logit_score,)
|
|
if output_attentions:
|
|
decoder_attentions += (
|
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
|
)
|
|
if self.config.is_encoder_decoder:
|
|
cross_attentions += (outputs.cross_attentions,)
|
|
|
|
if output_hidden_states:
|
|
decoder_hidden_states += (
|
|
(outputs.decoder_hidden_states,)
|
|
if self.config.is_encoder_decoder
|
|
else (outputs.hidden_states,)
|
|
)
|
|
|
|
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
|
|
|
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
|
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
|
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
|
|
# (that way the memory peak does not include outputs.logits)
|
|
del outputs
|
|
|
|
# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
|
|
if model_kwargs.get("past_key_values", None) is not None:
|
|
if hasattr(self, "_reorder_cache"):
|
|
model_kwargs["past_key_values"] = self._reorder_cache(
|
|
model_kwargs["past_key_values"], reordering_indices
|
|
)
|
|
else:
|
|
model_kwargs["past_key_values"].reorder_cache(reordering_indices)
|
|
|
|
# increase cur_len
|
|
cur_len = cur_len + 1
|
|
|
|
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
|
this_peer_finished = True
|
|
|
|
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
|
sequence_outputs = beam_scorer.finalize(
|
|
input_ids,
|
|
beam_scores,
|
|
next_tokens,
|
|
next_indices,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
max_length=stopping_criteria.max_length,
|
|
beam_indices=final_beam_indices,
|
|
decoder_prompt_len=decoder_prompt_len,
|
|
)
|
|
|
|
if return_dict_in_generate:
|
|
if not output_scores:
|
|
sequence_outputs["sequence_scores"] = None
|
|
|
|
if self.config.is_encoder_decoder:
|
|
return GenerateBeamEncoderDecoderOutput(
|
|
sequences=sequence_outputs["sequences"],
|
|
sequences_scores=sequence_outputs["sequence_scores"],
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
beam_indices=sequence_outputs["beam_indices"],
|
|
encoder_attentions=encoder_attentions,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
decoder_attentions=decoder_attentions,
|
|
cross_attentions=cross_attentions,
|
|
decoder_hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return GenerateBeamDecoderOnlyOutput(
|
|
sequences=sequence_outputs["sequences"],
|
|
sequences_scores=sequence_outputs["sequence_scores"],
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
beam_indices=sequence_outputs["beam_indices"],
|
|
attentions=decoder_attentions,
|
|
hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return sequence_outputs["sequences"]
|
|
|
|
def _constrained_beam_search(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
constrained_beam_scorer: ConstrainedBeamSearchScorer,
|
|
logits_processor: LogitsProcessorList,
|
|
stopping_criteria: StoppingCriteriaList,
|
|
generation_config: GenerationConfig,
|
|
synced_gpus: bool,
|
|
**model_kwargs,
|
|
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
|
r"""
|
|
Generates sequences of token ids for models with a language modeling head using **constrained beam search
|
|
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
|
|
|
Parameters:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
|
|
The sequence used as a prompt for the generation.
|
|
constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
|
|
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
|
sorted during generation, while satisfying a list of positive constraints. For more information, the
|
|
documentation of [`ConstrainedBeamSearchScorer`] should be read.
|
|
logits_processor (`LogitsProcessorList`):
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
|
stopping_criteria (`StoppingCriteriaList`):
|
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
|
used to tell if the generation loop should stop.
|
|
generation_config ([`~generation.GenerationConfig`]):
|
|
The generation configuration to be used as parametrization of the decoding method.
|
|
synced_gpus (`bool`):
|
|
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
|
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
|
model_kwargs:
|
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
|
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
|
|
|
Return:
|
|
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
|
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
|
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
|
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
|
|
`model.config.is_encoder_decoder=True`.
|
|
"""
|
|
# init values
|
|
pad_token_id = generation_config._pad_token_tensor
|
|
eos_token_id = generation_config._eos_token_tensor
|
|
output_attentions = generation_config.output_attentions
|
|
output_hidden_states = generation_config.output_hidden_states
|
|
output_scores = generation_config.output_scores
|
|
output_logits = generation_config.output_logits
|
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
|
|
batch_size = len(constrained_beam_scorer._beam_hyps)
|
|
num_beams = constrained_beam_scorer.num_beams
|
|
|
|
batch_beam_size, cur_len = input_ids.shape[:2]
|
|
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
|
|
|
if num_beams * batch_size != batch_beam_size:
|
|
raise ValueError(
|
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
|
)
|
|
|
|
# init attention / hidden states / scores tuples
|
|
scores = () if (return_dict_in_generate and output_scores) else None
|
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
beam_indices = (
|
|
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
|
|
)
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
|
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
|
encoder_hidden_states = (
|
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
)
|
|
|
|
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
|
|
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
|
beam_scores[:, 1:] = -1e9
|
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
|
|
|
this_peer_finished = False
|
|
|
|
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
|
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
# prepare variable output controls (note: some models won't accept all output controls)
|
|
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
|
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
|
|
|
outputs = self(**model_inputs, return_dict=True)
|
|
|
|
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs,
|
|
model_kwargs,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
)
|
|
if synced_gpus and this_peer_finished:
|
|
cur_len = cur_len + 1
|
|
continue
|
|
|
|
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
|
# (the clone itself is always small)
|
|
# .float() is needed to retain precision for later logits manipulations
|
|
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
|
next_token_scores = nn.functional.log_softmax(
|
|
next_token_logits, dim=-1
|
|
) # (batch_size * num_beams, vocab_size)
|
|
|
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
|
|
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
|
next_token_scores_processed
|
|
)
|
|
|
|
scores_for_all_vocab = next_token_scores.clone()
|
|
|
|
# Store scores, attentions and hidden_states when required
|
|
if return_dict_in_generate:
|
|
if output_scores:
|
|
scores += (next_token_scores,)
|
|
if output_logits:
|
|
raw_logits += (next_token_logits,)
|
|
if output_attentions:
|
|
decoder_attentions += (
|
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
|
)
|
|
if self.config.is_encoder_decoder:
|
|
cross_attentions += (outputs.cross_attentions,)
|
|
|
|
if output_hidden_states:
|
|
decoder_hidden_states += (
|
|
(outputs.decoder_hidden_states,)
|
|
if self.config.is_encoder_decoder
|
|
else (outputs.hidden_states,)
|
|
)
|
|
|
|
# reshape for beam search
|
|
vocab_size = next_token_scores.shape[-1]
|
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
|
|
|
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
|
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
|
next_token_scores, next_tokens = torch.topk(
|
|
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
|
|
)
|
|
|
|
next_indices = (next_tokens / vocab_size).long()
|
|
next_tokens = next_tokens % vocab_size
|
|
|
|
# stateless
|
|
beam_outputs = constrained_beam_scorer.process(
|
|
input_ids,
|
|
next_token_scores,
|
|
next_tokens,
|
|
next_indices,
|
|
scores_for_all_vocab,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
beam_indices=beam_indices,
|
|
decoder_prompt_len=decoder_prompt_len,
|
|
)
|
|
beam_scores = beam_outputs["next_beam_scores"]
|
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
|
beam_idx = beam_outputs["next_beam_indices"]
|
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
|
|
|
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
|
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
|
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
|
|
# (that way the memory peak does not include outputs.logits)
|
|
del outputs
|
|
|
|
# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
|
|
if model_kwargs.get("past_key_values", None) is not None:
|
|
if hasattr(self, "_reorder_cache"):
|
|
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
|
else:
|
|
model_kwargs["past_key_values"].reorder_cache(beam_idx)
|
|
|
|
if return_dict_in_generate and output_scores:
|
|
beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))
|
|
|
|
# increase cur_len
|
|
cur_len = cur_len + 1
|
|
|
|
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
|
this_peer_finished = True
|
|
|
|
sequence_outputs = constrained_beam_scorer.finalize(
|
|
input_ids,
|
|
beam_scores,
|
|
next_tokens,
|
|
next_indices,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
max_length=stopping_criteria.max_length,
|
|
beam_indices=beam_indices,
|
|
decoder_prompt_len=decoder_prompt_len,
|
|
)
|
|
|
|
if return_dict_in_generate:
|
|
if not output_scores:
|
|
sequence_outputs["sequence_scores"] = None
|
|
if self.config.is_encoder_decoder:
|
|
return GenerateBeamEncoderDecoderOutput(
|
|
sequences=sequence_outputs["sequences"],
|
|
sequences_scores=sequence_outputs["sequence_scores"],
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
beam_indices=sequence_outputs["beam_indices"],
|
|
encoder_attentions=encoder_attentions,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
decoder_attentions=decoder_attentions,
|
|
cross_attentions=cross_attentions,
|
|
decoder_hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return GenerateBeamDecoderOnlyOutput(
|
|
sequences=sequence_outputs["sequences"],
|
|
sequences_scores=sequence_outputs["sequence_scores"],
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
beam_indices=sequence_outputs["beam_indices"],
|
|
attentions=decoder_attentions,
|
|
hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return sequence_outputs["sequences"]
|
|
|
|
def _assisted_decoding(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
candidate_generator: CandidateGenerator,
|
|
logits_processor: LogitsProcessorList,
|
|
stopping_criteria: StoppingCriteriaList,
|
|
generation_config: GenerationConfig,
|
|
synced_gpus: bool,
|
|
streamer: Optional["BaseStreamer"],
|
|
**model_kwargs,
|
|
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
|
r"""
|
|
Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
|
|
**sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
|
|
candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
|
|
models.
|
|
|
|
Parameters:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
The sequence used as a prompt for the generation.
|
|
candidate_generator (`CandidateGenerator`):
|
|
A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
|
|
more information, the documentation of [`CandidateGenerator`] should be read.
|
|
logits_processor (`LogitsProcessorList`):
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
|
stopping_criteria (`StoppingCriteriaList`):
|
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
|
used to tell if the generation loop should stop.
|
|
generation_config ([`~generation.GenerationConfig`]):
|
|
The generation configuration to be used as parametrization of the decoding method.
|
|
synced_gpus (`bool`):
|
|
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
|
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
|
streamer (`BaseStreamer`, *optional*):
|
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
|
model_kwargs:
|
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
|
|
|
Return:
|
|
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
|
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
|
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
|
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
|
`model.config.is_encoder_decoder=True`.
|
|
"""
|
|
# init values
|
|
do_sample = generation_config.do_sample
|
|
output_attentions = generation_config.output_attentions
|
|
output_hidden_states = generation_config.output_hidden_states
|
|
output_scores = generation_config.output_scores
|
|
output_logits = generation_config.output_logits
|
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
|
|
# init attention / hidden states / scores tuples
|
|
scores = () if (return_dict_in_generate and output_scores) else None
|
|
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
|
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
|
encoder_hidden_states = (
|
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
)
|
|
|
|
# keep track of which sequences are already finished
|
|
batch_size, cur_len = input_ids.shape[:2]
|
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
|
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
|
|
|
this_peer_finished = False
|
|
is_first_iteration = True # to preserve the same API in the output as other generation methods
|
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
cur_len = input_ids.shape[1]
|
|
|
|
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
|
|
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
|
candidate_input_ids = candidate_input_ids.to(self.device)
|
|
if candidate_logits is not None:
|
|
candidate_logits = candidate_logits.to(self.device)
|
|
|
|
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
|
is_done_candidate = stopping_criteria(candidate_input_ids, None)
|
|
|
|
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
|
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
|
|
# we use this forward pass to also pick the subsequent logits in the original model.
|
|
|
|
# 2.1. Prepare the model inputs
|
|
candidate_kwargs = copy.copy(model_kwargs)
|
|
candidate_kwargs = _prepare_attention_mask(
|
|
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
|
|
)
|
|
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
|
if "cache_position" in candidate_kwargs:
|
|
candidate_kwargs["cache_position"] = torch.cat(
|
|
(
|
|
candidate_kwargs["cache_position"],
|
|
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
|
|
),
|
|
dim=0,
|
|
)
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
|
if "logits_to_keep" in model_inputs:
|
|
model_inputs["logits_to_keep"] = candidate_length + 1
|
|
|
|
# 2.2. Run a forward pass on the candidate sequence
|
|
# prepare variable output controls (note: some models won't accept all output controls)
|
|
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
|
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
|
|
|
outputs = self(**model_inputs)
|
|
|
|
# 2.3. Process the new logits
|
|
# .float() is needed to retain precision for later logits manipulations
|
|
new_logits = outputs.logits[:, -candidate_length - 1 :].to(
|
|
dtype=torch.float32, device=input_ids.device
|
|
) # excludes the input prompt if present
|
|
next_token_logits = new_logits.clone()
|
|
if len(logits_processor) > 0:
|
|
for i in range(candidate_length + 1):
|
|
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
|
|
|
# 3. Select the accepted tokens. There are two possible cases:
|
|
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
|
|
# 👉 Apply algorithm 1 from the speculative decoding paper (https://huggingface.co/papers/2211.17192).
|
|
if do_sample and candidate_logits is not None:
|
|
valid_tokens, n_matches = _speculative_sampling(
|
|
candidate_input_ids,
|
|
candidate_logits,
|
|
candidate_length,
|
|
new_logits,
|
|
is_done_candidate,
|
|
)
|
|
|
|
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
|
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
|
|
# mismatch, or until the max length is reached.
|
|
else:
|
|
if do_sample:
|
|
probs = new_logits.softmax(dim=-1)
|
|
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
|
|
else:
|
|
selected_tokens = new_logits.argmax(dim=-1)
|
|
|
|
candidate_new_tokens = candidate_input_ids[:, cur_len:]
|
|
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
|
|
|
|
# Ensure we don't generate beyond max_len or an EOS token
|
|
if is_done_candidate and n_matches == candidate_length:
|
|
n_matches -= 1
|
|
valid_tokens = selected_tokens[:, : n_matches + 1]
|
|
|
|
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
|
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
|
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
|
|
# is no match.
|
|
|
|
# 4.1. Get the valid continuation, after the matching tokens
|
|
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
|
if streamer is not None:
|
|
streamer.put(valid_tokens.cpu())
|
|
new_cur_len = input_ids.shape[1]
|
|
|
|
# 4.2. Discard past key values relative to unused assistant tokens
|
|
outputs.past_key_values.crop(new_cur_len - 1)
|
|
|
|
# 5. Update the candidate generation strategy if needed
|
|
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
|
|
|
|
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs,
|
|
model_kwargs,
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
num_new_tokens=n_matches + 1,
|
|
)
|
|
if synced_gpus and this_peer_finished:
|
|
continue
|
|
|
|
# Store scores, attentions and hidden_states when required
|
|
# Assistant: modified to append one tuple element per token, as in the other generation methods.
|
|
if return_dict_in_generate:
|
|
newly_added_length = n_matches + 1
|
|
if output_scores:
|
|
scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
|
|
if output_logits:
|
|
raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))
|
|
|
|
newly_added_length = new_cur_len if is_first_iteration else newly_added_length
|
|
if output_attentions:
|
|
if self.config.is_encoder_decoder:
|
|
cross_attentions = _split_model_outputs(
|
|
cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
|
|
)
|
|
decoder_attentions = _split_model_outputs(
|
|
decoder_attentions,
|
|
outputs.decoder_attentions,
|
|
cur_len,
|
|
newly_added_length,
|
|
is_decoder_attention=True,
|
|
)
|
|
# some (V)LLMs have hard requirement on SDPA and thus never return attn
|
|
elif outputs.attentions[0] is not None:
|
|
decoder_attentions = _split_model_outputs(
|
|
decoder_attentions,
|
|
outputs.attentions,
|
|
cur_len,
|
|
newly_added_length,
|
|
is_decoder_attention=True,
|
|
)
|
|
if output_hidden_states:
|
|
if self.config.is_encoder_decoder:
|
|
decoder_hidden_states = _split_model_outputs(
|
|
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
|
|
)
|
|
else:
|
|
decoder_hidden_states = _split_model_outputs(
|
|
decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
|
|
)
|
|
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
this_peer_finished = unfinished_sequences.max() == 0
|
|
is_first_iteration = False
|
|
|
|
if streamer is not None:
|
|
streamer.end()
|
|
|
|
if (
|
|
hasattr(candidate_generator, "assistant_model")
|
|
and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
|
|
):
|
|
candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
|
|
candidate_generator.num_assistant_tokens
|
|
)
|
|
if return_dict_in_generate:
|
|
if self.config.is_encoder_decoder:
|
|
return GenerateEncoderDecoderOutput(
|
|
sequences=input_ids,
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
encoder_attentions=encoder_attentions,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
decoder_attentions=decoder_attentions,
|
|
cross_attentions=cross_attentions,
|
|
decoder_hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return GenerateDecoderOnlyOutput(
|
|
sequences=input_ids,
|
|
scores=scores,
|
|
logits=raw_logits,
|
|
attentions=decoder_attentions,
|
|
hidden_states=decoder_hidden_states,
|
|
past_key_values=model_kwargs.get("past_key_values"),
|
|
)
|
|
else:
|
|
return input_ids
|
|
|
|
def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
|
|
# Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
|
|
# end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
|
|
torch._dynamo.config.cache_size_limit = 64
|
|
|
|
chunk_size = generation_config.prefill_chunk_size
|
|
# Only chunk up the token just before last, so that decoding is completely performed outside this function
|
|
# (here we simply prefill the cache)
|
|
input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
|
|
|
|
if "past_key_values" not in model_kwargs:
|
|
raise ValueError("Cannot use prefill chunking without a cache")
|
|
|
|
model_forward = self.forward
|
|
|
|
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
|
if compile_forward:
|
|
model_forward = self.get_compiled_call(generation_config.compile_config)
|
|
|
|
attention_mask = model_kwargs.pop("attention_mask", None)
|
|
|
|
past_length = 0
|
|
for input_chunk in input_chunks:
|
|
current_length = past_length + input_chunk.shape[-1]
|
|
# Prepare inputs
|
|
if attention_mask is not None:
|
|
model_kwargs["attention_mask"] = attention_mask[:, :current_length]
|
|
model_kwargs["cache_position"] = torch.arange(
|
|
past_length, current_length, dtype=torch.long, device=input_chunk.device
|
|
)
|
|
model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
|
|
model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
|
|
|
|
outputs = model_forward(**model_inputs, return_dict=True)
|
|
|
|
model_kwargs["past_key_values"] = outputs.past_key_values
|
|
past_length = current_length
|
|
|
|
model_kwargs["attention_mask"] = attention_mask
|
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
|
_ = model_kwargs.pop("position_ids", None)
|
|
|
|
return model_kwargs
|
|
|
|
|
|
def _speculative_sampling(
|
|
candidate_input_ids,
|
|
candidate_logits,
|
|
candidate_length,
|
|
new_logits,
|
|
is_done_candidate,
|
|
):
|
|
"""
|
|
Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
|
|
the selected tokens, as well as the number of candidate matches.
|
|
|
|
NOTE: Unless otherwise stated, the variable names match those in the paper.
|
|
"""
|
|
new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
|
|
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
|
|
# selected by the assistant, respectively.
|
|
q = candidate_logits.softmax(dim=-1)
|
|
q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
|
|
p = new_logits.softmax(dim=-1)
|
|
p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
|
|
probability_ratio = p_i / q_i
|
|
|
|
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
|
|
# than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
|
|
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
|
|
r_i = torch.rand_like(probability_ratio)
|
|
is_accepted = r_i <= probability_ratio
|
|
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
|
|
|
|
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
|
|
if is_done_candidate and n_matches == candidate_length:
|
|
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
|
|
# due to acceptance on EOS we fix `n_matches`
|
|
n_matches -= 1
|
|
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
|
|
else:
|
|
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
|
|
gamma = candidate_logits.shape[1]
|
|
p_n_plus_1 = p[:, n_matches, :]
|
|
if n_matches < gamma:
|
|
q_n_plus_1 = q[:, n_matches, :]
|
|
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
|
|
p_prime.div_(p_prime.sum())
|
|
else:
|
|
p_prime = p_n_plus_1
|
|
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
|
|
|
|
# The selected tokens include the matches (if any) plus the next sampled tokens
|
|
if n_matches > 0:
|
|
valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
|
|
else:
|
|
valid_tokens = t
|
|
|
|
return valid_tokens, n_matches
|
|
|
|
|
|
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
|
|
"""
|
|
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
|
|
where each member corresponds to a single generated token.
|
|
"""
|
|
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
|
|
# prompt.
|
|
if len(outputs) == 0:
|
|
new_tuple = ()
|
|
for layer in new_outputs:
|
|
last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
|
|
new_tuple += (layer[..., :cur_len, :last_dim_size],)
|
|
outputs += (new_tuple,)
|
|
# The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
|
|
cur_len += 1
|
|
added_len -= cur_len
|
|
|
|
for i in range(added_len):
|
|
new_tuple = ()
|
|
for layer in new_outputs:
|
|
last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
|
|
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
|
|
outputs += (new_tuple,)
|
|
return outputs
|
|
|
|
|
|
def _ranking_fast(
|
|
context_hidden: torch.FloatTensor,
|
|
next_hidden: torch.FloatTensor,
|
|
next_top_k_probs: torch.FloatTensor,
|
|
cosine_matrix_mask: torch.LongTensor,
|
|
alpha: float,
|
|
beam_width: int,
|
|
) -> torch.FloatTensor:
|
|
"""
|
|
Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
|
|
in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each
|
|
row in the batch.
|
|
"""
|
|
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
|
|
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
|
|
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
|
|
|
|
# Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
|
|
# Using a large negative value for masked positions
|
|
cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype)
|
|
cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min
|
|
cosine_matrix = cosine_matrix + cosine_matrix_mask
|
|
|
|
degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
|
|
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
|
|
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
|
|
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
|
|
_, selected_idx = contrastive_score.max(dim=-1) # [B]
|
|
return selected_idx
|
|
|
|
|
|
def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput:
|
|
"""
|
|
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
|
|
specific ModelOutput subclass from the list provided.
|
|
"""
|
|
if not model_outputs:
|
|
raise ValueError("Input list is empty.")
|
|
|
|
# Infer the class from the first object in the list
|
|
model_output_cls = type(model_outputs[0])
|
|
|
|
# Ensure all objects are of the same type
|
|
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
|
|
raise ValueError("All elements in the list should be of the same type.")
|
|
|
|
# Helper function to concat tensors or tuples of tensors
|
|
def _concat(data):
|
|
"""
|
|
Reverse of `_split` function above.
|
|
"""
|
|
if any(data is None for data in data):
|
|
return None
|
|
if isinstance(data[0], torch.Tensor):
|
|
return torch.cat(data, dim=0)
|
|
elif isinstance(data[0], tuple):
|
|
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
|
if isinstance(data[0][0], tuple):
|
|
return tuple(
|
|
tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
|
|
for i in range(len(data[0]))
|
|
)
|
|
else:
|
|
return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
|
|
elif isinstance(data[0], (int, float)):
|
|
# If the elements are integers or floats, return a tensor
|
|
return torch.tensor(data)
|
|
else:
|
|
raise TypeError(f"Unexpected attribute type: {type(data[0])}")
|
|
|
|
# Use a dictionary comprehension to gather attributes from all objects and concatenate them
|
|
concatenated_data = {
|
|
k: _concat([getattr(model_output, k) for model_output in model_outputs])
|
|
for k in model_output_cls.__dataclass_fields__.keys()
|
|
}
|
|
|
|
# Return a new object of the inferred class with the concatenated attributes
|
|
return model_output_cls(**concatenated_data)
|
|
|
|
|
|
def _relative_top_filter(
|
|
scores: torch.FloatTensor,
|
|
baseline_scores: torch.FloatTensor,
|
|
relative_top: float = 0.1,
|
|
filter_value: float = -float("Inf"),
|
|
base_filter_value=-1e-3,
|
|
min_tokens_to_keep: int = 1,
|
|
) -> torch.FloatTensor:
|
|
"""
|
|
Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235
|
|
Apply filtering to only keep tokens with a probability above a certain threshold. The threshold is defined as `relative_top` * max probability in the distribution.
|
|
"""
|
|
scores_normalized = scores.log_softmax(dim=-1)
|
|
baseline_scores_normalized = baseline_scores.log_softmax(dim=-1)
|
|
sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True)
|
|
min_thresh = sorted_logits[..., min_tokens_to_keep - 1]
|
|
probs_max = torch.max(scores_normalized, dim=-1).values
|
|
probs_thresh = probs_max + np.log(relative_top)
|
|
probs_thresh = torch.min(min_thresh, probs_thresh)
|
|
probs_thresh = probs_thresh.unsqueeze(-1)
|
|
baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value
|
|
scores_normalized[scores_normalized < probs_thresh] = filter_value
|
|
return scores_normalized, baseline_scores_normalized
|
|
|
|
|
|
def _dola_select_contrast(
|
|
candidate_premature_layers: list[int],
|
|
candidate_premature_logits: dict[int, torch.FloatTensor],
|
|
final_logits: torch.FloatTensor,
|
|
) -> torch.FloatTensor:
|
|
if len(candidate_premature_layers) == 1:
|
|
base_logits = candidate_premature_logits[candidate_premature_layers[0]]
|
|
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
|
|
logits = final_logits - base_logits
|
|
return logits
|
|
|
|
# 1. Stacking all premature_layers into a new dimension
|
|
stacked_premature_layers = torch.stack([candidate_premature_logits[i] for i in candidate_premature_layers], dim=0)
|
|
|
|
# 2. Calculate the softmax values for mature_layer and all premature_layers
|
|
# shape: (batch_size, vocab_size)
|
|
softmax_mature_layer = F.softmax(final_logits, dim=-1)
|
|
# shape: (num_premature_layers, batch_size, vocab_size)
|
|
softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1)
|
|
|
|
# 3. Calculate the average distribution
|
|
# shape: (num_premature_layers, batch_size, vocab_size)
|
|
avg_dist = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers)
|
|
|
|
# 4. Calculate log-softmax for the KL divergence
|
|
# shape: (batch_size, vocab_size)
|
|
log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1)
|
|
# shape: (num_premature_layers, batch_size, vocab_size)
|
|
log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1)
|
|
|
|
# 5. Calculate the KL divergences and then the JS divergences
|
|
# shape: (num_premature_layers, batch_size)
|
|
kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], avg_dist, reduction="none").mean(-1)
|
|
# shape: (num_premature_layers, batch_size)
|
|
kl2 = F.kl_div(log_softmax_premature_layers, avg_dist, reduction="none").mean(-1)
|
|
js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size)
|
|
|
|
# 6. Reduce the batchmean
|
|
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,)
|
|
premature_layer = candidate_premature_layers[int(js_divs.argmax().item())]
|
|
|
|
base_logits = candidate_premature_logits[premature_layer]
|
|
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
|
|
logits = final_logits - base_logits
|
|
return logits
|